LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
gather.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_LAYERS_TRANSFORM_GATHER_HPP_INCLUDED
28 #define LBANN_LAYERS_TRANSFORM_GATHER_HPP_INCLUDED
29 
32 #include "lbann/proto/layers.pb.h"
34 
35 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
38 #include "lbann/utils/nvshmem.hpp"
39 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
40 
41 namespace lbann {
42 
43 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
44 namespace dc {
45 // using Backend = ::distconv::BackendDNNLib;
46 template <typename TensorDataType>
47 using Gather = ::distconv::Gather<Backend, TensorDataType>;
48 } // namespace dc
49 
50 template <typename TensorDataType, data_layout Layout, El::Device Device>
51 class gather_distconv_adapter
52  : public data_type_distconv_adapter<TensorDataType>
53 {
54 public:
55  using TensorDevType =
57 
58  gather_distconv_adapter(Layer& layer)
59  : data_type_distconv_adapter<TensorDataType>(layer)
60  {}
61  virtual ~gather_distconv_adapter() = default;
62 
63  void setup_distributions(tensor_overlap_constraints& constraints) override;
64  void setup_layer(size_t workspace_capacity) override;
65  void fp_compute();
66  void bp_compute();
67  dc::Shape get_activations_local_shape(int index = 0) const override;
68 
69  std::unique_ptr<dc::Gather<TensorDataType>> m_gather_operator;
70  size_t m_workspace_buffer_size{0};
71 };
72 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
73 
97 template <typename TensorDataType,
99  El::Device Device = El::Device::CPU>
100 class gather_layer : public data_type_layer<TensorDataType>
101 {
102  static_assert(Layout == data_layout::DATA_PARALLEL,
103  "gather layer only supports data parallel layout");
104 
105 public:
106  gather_layer(const int axis);
107  gather_layer(const gather_layer& other) = default;
108  gather_layer& operator=(const gather_layer& other) = default;
109 
110  gather_layer* copy() const override;
111 
113 
115  template <typename ArchiveT>
116  void serialize(ArchiveT& ar);
117 
119 
120  std::string get_type() const override;
121  data_layout get_data_layout() const override;
122  El::Device get_device_allocation() const override;
123  bool can_run_inplace() const override { return false; }
124  int get_backprop_requirements() const override
125  {
127  }
128 
129 protected:
131  void write_specific_proto(lbann_data::Layer& proto) const final;
132 
133  friend class cereal::access;
135  void setup_dims() override;
136  void fp_compute() override;
137  void bp_compute() override;
138 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
139  friend class gather_distconv_adapter<TensorDataType, Layout, Device>;
140  void setup_distconv_adapter() override;
141  bool is_distconv_supported() const override;
142  gather_distconv_adapter<TensorDataType, Layout, Device>&
143  get_distconv_adapter() override;
144  const gather_distconv_adapter<TensorDataType, Layout, Device>&
145  get_distconv_adapter() const override;
146 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
147 private:
149 };
150 
151 // =========================================================
152 // Implementation
153 // =========================================================
154 
155 template <typename T, data_layout L, El::Device D>
156 void gather_layer<T, L, D>::write_specific_proto(lbann_data::Layer& proto) const
157 {
158  proto.set_datatype(proto::ProtoDataType<T>);
159  auto* msg = proto.mutable_gather();
160  msg->mutable_axis()->set_value(m_gather_axis);
161 }
162 
163 template <typename TensorDataType, data_layout Layout, El::Device Device>
165  : data_type_layer<TensorDataType>(nullptr), m_gather_axis{axis}
166 {
168 }
169 
170 template <typename TensorDataType, data_layout Layout, El::Device Device>
173 {
174  return new gather_layer(*this);
175 }
176 
177 template <typename TensorDataType, data_layout Layout, El::Device Device>
179 {
180  return "gather";
181 }
182 
183 template <typename TensorDataType, data_layout Layout, El::Device Device>
186 {
187  return Layout;
188 }
189 
190 template <typename TensorDataType, data_layout Layout, El::Device Device>
193 {
194  return Device;
195 }
196 
197 template <typename TensorDataType, data_layout Layout, El::Device Device>
199 {
201 
202  // Tensor dimensions
203  const auto& input0_dims = this->get_input_dims(0);
204  const auto& input1_dims = this->get_input_dims(1);
205 
206  const bool along_axis_0 = this->m_gather_axis == 0;
207 
208  auto dims_to_str = [](const std::vector<int>& dims) -> std::string {
209  std::ostringstream ss;
210  for (size_t i = 0; i < dims.size(); ++i) {
211  ss << (i > 0 ? "x" : "") << dims[i];
212  }
213  return ss.str();
214  };
215 
216  // Tensor dimension requirements are different
217  // when using distconv
218  // Distconv requires 3D inputs for both values
219  // and indices
220 
221 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
222 
223  if (this->distconv_enabled()) {
224  const auto is_values_3D = input0_dims.size() == 3;
225  const auto is_indices_3D = input1_dims.size() == 3;
226 
227  // Input matrices need to be 3D
228  if (!is_values_3D || !is_indices_3D) {
229 
230  LBANN_ERROR(this->get_type(),
231  " Layer \"",
232  this->get_name(),
233  "\" ",
234  "has values input (",
235  dims_to_str(input0_dims),
236  ") ",
237  "has indices input (",
238  dims_to_str(input1_dims),
239  "). ",
240  "Distconv Gather requires both to be 3D. ");
241  }
242  // Make sure only gathering along axis 0
243  if (along_axis_0) {
244  this->set_output_dims(
245  std::vector<int>{input1_dims[0], input0_dims[1], 1});
246  }
247  else {
248  LBANN_ERROR(this->get_type(),
249  "Layer \"",
250  this->get_name(),
251  "\"",
252  "cannot gather along axis ",
254  " when distconv is enabled");
255  }
256  return;
257  }
258 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
259 
260  // Only support 1D indices
261  const auto is_indices_not_1D = input1_dims.size() != 1;
262 
263  // Only support 1D or 2D values
264  const auto is_values_1D = input0_dims.size() == 1;
265  const auto is_values_2D = input0_dims.size() == 2;
266 
267  if (is_values_2D) {
268  if (this->m_gather_axis == -1) {
269  LBANN_ERROR(this->get_type(),
270  " Layer \"",
271  this->get_name(),
272  "\" ",
273  "has 2D input, but does not set a gather axis.",
274  "Axis must be either set to 0 or 1");
275  }
276  }
277  if (is_values_1D) {
278  this->set_output_dims(input1_dims);
279  }
280  else {
281  //
282  if (along_axis_0) {
283  this->set_output_dims(std::vector<int>{input1_dims[0], input0_dims[1]});
284  }
285  else {
286  this->set_output_dims(std::vector<int>{input0_dims[0], input1_dims[0]});
287  }
288  }
289 
290  // Make sure input tensors have supported numbers of dimensions
291 
292  if (is_indices_not_1D || !(is_values_1D || is_values_2D)) {
293  const auto& parent0 = this->get_parent_layer(0);
294  const auto& parent1 = this->get_parent_layer(1);
295  LBANN_ERROR(this->get_type(),
296  " layer \"",
297  this->get_name(),
298  "\" ",
299  "has input tensors with incorrect numbers of dimensions. "
300  "Expected 1D or 2D values tensor and 1D indices tensor.",
301  " Expected 3D-only tensors for distconv-enabled Gather. ",
302  "(",
303  parent0.get_type(),
304  " layer \"",
305  parent0.get_name(),
306  "\" ",
307  "outputs ",
308  dims_to_str(input0_dims),
309  ", ",
310  parent1.get_type(),
311  " layer \"",
312  parent1.get_name(),
313  "\" ",
314  "outputs ",
315  dims_to_str(input1_dims),
316  ")");
317  }
318 
319  // Check that tensors are 1D
321  if (!is_values_1D && !is_values_2D) {
322  LBANN_ERROR(this->get_type(),
323  " layer \"",
324  this->get_name(),
325  "\" ",
326  "attempted to gather from a ",
327  input0_dims.size(),
328  "-D tensor ",
329  "(",
330  dims_to_str(input0_dims),
331  "), "
332  "but the gather layer currently only supports ",
333  "gathering from a 1-D or 2-D tensor");
334  }
335 }
336 
337 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM)
338 
339 // =============================================================
340 // DistConv-enabled Gather member functions
341 // =============================================================
342 
343 template <typename TensorDataType, data_layout Layout, El::Device Device>
345 {
346  return Device == El::Device::GPU && Layout == data_layout::DATA_PARALLEL;
347 }
348 
349 template <typename TensorDataType, data_layout Layout, El::Device Device>
351 {
352  this->get_distconv_adapter_ptr() =
353  std::make_unique<gather_distconv_adapter<TensorDataType, Layout, Device>>(
354  *this);
355 }
356 
357 template <typename TensorDataType, data_layout Layout, El::Device Device>
358 const gather_distconv_adapter<TensorDataType, Layout, Device>&
360 {
361  return dynamic_cast<
362  const gather_distconv_adapter<TensorDataType, Layout, Device>&>(
364 }
365 
366 template <typename TensorDataType, data_layout Layout, El::Device Device>
367 gather_distconv_adapter<TensorDataType, Layout, Device>&
369 {
370  return const_cast<gather_distconv_adapter<TensorDataType, Layout, Device>&>(
371  static_cast<const gather_layer<TensorDataType, Layout, Device>&>(*this)
372  .get_distconv_adapter());
373 }
374 
375 // =============================================================
376 // Gather DistConv Adapter implementation
377 // =============================================================
378 
379 template <typename TensorDataType, data_layout Layout, El::Device Device>
380 void gather_distconv_adapter<TensorDataType, Layout, Device>::
381  setup_distributions(tensor_overlap_constraints& constraints)
382 {
384  // no overlap needed
385  for (auto& d : this->m_prev_activations_dists) {
386  d.clear_overlap();
387  constraints.mark_updated(d);
388  constraints.mark_invariant(d);
389  }
390  for (auto& d : this->m_activations_dists) {
391  d.clear_overlap();
392  constraints.mark_updated(d);
393  constraints.mark_invariant(d);
394  }
395  for (auto& d : this->m_prev_error_signals_dists) {
396  d.clear_overlap();
397  constraints.mark_updated(d);
398  constraints.mark_invariant(d);
399  }
400  for (auto& d : this->m_error_signals_dists) {
401  d.clear_overlap();
402  constraints.mark_updated(d);
403  constraints.mark_invariant(d);
404  }
405 }
406 
407 template <typename TensorDataType, data_layout Layout, El::Device Device>
408 void gather_distconv_adapter<TensorDataType, Layout, Device>::setup_layer(
409  size_t workspace_capacity)
410 {
412  m_gather_operator =
413  make_unique<dc::Gather<TensorDataType>>(dc::get_backend());
415  m_gather_operator->setup(this->get_prev_activations(0),
416  this->get_prev_activations(1),
417  this->get_activations());
418 }
419 
420 template <typename TensorDataType, data_layout Layout, El::Device Device>
421 dc::Shape gather_distconv_adapter<TensorDataType, Layout, Device>::
422  get_activations_local_shape(int index) const
423 {
424  const auto& layer =
426  this->layer());
427  auto output_dims = layer.get_output_dims();
428  // Get the indices layer shape
429  auto output_shape = this->get_prev_activations(1).get_local_shape();
430  auto values_shape = this->get_prev_activations(0).get_local_shape();
431  // Change the column dimension to match, the rest should be the same
432  // To do: Maybe move this to distconv namespace - SZ
433  output_shape[1] = values_shape[1];
434  return output_shape;
435 }
436 
437 template <typename TensorDataType, data_layout Layout, El::Device Device>
438 void gather_distconv_adapter<TensorDataType, Layout, Device>::fp_compute()
439 {
440  // Compute the forward pass
441  m_gather_operator->forward(this->get_prev_activations(0),
442  this->get_prev_activations(1),
443  this->get_activations());
444 }
445 
446 template <typename TensorDataType, data_layout Layout, El::Device Device>
447 void gather_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
448 {
449  // Compute the backward pass
450  m_gather_operator->backward(
451  this->get_prev_error_signals(),
452  this->get_prev_activations(1),
453  this->get_error_signals(0), // Values gradient
454  this->get_error_signals(1)); // Indices gradient. Will be 0'ed out
455 }
456 
457 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM
458 
459 #ifndef LBANN_GATHER_LAYER_INSTANTIATE
460 #define PROTO_DEVICE(T, Device) \
461  extern template class gather_layer<T, data_layout::DATA_PARALLEL, Device>
463 #undef PROTO_DEVICE
464 #endif // LBANN_GATHER_LAYER_INSTANTIATE
465 
466 } // namespace lbann
467 
468 #endif // LBANN_LAYERS_TRANSFORM_GATHER_HPP_INCLUDED
bool distconv_enabled() const
Indicate whether distconv is enabled.
Definition: layer.hpp:1082
virtual void setup_dims()
Setup tensor dimensions Called by the &#39;setup&#39; function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
Gather values from specified tensor indices.
Definition: gather.hpp:100
#define LBANN_ERROR(...)
Definition: exception.hpp:37
void mark_updated(const dc::Dist &d)
std::vector< int > get_input_dims(size_t input_index=0) const
Get input tensor dimensions.
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
bool can_run_inplace() const override
If True, the computation can run in-place (feeding each input activations tensor as the corresponding...
Definition: gather.hpp:123
int get_backprop_requirements() const override
Returns the necessary tensors for computing backpropagation.
Definition: gather.hpp:124
constexpr El::Device Device
virtual void setup_distributions(tensor_overlap_constraints &constraints)
OutputAbsDistMatrixType & get_prev_error_signals(int child_index=0)
InputAbsDistMatrixType & get_prev_activations(int parent_index=0)
const OutputAbsDistMatrixType & get_activations(const Layer &child) const override
std::string get_type() const override
Get the layer type&#39;s name.
Definition: gather.hpp:178
gather_layer * copy() const override
Copy function. This function dynamically allocates memory for a layer instance and instantiates a cop...
Definition: gather.hpp:172
void set_output_dims(std::vector< int > dims, size_t output_index=0)
Set output tensor dimensions.
data_layout get_data_layout() const override
Get data layout of the data tensors. We assume that the data layouts of the previous activations...
Definition: gather.hpp:185
El::Device get_device_allocation() const override
Get the device allocation for the data tensors. We assume that the decice allocation of the previous ...
Definition: gather.hpp:192
::distconv::tensor::Shape Shape
std::string get_name() const
Get the layer instance&#39;s name.
Definition: layer.hpp:332
virtual void setup_layer(size_t workspace_capacity)
world_comm_ptr initialize(int &argc, char **&argv)
data_layout
Data layout that is optimized for different modes of parallelism.
Definition: base.hpp:218
const Layer & get_parent_layer(size_t index=0) const
std::vector< int > get_output_dims(size_t output_index=0) const
Get output tensor dimensions.
void mark_invariant(const dc::Dist &d)
void write_specific_proto(lbann_data::Layer &proto) const final
Definition: gather.hpp:156
int m_expected_num_parent_layers
Definition: layer.hpp:838
const InputAbsDistMatrixType & get_error_signals(const Layer &parent) const override
dc::TensorDev< OutputTensorDataType > TensorDevType
void setup_dims() override
Setup tensor dimensions Called by the &#39;setup&#39; function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
Definition: gather.hpp:198