27 #ifndef LBANN_LAYER_MATH_MATMUL_HPP_INCLUDED 28 #define LBANN_LAYER_MATH_MATMUL_HPP_INCLUDED 32 #ifdef LBANN_HAS_DISTCONV 35 #endif // LBANN_HAS_DISTCONV 39 #ifdef LBANN_HAS_DISTCONV 41 using Backend = ::distconv::BackendDNNLib;
42 template <
typename TensorDataType>
43 using MatMul = ::distconv::MatMul<Backend, TensorDataType>;
46 template <
typename TensorDataType, data_layout Layout, El::Device Device>
47 class matmul_distconv_adapter
48 :
public data_type_distconv_adapter<TensorDataType>
54 matmul_distconv_adapter(Layer& layer)
55 : data_type_distconv_adapter<TensorDataType>(layer)
58 virtual ~matmul_distconv_adapter() =
default;
59 void setup_distributions(tensor_overlap_constraints& constraints)
override;
60 void setup_layer(
size_t workspace_capacity)
override;
63 dc::Shape get_activations_local_shape(
int index = 0)
const override;
64 std::unique_ptr<dc::MatMul<TensorDataType>> m_matmul_operator;
67 #endif // LBANN_HAS_DISTCONV 79 template <
typename TensorDataType,
85 "matmul_layer only supports " 86 "data-parallel data layout");
90 bool transpose_a =
false,
91 bool transpose_b =
false);
96 std::string get_type()
const override;
98 El::Device get_device_allocation()
const override;
107 template <
typename ArchiveT>
112 void write_specific_proto(lbann_data::Layer& proto)
const final;
114 friend class cereal::access;
117 void setup_dims()
override;
118 void fp_compute()
override;
119 void bp_compute()
override;
121 #ifdef LBANN_HAS_DISTCONV 122 friend class matmul_distconv_adapter<TensorDataType, Layout, Device>;
125 void setup_distconv_adapter()
override;
126 bool is_distconv_supported()
const override;
127 matmul_distconv_adapter<TensorDataType, Layout, Device>&
128 get_distconv_adapter()
override;
129 const matmul_distconv_adapter<TensorDataType, Layout, Device>&
130 get_distconv_adapter()
const override;
131 #endif // LBANN_HAS_DISTCONV 141 template <
typename U>
143 template <
typename U>
151 template <
typename TensorDataType, data_layout Layout, El::Device Device>
156 m_transpose_a{transpose_a},
162 template <
typename TensorDataType, data_layout Layout, El::Device Device>
169 template <
typename TensorDataType, data_layout Layout, El::Device Device>
172 return "matrix multiply";
175 template <
typename TensorDataType, data_layout Layout, El::Device Device>
182 template <
typename TensorDataType, data_layout Layout, El::Device Device>
189 template <
typename TensorDataType, data_layout Layout, El::Device Device>
203 #ifndef LBANN_MATMUL_LAYER_INSTANTIATE 205 #define PROTO_DEVICE(T, Device) \ 206 extern template class matmul_layer<T, data_layout::DATA_PARALLEL, Device> 211 #endif // LBANN_MATMUL_LAYER_INSTANTIATE 215 #endif // LBANN_LAYER_MATH_MATMUL_HPP_INCLUDED
matmul_layer * copy() const override
Copy function. This function dynamically allocates memory for a layer instance and instantiates a cop...
data_layout get_data_layout() const override
Get data layout of the data tensors. We assume that the data layouts of the previous activations...
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
Generates nicely formatted description messages.
virtual description get_description() const
Human-readable description.
constexpr El::Device Device
El::Device get_device_allocation() const override
Get the device allocation for the data tensors. We assume that the decice allocation of the previous ...
description get_description() const override
Human-readable description.
int get_backprop_requirements() const override
Returns the necessary tensors for computing backpropagation.
::distconv::tensor::Shape Shape
data_layout
Data layout that is optimized for different modes of parallelism.
bool can_run_inplace() const override
If True, the computation can run in-place (feeding each input activations tensor as the corresponding...
std::string get_type() const override
Get the layer type's name.
int m_expected_num_parent_layers
dc::TensorDev< OutputTensorDataType > TensorDevType