27 #ifndef LBANN_LAYERS_TRANSFORM_SCATTER_HPP_INCLUDED 28 #define LBANN_LAYERS_TRANSFORM_SCATTER_HPP_INCLUDED 32 #include "lbann/proto/layers.pb.h" 36 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 40 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 44 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 47 template <
typename TensorDataType>
48 using Scatter = ::distconv::Scatter<Backend, TensorDataType>;
51 template <
typename TensorDataType, data_layout Layout, El::Device Device>
52 class scatter_distconv_adapter
53 :
public data_type_distconv_adapter<TensorDataType>
59 scatter_distconv_adapter(Layer& layer)
60 : data_type_distconv_adapter<TensorDataType>(layer)
62 virtual ~scatter_distconv_adapter() =
default;
64 void setup_distributions(tensor_overlap_constraints& constraints)
override;
65 void setup_layer(
size_t workspace_capacity)
override;
68 dc::Shape get_activations_local_shape(
int index = 0)
const override;
70 std::unique_ptr<dc::Scatter<TensorDataType>> m_scatter_operator;
71 size_t m_workspace_buffer_size{0};
73 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 96 template <
typename TensorDataType,
102 "scatter layer only supports data parallel layout");
114 template <
typename ArchiveT>
119 std::string get_type()
const override;
121 El::Device get_device_allocation()
const override;
130 void write_specific_proto(lbann_data::Layer& proto)
const final;
132 friend class cereal::access;
134 void setup_dims()
override;
135 void fp_compute()
override;
136 void bp_compute()
override;
137 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 138 friend class scatter_distconv_adapter<TensorDataType, Layout, Device>;
139 void setup_distconv_adapter()
override;
140 bool is_distconv_supported()
const override;
141 scatter_distconv_adapter<TensorDataType, Layout, Device>&
142 get_distconv_adapter()
override;
143 const scatter_distconv_adapter<TensorDataType, Layout, Device>&
144 get_distconv_adapter()
const override;
145 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 154 template <
typename T, data_layout L, El::Device D>
156 lbann_data::Layer& proto)
const 158 proto.set_datatype(proto::ProtoDataType<T>);
159 auto* msg = proto.mutable_scatter();
161 msg->mutable_axis()->set_value(m_scatter_axis);
164 template <
typename TensorDataType, data_layout Layout, El::Device Device>
166 const std::vector<int>& dims,
174 template <
typename TensorDataType, data_layout Layout, El::Device Device>
181 template <
typename TensorDataType, data_layout Layout, El::Device Device>
187 template <
typename TensorDataType, data_layout Layout, El::Device Device>
194 template <
typename TensorDataType, data_layout Layout, El::Device Device>
201 template <
typename TensorDataType, data_layout Layout, El::Device Device>
210 auto dims_to_str = [](
const std::vector<int>& dims) -> std::string {
211 std::ostringstream ss;
212 for (
size_t i = 0; i < dims.size(); ++i) {
213 ss << (i > 0 ?
"x" :
"") << dims[i];
223 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 226 const auto is_values_3D = input0_dims.size() == 3;
227 const auto is_indices_3D = input1_dims.size() == 3;
228 const auto is_output_3D = output_dims.size() == 3;
231 if (!is_values_3D || !is_indices_3D || !is_output_3D) {
237 "has values input shape (",
238 dims_to_str(input0_dims),
240 "has indices input shape (",
241 dims_to_str(input1_dims),
243 "has output shape (",
244 dims_to_str(output_dims),
246 "Distconv Scatter requires all three to be 3D. ");
249 if (input0_dims[0] != input1_dims[0]) {
254 "has values input (",
255 dims_to_str(input0_dims),
257 "has indices input (",
258 dims_to_str(input1_dims),
260 "Distconv requires the 0-th dimension to match. ");
264 const auto output_dim_fail =
265 input0_dims[1] != output_dims[1] || input0_dims[2] != output_dims[2];
267 if (output_dim_fail) {
272 "has values input (",
273 dims_to_str(input0_dims),
275 "has indices input (",
276 dims_to_str(input1_dims),
278 "Distconv requires the 0-th dimension to match. ");
287 "requires the scatter dimension to be 0 when using distconv");
292 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 297 const auto is_values_1D = input0_dims.size() == 1;
298 const auto is_values_2D = input0_dims.size() == 2;
302 const auto is_output_1D = output_dims.size() == 1;
303 const auto is_output_2D = output_dims.size() == 2;
311 "has 2D input, but does not set a scatter axis.",
312 " Axis must be either set to 0 or 1");
316 if (input0_dims != input1_dims) {
320 if (input0_dims[matching_dim] != input1_dims[0]) {
327 "has input tensors with different outer dimensions ",
334 dims_to_str(input0_dims),
341 dims_to_str(input1_dims),
347 if (input1_dims.size() != 1 || !(is_values_1D || is_values_2D) ||
348 input0_dims.size() != output_dims.size()) {
353 "attempted to scatter from a ",
357 dims_to_str(input0_dims),
361 "but the scatter layer currently only supports ",
362 "scattering to and from a 1-D or 2-D tensor and the input and " 364 "must have the same number of dimensions");
367 if (!is_output_1D && (is_output_2D && output_dims[0] != input0_dims[0])) {
369 if (output_dims[matching_dim] != input0_dims[matching_dim]) {
375 "attempted to scatter into a ",
379 dims_to_str(output_dims),
382 input0_dims[matching_dim],
389 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 395 template <
typename TensorDataType, data_layout Layout, El::Device Device>
402 template <
typename TensorDataType, data_layout Layout, El::Device Device>
405 this->get_distconv_adapter_ptr() =
406 std::make_unique<scatter_distconv_adapter<TensorDataType, Layout, Device>>(
410 template <
typename TensorDataType, data_layout Layout, El::Device Device>
411 const scatter_distconv_adapter<TensorDataType, Layout, Device>&
415 const scatter_distconv_adapter<TensorDataType, Layout, Device>&
>(
419 template <
typename TensorDataType, data_layout Layout, El::Device Device>
420 scatter_distconv_adapter<TensorDataType, Layout, Device>&
423 return const_cast<scatter_distconv_adapter<TensorDataType, Layout, Device>&
>(
425 .get_distconv_adapter());
432 template <
typename TensorDataType, data_layout Layout, El::Device Device>
433 void scatter_distconv_adapter<TensorDataType, Layout, Device>::
438 for (
auto& d : this->m_prev_activations_dists) {
443 for (
auto& d : this->m_activations_dists) {
448 for (
auto& d : this->m_prev_error_signals_dists) {
453 for (
auto& d : this->m_error_signals_dists) {
460 template <
typename TensorDataType, data_layout Layout, El::Device Device>
461 void scatter_distconv_adapter<TensorDataType, Layout, Device>::setup_layer(
462 size_t workspace_capacity)
466 make_unique<dc::Scatter<TensorDataType>>(dc::get_backend());
473 template <
typename TensorDataType, data_layout Layout, El::Device Device>
474 dc::Shape scatter_distconv_adapter<TensorDataType, Layout, Device>::
475 get_activations_local_shape(
int index)
const 492 template <
typename TensorDataType, data_layout Layout, El::Device Device>
493 void scatter_distconv_adapter<TensorDataType, Layout, Device>::fp_compute()
501 template <
typename TensorDataType, data_layout Layout, El::Device Device>
502 void scatter_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
505 m_scatter_operator->backward(
512 #define PROTO_DEVICE(T, Device) \ 513 template class scatter_distconv_adapter<T, data_layout::DATA_PARALLEL, Device> 516 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 518 #ifndef LBANN_SCATTER_LAYER_INSTANTIATE 519 #define PROTO_DEVICE(T, Device) \ 520 extern template class scatter_layer<T, data_layout::DATA_PARALLEL, Device>; 523 #endif // LBANN_SCATTER_LAYER_INSTANTIATE 527 #endif // LBANN_LAYERS_TRANSFORM_SCATTER_HPP_INCLUDED bool distconv_enabled() const
Indicate whether distconv is enabled.
virtual void setup_dims()
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
bool can_run_inplace() const override
If True, the computation can run in-place (feeding each input activations tensor as the corresponding...
data_layout get_data_layout() const override
Get data layout of the data tensors. We assume that the data layouts of the previous activations...
void mark_updated(const dc::Dist &d)
int get_backprop_requirements() const override
Returns the necessary tensors for computing backpropagation.
scatter_layer * copy() const override
Copy function. This function dynamically allocates memory for a layer instance and instantiates a cop...
std::vector< int > get_input_dims(size_t input_index=0) const
Get input tensor dimensions.
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
void setup_dims() override
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
constexpr El::Device Device
virtual void setup_distributions(tensor_overlap_constraints &constraints)
OutputAbsDistMatrixType & get_prev_error_signals(int child_index=0)
InputAbsDistMatrixType & get_prev_activations(int parent_index=0)
const OutputAbsDistMatrixType & get_activations(const Layer &child) const override
void set_output_dims(std::vector< int > dims, size_t output_index=0)
Set output tensor dimensions.
void assign_to_repeated(google::protobuf::RepeatedField< T > &field, ContainerT const &values)
Assign a range of values to a repeated protobuf field.
std::string get_type() const override
Get the layer type's name.
::distconv::tensor::Shape Shape
std::string get_name() const
Get the layer instance's name.
virtual void setup_layer(size_t workspace_capacity)
void write_specific_proto(lbann_data::Layer &proto) const final
world_comm_ptr initialize(int &argc, char **&argv)
data_layout
Data layout that is optimized for different modes of parallelism.
const Layer & get_parent_layer(size_t index=0) const
Scatter values to specified tensor indices.
std::vector< int > get_output_dims(size_t output_index=0) const
Get output tensor dimensions.
El::Device get_device_allocation() const override
Get the device allocation for the data tensors. We assume that the decice allocation of the previous ...
void mark_invariant(const dc::Dist &d)
int m_expected_num_parent_layers
const InputAbsDistMatrixType & get_error_signals(const Layer &parent) const override
dc::TensorDev< OutputTensorDataType > TensorDevType