27 #ifndef LBANN_LAYERS_TRANSFORM_DISTCONV_GATHER 28 #define LBANN_LAYERS_TRANSFORM_DISTCONV_GATHER 32 #if defined(LBANN_HAS_NVSHMEM) && defined(LBANN_HAS_DISTCONV) 34 #include "distconv/base.hpp" 35 #include "distconv/tensor/tensor.hpp" 36 #include "distconv/tensor/tensor_mpi.hpp" 40 template <
typename Backend,
typename DataType>
46 Gather(Backend& backend) : m_backend(backend)
48 m_dist_scatter = std::make_unique<tensor::ScatterNVSHMEM<DataType>>(
49 m_backend.get_stream());
50 m_dist_gather = std::make_unique<tensor::GatherNVSHMEM<DataType>>(
51 m_backend.get_stream());
54 template <
typename Allocator>
56 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input,
57 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& indices,
58 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output);
60 template <
typename Allocator>
62 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output_grad,
63 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& indices,
64 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& values_grad,
65 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& indices_grad);
67 template <
typename Allocator>
69 setup(
const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input,
70 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& indices,
71 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output);
75 std::unique_ptr<tensor::GatherNVSHMEM<DataType>>
77 std::unique_ptr<tensor::ScatterNVSHMEM<DataType>>
82 #endif // defined(LBANN_HAS_DISTCONV) && defined(LBANN_HAS_NVSHMEM) 83 #endif // LBANN_LAYERS_TRANSFORM_DISTCONV_GATHER ::distconv::tensor::LocaleMPI LocaleMPI