27 #ifndef LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_IMPL_HPP_INCLUDED 28 #define LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_IMPL_HPP_INCLUDED 32 #ifdef LBANN_HAS_DISTCONV 34 #endif // LBANN_HAS_DISTCONV 38 template <
typename T, data_layout L, El::Device D>
40 lbann_data::Layer& proto)
const 42 proto.set_datatype(proto::ProtoDataType<T>);
43 auto* msg = proto.mutable_batch_normalization();
44 msg->set_decay(m_decay);
45 msg->set_epsilon(m_epsilon);
46 msg->set_statistics_group_size(m_statistics_group_size);
49 #ifdef LBANN_HAS_DISTCONV 50 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
51 const batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
56 const batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
>(
60 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
61 batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
65 batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
>(
68 .get_distconv_adapter());
71 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
72 dc::Shape batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
73 get_per_channel_stat_shape()
const 75 auto& l =
dynamic_cast< 78 const int num_channels = this->get_activations_shape()[dc::get_channel_dim()];
80 assert_ne(num_channels, 0);
81 assert_eq(l.m_mean_and_var->Matrix().Width() *
82 l.m_mean_and_var->Matrix().Height(),
84 dc::Shape per_channel_stat_shape(dc::get_num_dims(l), 1);
85 per_channel_stat_shape[dc::get_channel_dim()] = num_channels;
86 return per_channel_stat_shape;
89 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
90 dc::Dist batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
91 get_per_channel_stat_dist(
const dc::Dist& input_dist)
const 93 auto shared_dist = dc::Dist::make_distribution(input_dist.get_locale_shape());
94 auto split_shape = input_dist.get_split_shape();
96 auto pc = split_shape[-2];
100 shared_dist.set_split_shape(split_shape);
105 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
106 void batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
114 const auto& input_dist = this->get_prev_activations_dist();
116 const auto per_channel_stat_shape = get_per_channel_stat_shape();
117 const auto shared_dist = get_per_channel_stat_dist(input_dist);
122 m_mean = TensorDevType(per_channel_stat_shape, loc, shared_dist);
123 assert0(dc::tensor::View(m_mean, l.m_mean_v->Buffer()));
125 m_var = TensorDevType(per_channel_stat_shape, loc, shared_dist);
126 assert0(dc::tensor::View(m_var, l.m_var_v->Buffer()));
128 m_scale = TensorDevType(per_channel_stat_shape, loc, shared_dist);
130 m_bias = TensorDevType(per_channel_stat_shape, loc, shared_dist);
132 m_running_mean = TensorDevType(per_channel_stat_shape, loc, shared_dist);
134 m_running_var = TensorDevType(per_channel_stat_shape, loc, shared_dist);
137 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
138 void batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
143 const auto& prev_error_signal_dist = this->get_prev_error_signals_dist();
148 const auto per_channel_stat_shape = get_per_channel_stat_shape();
149 const auto shared_dist = get_per_channel_stat_dist(prev_error_signal_dist);
154 m_scale_gradient = TensorDevType(per_channel_stat_shape, loc, shared_dist);
155 assert0(dc::tensor::View(m_scale_gradient, l.m_scale_gradient->Buffer()));
157 m_bias_gradient = TensorDevType(per_channel_stat_shape, loc, shared_dist);
158 assert0(dc::tensor::View(m_bias_gradient, l.m_bias_gradient->Buffer()));
160 m_mean_gradient = TensorDevType(per_channel_stat_shape, loc, shared_dist);
161 assert0(dc::tensor::View(m_mean_gradient, l.m_mean_gradient_v->Buffer()));
163 m_var_gradient = TensorDevType(per_channel_stat_shape, loc, shared_dist);
164 assert0(dc::tensor::View(m_var_gradient, l.m_var_gradient_v->Buffer()));
167 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
168 void batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
169 setup_layer(
size_t workspace_capacity)
175 if (l.m_statistics_group_size == 0) {
178 else if (l.m_statistics_group_size == 1) {
179 global_stats =
false;
182 LBANN_ERROR(
"statistics_group_size must be either 0 or 1 for now.");
185 m_bn = std::make_unique<dc::BatchNormalization<TensorDataType>>(
193 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
194 std::unique_ptr<
typename batch_normalization_distconv_adapter<TensorDataType,
197 batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
198 setup_error_signals_i(
int index)
const 201 auto& parent_layer = this->layer().get_parent_layer();
202 if (parent_layer.get_backprop_requirements() &
ACTIVATIONS 203 || parent_layer.get_type() ==
"identity" 204 || this->get_prev_activations_dist() != this->get_error_signals_dist()
205 || std::getenv(
"DISTCONV_DISABLE_MEM_OPT")) {
208 const auto& prev_activations = this->get_prev_activations(0);
209 return std::make_unique<TensorDevType>(prev_activations);
211 #endif // LBANN_HAS_DISTCONV 215 #endif // LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_IMPL_HPP_INCLUDED ::distconv::tensor::LocaleMPI LocaleMPI
virtual std::unique_ptr< InputTensorDevType > setup_error_signals_i(int index) const
Channel-wise batch normalization, including scale/bias.
::distconv::tensor::Shape Shape
virtual void setup_bp_tensors()
::distconv::tensor::Distribution Dist
virtual void setup_fp_tensors()
void write_specific_proto(lbann_data::Layer &proto) const final