27 #ifndef LBANN_LAYERS_LEARNING_DISTCONV_LAYERS 28 #define LBANN_LAYERS_LEARNING_DISTCONV_LAYERS 29 #include "distconv/base.hpp" 30 #include "distconv/tensor/tensor.hpp" 31 #include "distconv/tensor/tensor_mpi.hpp" 34 #ifdef LBANN_HAS_DISTCONV 36 template <
typename Backend,
typename DataType>
37 class ChannelwiseFullyConnected
42 ChannelwiseFullyConnected(Backend& backend) : m_be(backend){};
44 template <
typename Allocator>
47 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input,
48 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& linearity,
49 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output);
51 template <
typename Allocator>
53 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& bias,
54 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output);
56 template <
typename Allocator>
57 int backward_wrt_input(
59 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output_grad,
60 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& linearity,
61 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input_grad);
63 template <
typename Allocator>
64 int backward_wrt_weight(
67 DataType gradient_scale,
68 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input,
69 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output_grad,
70 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& linearity_grad);
72 template <
typename Allocator>
73 int backward_wrt_bias(
74 DataType gradient_scale,
76 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output_grad,
77 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& bias_grad);
83 template <
typename DataType,
typename locale,
typename Allocator>
85 const tensor::Tensor<DataType, locale, Allocator>& input,
86 const int_vector& linearity_dims,
94 auto output_local_shape = input.get_local_shape();
95 output_local_shape[0] = transpose ? linearity_dims[1] : linearity_dims[0];
96 return output_local_shape;
98 extern template class ChannelwiseFullyConnected<::distconv::BackendDNNLib,
100 extern template class ChannelwiseFullyConnected<::distconv::BackendDNNLib,
104 #endif // LBANN_HAS_DISTCONV 105 #endif // LBANN_LAYERS_LEARNING_DISTCONV_LAYERS std::map< El::Int, std::set< El::Int > > transpose(const std::set< El::Int > &nodes, const std::map< El::Int, std::set< El::Int >> &edges)
::distconv::tensor::LocaleMPI LocaleMPI
::distconv::tensor::Shape Shape