LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
scatter.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_LAYERS_TRANSFORM_SCATTER_HPP_INCLUDED
28 #define LBANN_LAYERS_TRANSFORM_SCATTER_HPP_INCLUDED
29 
32 #include "lbann/proto/layers.pb.h"
34 #include "lbann/utils/protobuf.hpp"
35 
36 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
39 #include "lbann/utils/nvshmem.hpp"
40 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
41 
42 namespace lbann {
43 
44 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
45 namespace dc {
46 // using Backend = ::distconv::BackendDNNLib;
47 template <typename TensorDataType>
48 using Scatter = ::distconv::Scatter<Backend, TensorDataType>;
49 } // namespace dc
50 
51 template <typename TensorDataType, data_layout Layout, El::Device Device>
52 class scatter_distconv_adapter
53  : public data_type_distconv_adapter<TensorDataType>
54 {
55 public:
56  using TensorDevType =
58 
59  scatter_distconv_adapter(Layer& layer)
60  : data_type_distconv_adapter<TensorDataType>(layer)
61  {}
62  virtual ~scatter_distconv_adapter() = default;
63 
64  void setup_distributions(tensor_overlap_constraints& constraints) override;
65  void setup_layer(size_t workspace_capacity) override;
66  void fp_compute();
67  void bp_compute();
68  dc::Shape get_activations_local_shape(int index = 0) const override;
69 
70  std::unique_ptr<dc::Scatter<TensorDataType>> m_scatter_operator;
71  size_t m_workspace_buffer_size{0};
72 };
73 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
74 
96 template <typename TensorDataType,
98  El::Device Device = El::Device::CPU>
99 class scatter_layer : public data_type_layer<TensorDataType>
100 {
101  static_assert(Layout == data_layout::DATA_PARALLEL,
102  "scatter layer only supports data parallel layout");
103 
104 public:
105  scatter_layer(const std::vector<int>& dims, const int axis);
106  scatter_layer(const scatter_layer& other) = default;
107  scatter_layer& operator=(const scatter_layer& other) = default;
108 
109  scatter_layer* copy() const override;
110 
112 
114  template <typename ArchiveT>
115  void serialize(ArchiveT& ar);
116 
118 
119  std::string get_type() const override;
120  data_layout get_data_layout() const override;
121  El::Device get_device_allocation() const override;
122  bool can_run_inplace() const override { return false; }
123  int get_backprop_requirements() const override
124  {
126  }
127 
128 protected:
130  void write_specific_proto(lbann_data::Layer& proto) const final;
131 
132  friend class cereal::access;
133  scatter_layer() : scatter_layer({1}, -1) {}
134  void setup_dims() override;
135  void fp_compute() override;
136  void bp_compute() override;
137 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
138  friend class scatter_distconv_adapter<TensorDataType, Layout, Device>;
139  void setup_distconv_adapter() override;
140  bool is_distconv_supported() const override;
141  scatter_distconv_adapter<TensorDataType, Layout, Device>&
142  get_distconv_adapter() override;
143  const scatter_distconv_adapter<TensorDataType, Layout, Device>&
144  get_distconv_adapter() const override;
145 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
146 private:
148 };
149 
150 // =========================================================
151 // Implementation
152 // =========================================================
153 
154 template <typename T, data_layout L, El::Device D>
156  lbann_data::Layer& proto) const
157 {
158  proto.set_datatype(proto::ProtoDataType<T>);
159  auto* msg = proto.mutable_scatter();
160  protobuf::assign_to_repeated(*msg->mutable_dims(), this->get_output_dims());
161  msg->mutable_axis()->set_value(m_scatter_axis);
162 }
163 
164 template <typename TensorDataType, data_layout Layout, El::Device Device>
166  const std::vector<int>& dims,
167  const int axis)
168  : data_type_layer<TensorDataType>(nullptr), m_scatter_axis{axis}
169 {
171  this->set_output_dims(dims);
172 }
173 
174 template <typename TensorDataType, data_layout Layout, El::Device Device>
177 {
178  return new scatter_layer(*this);
179 }
180 
181 template <typename TensorDataType, data_layout Layout, El::Device Device>
183 {
184  return "scatter";
185 }
186 
187 template <typename TensorDataType, data_layout Layout, El::Device Device>
190 {
191  return Layout;
192 }
193 
194 template <typename TensorDataType, data_layout Layout, El::Device Device>
197 {
198  return Device;
199 }
200 
201 template <typename TensorDataType, data_layout Layout, El::Device Device>
203 {
205 
206  const auto& input0_dims = this->get_input_dims(0);
207  const auto& input1_dims = this->get_input_dims(1);
208  const auto& output_dims = this->get_output_dims();
209 
210  auto dims_to_str = [](const std::vector<int>& dims) -> std::string {
211  std::ostringstream ss;
212  for (size_t i = 0; i < dims.size(); ++i) {
213  ss << (i > 0 ? "x" : "") << dims[i];
214  }
215  return ss.str();
216  };
217 
218  // Tensor dimension requirements are different
219  // when using distconv
220  // Distconv requires 3D inputs for both values
221  // and indices
222 
223 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
224 
225  if (this->distconv_enabled()) {
226  const auto is_values_3D = input0_dims.size() == 3;
227  const auto is_indices_3D = input1_dims.size() == 3;
228  const auto is_output_3D = output_dims.size() == 3;
229 
230  // Matrices need to be 3D
231  if (!is_values_3D || !is_indices_3D || !is_output_3D) {
232 
233  LBANN_ERROR(this->get_type(),
234  " Layer \"",
235  this->get_name(),
236  "\" ",
237  "has values input shape (",
238  dims_to_str(input0_dims),
239  ") ",
240  "has indices input shape (",
241  dims_to_str(input1_dims),
242  "). ",
243  "has output shape (",
244  dims_to_str(output_dims),
245  ")",
246  "Distconv Scatter requires all three to be 3D. ");
247  }
248  // The 0-th dimension of the values and indices must match
249  if (input0_dims[0] != input1_dims[0]) {
250  LBANN_ERROR(this->get_type(),
251  " Layer \"",
252  this->get_name(),
253  "\" ",
254  "has values input (",
255  dims_to_str(input0_dims),
256  ") ",
257  "has indices input (",
258  dims_to_str(input1_dims),
259  "). ",
260  "Distconv requires the 0-th dimension to match. ");
261  }
262 
263  // The 1st and 2D dimension of the values and output must match
264  const auto output_dim_fail =
265  input0_dims[1] != output_dims[1] || input0_dims[2] != output_dims[2];
266 
267  if (output_dim_fail) {
268  LBANN_ERROR(this->get_type(),
269  " Layer \"",
270  this->get_name(),
271  "\" ",
272  "has values input (",
273  dims_to_str(input0_dims),
274  ") ",
275  "has indices input (",
276  dims_to_str(input1_dims),
277  "). ",
278  "Distconv requires the 0-th dimension to match. ");
279  }
280 
281  // Enable distconv only for scatter along the 0-th dimension
282  if (this->m_scatter_axis != 0) {
283  LBANN_ERROR(this->get_type(),
284  " Layer \"",
285  this->get_name(),
286  "\" ",
287  "requires the scatter dimension to be 0 when using distconv");
288  }
289 
290  return;
291  }
292 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
293 
294  // Tensor dimensions
295  // Check if value matrix is 1D or 2D
296 
297  const auto is_values_1D = input0_dims.size() == 1;
298  const auto is_values_2D = input0_dims.size() == 2;
299 
300  // Check if output matrix is 1D or 2D
301 
302  const auto is_output_1D = output_dims.size() == 1;
303  const auto is_output_2D = output_dims.size() == 2;
304 
305  if (is_values_2D) {
306  if (this->m_scatter_axis == -1) {
307  LBANN_ERROR(this->get_type(),
308  " Layer \"",
309  this->get_name(),
310  "\" ",
311  "has 2D input, but does not set a scatter axis.",
312  " Axis must be either set to 0 or 1");
313  }
314  }
315  // Make sure input tensors have same dimensions
316  if (input0_dims != input1_dims) {
317 
318  // If input tensors are not same, make sure it's 2D and 1D
319  const auto matching_dim = this->m_scatter_axis == 0 ? 0 : 1;
320  if (input0_dims[matching_dim] != input1_dims[0]) {
321  const auto& parent0 = this->get_parent_layer(0);
322  const auto& parent1 = this->get_parent_layer(1);
323  LBANN_ERROR(this->get_type(),
324  " layer \"",
325  this->get_name(),
326  "\" ",
327  "has input tensors with different outer dimensions ",
328  "(",
329  parent0.get_type(),
330  " layer \"",
331  parent0.get_name(),
332  "\" ",
333  "outputs ",
334  dims_to_str(input0_dims),
335  ", ",
336  parent1.get_type(),
337  " layer \"",
338  parent1.get_name(),
339  "\" ",
340  "outputs ",
341  dims_to_str(input1_dims),
342  ")");
343  }
344  }
345 
346  // Check tensor dimensions
347  if (input1_dims.size() != 1 || !(is_values_1D || is_values_2D) ||
348  input0_dims.size() != output_dims.size()) {
349  LBANN_ERROR(this->get_type(),
350  " layer \"",
351  this->get_name(),
352  "\" ",
353  "attempted to scatter from a ",
354  input0_dims.size(),
355  "-D tensor ",
356  "(",
357  dims_to_str(input0_dims),
358  "), to a ",
359  output_dims.size(),
360  "-D tensor ",
361  "but the scatter layer currently only supports ",
362  "scattering to and from a 1-D or 2-D tensor and the input and "
363  "output tensors",
364  "must have the same number of dimensions");
365  }
366  // Check if either output is 1D or the first dim matches for input and output
367  if (!is_output_1D && (is_output_2D && output_dims[0] != input0_dims[0])) {
368  const auto matching_dim = this->m_scatter_axis == 0 ? 1 : 0;
369  if (output_dims[matching_dim] != input0_dims[matching_dim]) {
370 
371  LBANN_ERROR(this->get_type(),
372  " layer \"",
373  this->get_name(),
374  "\" ",
375  "attempted to scatter into a ",
376  output_dims.size(),
377  "-D tensor ",
378  "(",
379  dims_to_str(output_dims),
380  "), "
381  "but expected ",
382  input0_dims[matching_dim],
383  " on axis ",
384  matching_dim);
385  }
386  }
387 }
388 
389 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
390 
391 // =============================================================
392 // DistConv-enabled Scatter member functions
393 // =============================================================
394 
395 template <typename TensorDataType, data_layout Layout, El::Device Device>
397  const
398 {
399  return Device == El::Device::GPU && Layout == data_layout::DATA_PARALLEL;
400 }
401 
402 template <typename TensorDataType, data_layout Layout, El::Device Device>
404 {
405  this->get_distconv_adapter_ptr() =
406  std::make_unique<scatter_distconv_adapter<TensorDataType, Layout, Device>>(
407  *this);
408 }
409 
410 template <typename TensorDataType, data_layout Layout, El::Device Device>
411 const scatter_distconv_adapter<TensorDataType, Layout, Device>&
413 {
414  return dynamic_cast<
415  const scatter_distconv_adapter<TensorDataType, Layout, Device>&>(
417 }
418 
419 template <typename TensorDataType, data_layout Layout, El::Device Device>
420 scatter_distconv_adapter<TensorDataType, Layout, Device>&
422 {
423  return const_cast<scatter_distconv_adapter<TensorDataType, Layout, Device>&>(
424  static_cast<const scatter_layer<TensorDataType, Layout, Device>&>(*this)
425  .get_distconv_adapter());
426 }
427 
428 // =============================================================
429 // Scatter DistConv Adapter implementation
430 // =============================================================
431 
432 template <typename TensorDataType, data_layout Layout, El::Device Device>
433 void scatter_distconv_adapter<TensorDataType, Layout, Device>::
434  setup_distributions(tensor_overlap_constraints& constraints)
435 {
437  // no overlap needed
438  for (auto& d : this->m_prev_activations_dists) {
439  d.clear_overlap();
440  constraints.mark_updated(d);
441  constraints.mark_invariant(d);
442  }
443  for (auto& d : this->m_activations_dists) {
444  d.clear_overlap();
445  constraints.mark_updated(d);
446  constraints.mark_invariant(d);
447  }
448  for (auto& d : this->m_prev_error_signals_dists) {
449  d.clear_overlap();
450  constraints.mark_updated(d);
451  constraints.mark_invariant(d);
452  }
453  for (auto& d : this->m_error_signals_dists) {
454  d.clear_overlap();
455  constraints.mark_updated(d);
456  constraints.mark_invariant(d);
457  }
458 }
459 
460 template <typename TensorDataType, data_layout Layout, El::Device Device>
461 void scatter_distconv_adapter<TensorDataType, Layout, Device>::setup_layer(
462  size_t workspace_capacity)
463 {
465  m_scatter_operator =
466  make_unique<dc::Scatter<TensorDataType>>(dc::get_backend());
468  m_scatter_operator->setup(this->get_prev_activations(0),
469  this->get_prev_activations(1),
470  this->get_activations());
471 }
472 
473 template <typename TensorDataType, data_layout Layout, El::Device Device>
474 dc::Shape scatter_distconv_adapter<TensorDataType, Layout, Device>::
475  get_activations_local_shape(int index) const
476 {
477  const auto& layer =
479  this->layer());
480  // Get the output dims witout the mini batch size
481  auto output_dims = layer.get_output_dims();
482  // Get the values layer shape
483  auto output_shape = this->get_prev_activations().get_local_shape();
484  // Divide the output channel dimension by the number of channel splits
485  // To do: Maybe move this to distconv namespace - SZ
486  output_shape[2] =
487  output_dims[0] /
488  this->get_prev_activations().get_distribution().get_split_shape()[-2];
489  return output_shape;
490 }
491 
492 template <typename TensorDataType, data_layout Layout, El::Device Device>
493 void scatter_distconv_adapter<TensorDataType, Layout, Device>::fp_compute()
494 {
495  // Compute the forward pass
496  m_scatter_operator->forward(this->get_prev_activations(0),
497  this->get_prev_activations(1),
498  this->get_activations());
499 }
500 
501 template <typename TensorDataType, data_layout Layout, El::Device Device>
502 void scatter_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
503 {
504  // Compute the backward pass
505  m_scatter_operator->backward(
506  this->get_prev_error_signals(0),
507  this->get_prev_activations(1),
508  this->get_error_signals(0), // Values gradient
509  this->get_error_signals(1)); // Indices gradient. Will be 0'ed out
510 }
511 
512 #define PROTO_DEVICE(T, Device) \
513  template class scatter_distconv_adapter<T, data_layout::DATA_PARALLEL, Device>
515 #undef PROTO_DEVICE
516 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
517 
518 #ifndef LBANN_SCATTER_LAYER_INSTANTIATE
519 #define PROTO_DEVICE(T, Device) \
520  extern template class scatter_layer<T, data_layout::DATA_PARALLEL, Device>;
522 #undef PROTO_DEVICE
523 #endif // LBANN_SCATTER_LAYER_INSTANTIATE
524 
525 } // namespace lbann
526 
527 #endif // LBANN_LAYERS_TRANSFORM_SCATTER_HPP_INCLUDED
bool distconv_enabled() const
Indicate whether distconv is enabled.
Definition: layer.hpp:1082
virtual void setup_dims()
Setup tensor dimensions Called by the &#39;setup&#39; function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
bool can_run_inplace() const override
If True, the computation can run in-place (feeding each input activations tensor as the corresponding...
Definition: scatter.hpp:122
data_layout get_data_layout() const override
Get data layout of the data tensors. We assume that the data layouts of the previous activations...
Definition: scatter.hpp:189
#define LBANN_ERROR(...)
Definition: exception.hpp:37
void mark_updated(const dc::Dist &d)
int get_backprop_requirements() const override
Returns the necessary tensors for computing backpropagation.
Definition: scatter.hpp:123
scatter_layer * copy() const override
Copy function. This function dynamically allocates memory for a layer instance and instantiates a cop...
Definition: scatter.hpp:176
std::vector< int > get_input_dims(size_t input_index=0) const
Get input tensor dimensions.
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
void setup_dims() override
Setup tensor dimensions Called by the &#39;setup&#39; function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
Definition: scatter.hpp:202
constexpr El::Device Device
virtual void setup_distributions(tensor_overlap_constraints &constraints)
OutputAbsDistMatrixType & get_prev_error_signals(int child_index=0)
InputAbsDistMatrixType & get_prev_activations(int parent_index=0)
const OutputAbsDistMatrixType & get_activations(const Layer &child) const override
void set_output_dims(std::vector< int > dims, size_t output_index=0)
Set output tensor dimensions.
void assign_to_repeated(google::protobuf::RepeatedField< T > &field, ContainerT const &values)
Assign a range of values to a repeated protobuf field.
Definition: impl.hpp:125
std::string get_type() const override
Get the layer type&#39;s name.
Definition: scatter.hpp:182
::distconv::tensor::Shape Shape
std::string get_name() const
Get the layer instance&#39;s name.
Definition: layer.hpp:332
virtual void setup_layer(size_t workspace_capacity)
void write_specific_proto(lbann_data::Layer &proto) const final
Definition: scatter.hpp:155
world_comm_ptr initialize(int &argc, char **&argv)
data_layout
Data layout that is optimized for different modes of parallelism.
Definition: base.hpp:218
const Layer & get_parent_layer(size_t index=0) const
Scatter values to specified tensor indices.
Definition: scatter.hpp:99
std::vector< int > get_output_dims(size_t output_index=0) const
Get output tensor dimensions.
El::Device get_device_allocation() const override
Get the device allocation for the data tensors. We assume that the decice allocation of the previous ...
Definition: scatter.hpp:196
void mark_invariant(const dc::Dist &d)
int m_expected_num_parent_layers
Definition: layer.hpp:838
const InputAbsDistMatrixType & get_error_signals(const Layer &parent) const override
dc::TensorDev< OutputTensorDataType > TensorDevType