27 #ifndef LBANN_LAYERS_MISC_DIST_EMBEDDING_HPP_INCLUDED 28 #define LBANN_LAYERS_MISC_DIST_EMBEDDING_HPP_INCLUDED 32 #if defined(LBANN_HAS_SHMEM) || defined(LBANN_HAS_NVSHMEM) 37 #include "lbann/proto/layers.pb.h" 69 template <
typename TensorDataType, data_layout Layout, El::Device Device>
70 class dist_embedding_layer :
public data_type_layer<TensorDataType>
74 "distributed embedding layer only supports data parallel layout");
77 dist_embedding_layer(
size_t num_embeddings,
80 DataType learning_rate,
81 bool barrier_in_forward_prop);
83 dist_embedding_layer(
const dist_embedding_layer& other);
84 dist_embedding_layer& operator=(
const dist_embedding_layer& other);
85 ~dist_embedding_layer();
87 dist_embedding_layer* copy()
const override;
89 std::string get_type()
const override;
91 El::Device get_device_allocation()
const override;
92 bool can_run_inplace()
const override {
return false; }
93 int get_backprop_requirements()
const override 98 description get_description()
const override;
103 template <
typename ArchiveT>
110 void write_specific_proto(lbann_data::Layer& proto)
const final;
112 friend class cereal::access;
113 dist_embedding_layer();
115 void setup_dims()
override;
116 void setup_data(
size_t max_mini_batch_size)
override;
118 void fp_compute()
override;
119 void bp_compute()
override;
120 bool update_compute()
override;
129 struct vector_metadata
131 size_t source_rank{0};
132 size_t source_index{0};
133 size_t target_rank{0};
134 size_t target_index{0};
135 bool is_active{
false};
139 using LocalMat = El::Matrix<TensorDataType, Device>;
146 nb_barrier(lbann_comm& comm,
const El::mpi::Comm& c, Al::request& req);
148 void attach_embeddings_to_shmem_buffer();
149 void apply_sparse_sgd_step(
size_t num_gradients, LocalMat& local_embeddings);
158 TensorDataType* m_embeddings_buffer{
nullptr};
160 size_t m_embeddings_buffer_size{0};
163 TensorDataType* m_workspace_buffer{
nullptr};
165 size_t m_workspace_buffer_size{0};
168 vector_metadata* m_metadata_buffer{
nullptr};
170 size_t m_metadata_buffer_size{0};
183 Al::request m_nb_barrier_request;
186 size_t m_num_embeddings;
188 size_t m_embedding_dim;
196 DataType m_learning_rate;
209 bool m_barrier_in_forward_prop;
216 template <
typename T, data_layout L, El::Device D>
217 void dist_embedding_layer<T, L, D>::write_specific_proto(
218 lbann_data::Layer& proto)
const 220 proto.set_datatype(proto::ProtoDataType<T>);
221 auto* msg = proto.mutable_dist_embedding();
222 msg->set_num_embeddings(m_num_embeddings);
223 msg->set_embedding_dim(m_embedding_dim);
224 msg->set_sparse_sgd(m_sparse_sgd);
225 msg->set_learning_rate(m_learning_rate);
226 msg->set_barrier_in_forward_prop(m_barrier_in_forward_prop);
229 template <
typename TensorDataType, data_layout Layout, El::Device Device>
230 dist_embedding_layer<TensorDataType, Layout, Device>::dist_embedding_layer(
231 size_t num_embeddings,
232 size_t embedding_dim,
234 DataType learning_rate,
235 bool barrier_in_forward_prop)
236 : data_type_layer<TensorDataType>(nullptr),
237 m_num_embeddings{num_embeddings},
238 m_embedding_dim{embedding_dim},
239 m_sparse_sgd{sparse_sgd},
240 m_learning_rate{learning_rate},
241 m_barrier_in_forward_prop{barrier_in_forward_prop}
246 m_learning_rate = -1.0;
250 template <
typename TensorDataType, data_layout Layout, El::Device Device>
251 dist_embedding_layer<TensorDataType, Layout, Device>::dist_embedding_layer()
252 : dist_embedding_layer(1, 1, false,
El::To<DataType>(1), false)
255 template <
typename TensorDataType, data_layout Layout, El::Device Device>
256 dist_embedding_layer<TensorDataType, Layout, Device>::dist_embedding_layer(
257 const dist_embedding_layer& other)
258 : data_type_layer<TensorDataType>(other)
260 LBANN_ERROR(
"copy constructor is invalid for dist_embedding_layer");
263 template <
typename TensorDataType, data_layout Layout, El::Device Device>
264 dist_embedding_layer<TensorDataType, Layout, Device>&
265 dist_embedding_layer<TensorDataType, Layout, Device>::operator=(
266 const dist_embedding_layer& other)
268 LBANN_ERROR(
"copy assignment operator is invalid for dist_embedding_layer");
271 template <
typename TensorDataType, data_layout Layout, El::Device Device>
272 dist_embedding_layer<TensorDataType, Layout, Device>*
273 dist_embedding_layer<TensorDataType, Layout, Device>::copy()
const 275 return new dist_embedding_layer(*
this);
278 template <
typename TensorDataType, data_layout Layout, El::Device Device>
280 dist_embedding_layer<TensorDataType, Layout, Device>::get_type()
const 282 return "distributed embedding";
285 template <
typename TensorDataType, data_layout Layout, El::Device Device>
287 dist_embedding_layer<TensorDataType, Layout, Device>::get_data_layout()
const 292 template <
typename TensorDataType, data_layout Layout, El::Device Device>
294 dist_embedding_layer<TensorDataType, Layout, Device>::get_device_allocation()
300 template <
typename TensorDataType, data_layout Layout, El::Device Device>
302 dist_embedding_layer<TensorDataType, Layout, Device>::get_description()
const 305 desc.add(
"Num embeddings", m_num_embeddings);
306 desc.add(
"Embedding dim", m_embedding_dim);
307 desc.add(
"Using sparse SGD", m_sparse_sgd);
308 desc.add(
"SGD learning rate", m_learning_rate);
312 template <
typename TensorDataType, data_layout Layout, El::Device Device>
313 void dist_embedding_layer<TensorDataType, Layout, Device>::setup_dims()
316 auto dims = this->get_input_dims();
317 dims.push_back(static_cast<int>(m_embedding_dim));
318 this->set_output_dims(dims);
321 template <
typename TensorDataType, data_layout Layout, El::Device Device>
322 void dist_embedding_layer<TensorDataType, Layout, Device>::setup_data(
323 size_t max_mini_batch_size)
329 auto& comm = *this->get_comm();
330 comm.wait(m_nb_barrier_request);
335 if (!this->has_weights()) {
336 auto w = std::make_shared<data_type_weights<TensorDataType>>(comm);
337 auto init = std::make_unique<normal_initializer<TensorDataType>>(0, 1);
338 auto opt = this->m_model->template create_optimizer<TensorDataType>();
339 w->set_name(this->get_name() +
"_weights");
340 w->set_initializer(std::move(init));
341 w->set_optimizer(std::move(opt));
342 this->add_weights(w);
343 this->m_model->add_weights(std::move(w));
345 if (this->num_weights() != 1) {
351 "with an invalid number of weights ",
352 "(expected 1, found ",
358 auto& embeddings = this->get_weights(0);
360 auto dist = this->get_prev_activations().DistData();
361 dist.colDist = El::STAR;
362 dist.rowDist = El::VC;
363 embeddings.set_dims({m_embedding_dim}, {m_num_embeddings});
364 embeddings.set_matrix_distribution(dist);
375 embeddings.set_optimizer(
nullptr);
376 auto w = std::make_shared<data_type_weights<TensorDataType>>(comm);
377 auto opt = std::make_unique<sgd<TensorDataType>>(0.);
378 w->set_name(this->get_name() +
"_dummy_weights");
379 w->set_optimizer(std::move(opt));
381 w->set_matrix_distribution(embeddings.get_matrix_distribution());
383 this->add_weights(w);
384 this->m_model->add_weights(std::move(w));
389 attach_embeddings_to_shmem_buffer();
393 nb_barrier(comm, comm.get_trainer_comm(), m_nb_barrier_request);
396 template <
typename TensorDataType, data_layout Layout, El::Device Device>
397 bool dist_embedding_layer<TensorDataType, Layout, Device>::update_compute()
402 const size_t input_size = this->get_input_size();
403 const size_t mini_batch_size = this->get_prev_activations().Width();
404 using ValuesGetter = weights_details::SafeWeightsAccessor<TensorDataType>;
405 auto& embeddings = ValuesGetter::mutable_values(this->get_weights(0));
406 auto& local_embeddings =
dynamic_cast<LocalMat&
>(embeddings.Matrix());
407 apply_sparse_sgd_step(input_size * mini_batch_size, local_embeddings);
412 auto& comm = *this->get_comm();
413 comm.wait(m_nb_barrier_request);
414 nb_barrier(comm, comm.get_trainer_comm(), m_nb_barrier_request);
419 template <
typename TensorDataType, data_layout Layout, El::Device Device>
420 void dist_embedding_layer<TensorDataType, Layout, Device>::nb_barrier(
422 const El::mpi::Comm& c,
425 static El::Matrix<float, Device> buffer;
426 buffer.SetMemoryMode(0);
428 comm.nb_allreduce(buffer, c, req);
435 #ifdef LBANN_HAS_SHMEM 436 extern template class dist_embedding_layer<float,
437 data_layout::DATA_PARALLEL,
439 #endif // LBANN_HAS_SHMEM 440 #if defined(LBANN_HAS_GPU) && defined(LBANN_HAS_NVSHMEM) 441 extern template class dist_embedding_layer<float,
442 data_layout::DATA_PARALLEL,
444 #endif // defined(LBANN_HAS_GPU) && defined(LBANN_HAS_NVSHMEM) 447 #endif // defined(LBANN_HAS_SHMEM) || defined(LBANN_HAS_NVSHMEM) 455 #endif // LBANN_LAYERS_MISC_DIST_EMBEDDING_HPP_INCLUDED virtual void setup_dims()
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
virtual description get_description() const
Human-readable description.
constexpr El::Device Device
data_layout
Data layout that is optimized for different modes of parallelism.
void setup_data(size_t max_mini_batch_size) override