26 #ifndef LBANN_LAYERS_MATH_DISTCONV_MATMUL 27 #define LBANN_LAYERS_MATH_DISTCONV_MATMUL 28 #include "distconv/base.hpp" 29 #include "distconv/tensor/tensor.hpp" 30 #include "distconv/tensor/tensor_mpi.hpp" 33 #ifdef LBANN_HAS_DISTCONV 35 template <
typename Backend,
typename DataType>
41 MatMul(Backend& backend) : m_be(backend){};
43 template <
typename Allocator>
45 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input_0,
46 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input_1,
47 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output,
48 const bool transpose_0,
49 const bool transpose_1);
51 template <
typename Allocator>
53 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input_0,
54 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input_1,
55 const tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& output_grad,
56 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input_grad_0,
57 tensor::Tensor<DataType, tensor::LocaleMPI, Allocator>& input_grad_1,
58 const bool transpose_0,
59 const bool transpose_1);
65 template <
typename DataType,
typename locale,
typename Allocator>
67 const tensor::Tensor<DataType, locale, Allocator>& input_0,
68 const tensor::Tensor<DataType, locale, Allocator>& input_1,
73 auto output_local_shape = input_0.get_local_shape();
75 auto inp_0_dims = input_0.get_local_shape();
76 auto inp_1_dims = input_1.get_local_shape();
79 output_local_shape[0] = transpose_2 ? inp_1_dims[1] : inp_1_dims[0];
80 output_local_shape[1] = transpose_1 ? inp_0_dims[0] : inp_0_dims[1];
82 return output_local_shape;
85 extern template class MatMul<::distconv::BackendDNNLib, float>;
86 extern template class MatMul<::distconv::BackendDNNLib, double>;
89 #endif // LBANN_HAS_DISTCONV 90 #endif // LBANN_LAYERS_MATH_DISTCONV_MATMUL ::distconv::tensor::LocaleMPI LocaleMPI
::distconv::tensor::Shape Shape