27 #ifndef LBANN_LAYERS_TRANSFORM_GATHER_HPP_INCLUDED 28 #define LBANN_LAYERS_TRANSFORM_GATHER_HPP_INCLUDED 32 #include "lbann/proto/layers.pb.h" 35 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 39 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 43 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 46 template <
typename TensorDataType>
47 using Gather = ::distconv::Gather<Backend, TensorDataType>;
50 template <
typename TensorDataType, data_layout Layout, El::Device Device>
51 class gather_distconv_adapter
52 :
public data_type_distconv_adapter<TensorDataType>
58 gather_distconv_adapter(Layer& layer)
59 : data_type_distconv_adapter<TensorDataType>(layer)
61 virtual ~gather_distconv_adapter() =
default;
63 void setup_distributions(tensor_overlap_constraints& constraints)
override;
64 void setup_layer(
size_t workspace_capacity)
override;
67 dc::Shape get_activations_local_shape(
int index = 0)
const override;
69 std::unique_ptr<dc::Gather<TensorDataType>> m_gather_operator;
70 size_t m_workspace_buffer_size{0};
72 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 97 template <
typename TensorDataType,
103 "gather layer only supports data parallel layout");
115 template <
typename ArchiveT>
120 std::string get_type()
const override;
122 El::Device get_device_allocation()
const override;
131 void write_specific_proto(lbann_data::Layer& proto)
const final;
133 friend class cereal::access;
135 void setup_dims()
override;
136 void fp_compute()
override;
137 void bp_compute()
override;
138 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 139 friend class gather_distconv_adapter<TensorDataType, Layout, Device>;
140 void setup_distconv_adapter()
override;
141 bool is_distconv_supported()
const override;
142 gather_distconv_adapter<TensorDataType, Layout, Device>&
143 get_distconv_adapter()
override;
144 const gather_distconv_adapter<TensorDataType, Layout, Device>&
145 get_distconv_adapter()
const override;
146 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 155 template <
typename T, data_layout L, El::Device D>
158 proto.set_datatype(proto::ProtoDataType<T>);
159 auto* msg = proto.mutable_gather();
160 msg->mutable_axis()->set_value(m_gather_axis);
163 template <
typename TensorDataType, data_layout Layout, El::Device Device>
170 template <
typename TensorDataType, data_layout Layout, El::Device Device>
177 template <
typename TensorDataType, data_layout Layout, El::Device Device>
183 template <
typename TensorDataType, data_layout Layout, El::Device Device>
190 template <
typename TensorDataType, data_layout Layout, El::Device Device>
197 template <
typename TensorDataType, data_layout Layout, El::Device Device>
208 auto dims_to_str = [](
const std::vector<int>& dims) -> std::string {
209 std::ostringstream ss;
210 for (
size_t i = 0; i < dims.size(); ++i) {
211 ss << (i > 0 ?
"x" :
"") << dims[i];
221 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 224 const auto is_values_3D = input0_dims.size() == 3;
225 const auto is_indices_3D = input1_dims.size() == 3;
228 if (!is_values_3D || !is_indices_3D) {
234 "has values input (",
235 dims_to_str(input0_dims),
237 "has indices input (",
238 dims_to_str(input1_dims),
240 "Distconv Gather requires both to be 3D. ");
245 std::vector<int>{input1_dims[0], input0_dims[1], 1});
252 "cannot gather along axis ",
254 " when distconv is enabled");
258 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 261 const auto is_indices_not_1D = input1_dims.size() != 1;
264 const auto is_values_1D = input0_dims.size() == 1;
265 const auto is_values_2D = input0_dims.size() == 2;
273 "has 2D input, but does not set a gather axis.",
274 "Axis must be either set to 0 or 1");
283 this->
set_output_dims(std::vector<int>{input1_dims[0], input0_dims[1]});
286 this->
set_output_dims(std::vector<int>{input0_dims[0], input1_dims[0]});
292 if (is_indices_not_1D || !(is_values_1D || is_values_2D)) {
299 "has input tensors with incorrect numbers of dimensions. " 300 "Expected 1D or 2D values tensor and 1D indices tensor.",
301 " Expected 3D-only tensors for distconv-enabled Gather. ",
308 dims_to_str(input0_dims),
315 dims_to_str(input1_dims),
321 if (!is_values_1D && !is_values_2D) {
326 "attempted to gather from a ",
330 dims_to_str(input0_dims),
332 "but the gather layer currently only supports ",
333 "gathering from a 1-D or 2-D tensor");
337 #if defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 343 template <
typename TensorDataType, data_layout Layout, El::Device Device>
349 template <
typename TensorDataType, data_layout Layout, El::Device Device>
352 this->get_distconv_adapter_ptr() =
353 std::make_unique<gather_distconv_adapter<TensorDataType, Layout, Device>>(
357 template <
typename TensorDataType, data_layout Layout, El::Device Device>
358 const gather_distconv_adapter<TensorDataType, Layout, Device>&
362 const gather_distconv_adapter<TensorDataType, Layout, Device>&
>(
366 template <
typename TensorDataType, data_layout Layout, El::Device Device>
367 gather_distconv_adapter<TensorDataType, Layout, Device>&
370 return const_cast<gather_distconv_adapter<TensorDataType, Layout, Device>&
>(
372 .get_distconv_adapter());
379 template <
typename TensorDataType, data_layout Layout, El::Device Device>
380 void gather_distconv_adapter<TensorDataType, Layout, Device>::
385 for (
auto& d : this->m_prev_activations_dists) {
390 for (
auto& d : this->m_activations_dists) {
395 for (
auto& d : this->m_prev_error_signals_dists) {
400 for (
auto& d : this->m_error_signals_dists) {
407 template <
typename TensorDataType, data_layout Layout, El::Device Device>
408 void gather_distconv_adapter<TensorDataType, Layout, Device>::setup_layer(
409 size_t workspace_capacity)
413 make_unique<dc::Gather<TensorDataType>>(dc::get_backend());
420 template <
typename TensorDataType, data_layout Layout, El::Device Device>
421 dc::Shape gather_distconv_adapter<TensorDataType, Layout, Device>::
422 get_activations_local_shape(
int index)
const 433 output_shape[1] = values_shape[1];
437 template <
typename TensorDataType, data_layout Layout, El::Device Device>
438 void gather_distconv_adapter<TensorDataType, Layout, Device>::fp_compute()
446 template <
typename TensorDataType, data_layout Layout, El::Device Device>
447 void gather_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
450 m_gather_operator->backward(
457 #endif // LBANN_HAS_DISTCONV && LBANN_HAS_NVSHMEM 459 #ifndef LBANN_GATHER_LAYER_INSTANTIATE 460 #define PROTO_DEVICE(T, Device) \ 461 extern template class gather_layer<T, data_layout::DATA_PARALLEL, Device> 464 #endif // LBANN_GATHER_LAYER_INSTANTIATE 468 #endif // LBANN_LAYERS_TRANSFORM_GATHER_HPP_INCLUDED bool distconv_enabled() const
Indicate whether distconv is enabled.
virtual void setup_dims()
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
Gather values from specified tensor indices.
void mark_updated(const dc::Dist &d)
std::vector< int > get_input_dims(size_t input_index=0) const
Get input tensor dimensions.
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
bool can_run_inplace() const override
If True, the computation can run in-place (feeding each input activations tensor as the corresponding...
int get_backprop_requirements() const override
Returns the necessary tensors for computing backpropagation.
constexpr El::Device Device
virtual void setup_distributions(tensor_overlap_constraints &constraints)
OutputAbsDistMatrixType & get_prev_error_signals(int child_index=0)
InputAbsDistMatrixType & get_prev_activations(int parent_index=0)
const OutputAbsDistMatrixType & get_activations(const Layer &child) const override
std::string get_type() const override
Get the layer type's name.
gather_layer * copy() const override
Copy function. This function dynamically allocates memory for a layer instance and instantiates a cop...
void set_output_dims(std::vector< int > dims, size_t output_index=0)
Set output tensor dimensions.
data_layout get_data_layout() const override
Get data layout of the data tensors. We assume that the data layouts of the previous activations...
El::Device get_device_allocation() const override
Get the device allocation for the data tensors. We assume that the decice allocation of the previous ...
::distconv::tensor::Shape Shape
std::string get_name() const
Get the layer instance's name.
virtual void setup_layer(size_t workspace_capacity)
world_comm_ptr initialize(int &argc, char **&argv)
data_layout
Data layout that is optimized for different modes of parallelism.
const Layer & get_parent_layer(size_t index=0) const
std::vector< int > get_output_dims(size_t output_index=0) const
Get output tensor dimensions.
void mark_invariant(const dc::Dist &d)
void write_specific_proto(lbann_data::Layer &proto) const final
int m_expected_num_parent_layers
const InputAbsDistMatrixType & get_error_signals(const Layer &parent) const override
dc::TensorDev< OutputTensorDataType > TensorDevType
void setup_dims() override
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.