27 #ifndef LBANN_LAYERS_LOSS_MEAN_SQUARED_ERROR_HPP_INCLUDED 28 #define LBANN_LAYERS_LOSS_MEAN_SQUARED_ERROR_HPP_INCLUDED 32 #include "lbann/proto/layers.pb.h" 34 #ifdef LBANN_HAS_DISTCONV 35 #include "distconv/dnn_backend/mean_squared_error.hpp" 41 #ifdef LBANN_HAS_DISTCONV 43 using Backend = ::distconv::BackendDNNLib;
44 using MeanSquaredError = ::distconv::MeanSquaredError<Backend>;
47 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
48 class mean_squared_error_distconv_adapter
49 :
public data_type_distconv_adapter<TensorDataType>
54 mean_squared_error_distconv_adapter(Layer& layer)
55 : data_type_distconv_adapter<TensorDataType>(layer)
57 virtual ~mean_squared_error_distconv_adapter() =
default;
58 void setup_distributions(tensor_overlap_constraints& constraints)
override;
59 dc::Shape get_prev_activations_shape(
int index)
const override;
60 dc::Shape get_activations_shape(
int index)
const override;
61 dc::Shape get_activations_local_shape(
int index)
const override;
62 void setup_layer(
size_t workspace_capacity)
override;
63 std::unique_ptr<dc::MeanSquaredError> m_mean_squared_error;
65 #endif // LBANN_HAS_DISTCONV 75 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
91 this->m_expected_num_parent_layers = 2;
115 template <
typename ArchiveT>
120 std::string
get_type()
const override {
return "mean squared error"; }
129 #ifdef LBANN_HAS_ONNX 130 void fill_onnx_node(onnx::GraphProto& graph)
const override;
131 #endif // LBANN_HAS_ONNX 133 void setup_dims()
override;
135 void setup_data(
size_t max_mini_batch_size)
override;
137 void fp_compute()
override;
139 void bp_compute()
override;
143 void write_specific_proto(lbann_data::Layer& proto)
const final;
145 friend class cereal::access;
150 void local_fp_compute();
152 void local_bp_compute();
157 #ifdef LBANN_HAS_DISTCONV 158 friend class mean_squared_error_distconv_adapter<TensorDataType,
163 bool is_distconv_supported()
const override 168 void setup_distconv_adapter()
override 170 this->get_distconv_adapter_ptr() = std::make_unique<
171 mean_squared_error_distconv_adapter<TensorDataType, T_layout, Dev>>(
175 mean_squared_error_distconv_adapter<TensorDataType, T_layout, Dev>&
176 get_distconv_adapter()
override;
177 const mean_squared_error_distconv_adapter<TensorDataType, T_layout, Dev>&
178 get_distconv_adapter()
const override;
180 void fp_compute_distconv()
182 assert_always(this->distconv_enabled());
183 get_distconv_adapter().m_mean_squared_error->forward(
184 this->get_distconv_adapter().get_prev_activations(0),
185 this->get_distconv_adapter().get_prev_activations(1),
186 this->get_distconv_adapter().get_activations());
189 void bp_compute_distconv()
191 assert_always(this->distconv_enabled());
192 get_distconv_adapter().m_mean_squared_error->backward(
193 this->get_distconv_adapter().get_prev_activations(0),
194 this->get_distconv_adapter().get_prev_activations(1),
195 this->get_distconv_adapter().get_prev_error_signals(0),
196 this->get_distconv_adapter().get_error_signals(0),
197 this->get_distconv_adapter().get_error_signals(1));
199 #endif // LBANN_HAS_DISTCONV 202 #ifndef LBANN_MEAN_SQUARED_ERROR_LAYER_INSTANTIATE 204 #define PROTO_DEVICE(T, Device) \ 205 extern template class mean_squared_error_layer<T, \ 206 data_layout::DATA_PARALLEL, \ 208 extern template class mean_squared_error_layer<T, \ 209 data_layout::MODEL_PARALLEL, \ 215 #endif // LBANN_MEAN_SQUARED_ERROR_LAYER_INSTANTIATE 219 #endif // LBANN_LAYERS_LOSS_MEAN_SQUARED_ERROR_HPP_INCLUDED mean_squared_error_layer(const mean_squared_error_layer &other)
mean_squared_error_layer * copy() const override
Copy function. This function dynamically allocates memory for a layer instance and instantiates a cop...
std::unique_ptr< AbsDistMatrixType > m_workspace
mean_squared_error_layer & operator=(const mean_squared_error_layer &other)
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
int get_backprop_requirements() const override
Returns the necessary tensors for computing backpropagation.
bool can_run_inplace() const override
If True, the computation can run in-place (feeding each input activations tensor as the corresponding...
std::string get_type() const override
Get the layer type's name.
constexpr El::Device Device
mean_squared_error_layer(lbann_comm *comm)
mean_squared_error_layer()
El::Device get_device_allocation() const override
Get the device allocation for the data tensors. We assume that the decice allocation of the previous ...
data_layout get_data_layout() const override
Get data layout of the data tensors. We assume that the data layouts of the previous activations...
::distconv::tensor::Shape Shape
data_layout
Data layout that is optimized for different modes of parallelism.
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
data_type_layer & operator=(data_type_layer &&other)=default
dc::TensorDev< OutputTensorDataType > TensorDevType