27 #ifndef LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_HPP_INCLUDED 28 #define LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_HPP_INCLUDED 33 #include "lbann/proto/layers.pb.h" 35 #ifdef LBANN_HAS_DISTCONV 36 #include "distconv/dnn_backend/batchnorm.hpp" 52 #ifdef LBANN_HAS_DISTCONV 55 using Backend = ::distconv::BackendDNNLib;
56 template <
typename TensorDataType>
57 using BatchNormalization =
58 ::distconv::BatchNormalization<Backend, TensorDataType>;
61 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
62 class batch_normalization_distconv_adapter
68 batch_normalization_distconv_adapter(
Layer& layer)
71 virtual ~batch_normalization_distconv_adapter() =
default;
72 void setup_fp_tensors()
override;
73 void setup_bp_tensors()
override;
74 dc::Shape get_per_channel_stat_shape()
const;
76 void setup_layer(
size_t workspace_capacity)
override;
77 std::unique_ptr<TensorDevType>
78 setup_error_signals_i(
int index)
const override;
84 TensorDevType m_scale;
86 TensorDevType m_running_mean;
87 TensorDevType m_running_var;
88 TensorDevType m_mean_gradient;
89 TensorDevType m_var_gradient;
90 TensorDevType m_scale_gradient;
91 TensorDevType m_bias_gradient;
92 std::unique_ptr<dc::BatchNormalization<TensorDataType>> m_bn;
94 #endif // LBANN_HAS_DISTCONV 109 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
113 "batch normalization only supports DATA_PARALLEL");
188 TensorDataType epsilon = 1e-5,
189 int statistics_group_size = 1,
190 bool bessel_correction =
true)
194 m_statistics_group_size(statistics_group_size),
195 m_bessel_correction(bessel_correction)
197 #ifdef LBANN_DETERMINISTIC 199 m_statistics_group_size = 0;
205 m_decay(other.m_decay),
206 m_epsilon(other.m_epsilon),
207 m_statistics_group_size(other.m_statistics_group_size),
208 m_bessel_correction(other.m_bessel_correction),
209 m_num_per_sum_cache(other.m_num_per_sum_cache),
210 m_mean_and_var(other.m_mean_and_var ? other.m_mean_and_var->Copy()
212 m_mean_v(other.m_mean_v ? other.m_mean_v->Copy() : nullptr),
213 m_var_v(other.m_var_v ? other.m_var_v->Copy() : nullptr),
214 m_mean_and_var_gradient(other.m_mean_and_var_gradient
215 ? other.m_mean_and_var_gradient->Copy()
218 other.m_mean_gradient_v ? other.m_mean_gradient_v->Copy() : nullptr),
219 m_var_gradient_v(other.m_var_gradient_v ? other.m_var_gradient_v->Copy()
221 m_scale_gradient(other.m_scale_gradient ? other.m_scale_gradient->Copy()
223 m_bias_gradient(other.m_bias_gradient ? other.m_bias_gradient->Copy()
244 m_mean_gradient_v.reset(
246 m_var_gradient_v.reset(
248 m_scale_gradient.reset(
260 std::string
get_type()
const override {
return "batch normalization"; }
272 desc.add(
"Decay", m_decay);
273 desc.add(
"Epsilon", m_epsilon);
274 desc.add(
"Statistics group size", m_statistics_group_size);
275 desc.add(
"Bessel's correction", m_bessel_correction);
282 template <
typename ArchiveT>
289 void write_specific_proto(lbann_data::Layer& proto)
const final;
294 this->set_output_dims(this->get_input_dims());
300 const auto& output_dims = this->get_output_dims();
301 const auto& num_channels = output_dims[0];
304 const auto& output = this->get_activations();
305 const auto& mini_batch_size = output.Width();
306 const auto& local_mini_batch_size = mini_batch_size / output.DistSize();
307 if (m_statistics_group_size == 0 && mini_batch_size <= 4) {
308 if (output.DistRank() == 0) {
309 std::stringstream err;
310 err <<
"LBANN warning: " << get_type() <<
" layer \"" 311 << this->get_name() <<
"\" " 312 <<
"is using global statistics and " 313 <<
"the mini-batch size (" << mini_batch_size <<
") " 314 <<
"may be too small to get good statistics";
315 std::cerr << err.str() << std::endl;
318 else if (m_statistics_group_size != 0 &&
319 m_statistics_group_size * local_mini_batch_size <= 4) {
322 if (output.DistRank() == 0) {
323 std::stringstream err;
324 err <<
"LBANN warning: " << get_type() <<
" layer \"" 325 << this->get_name() <<
"\" " 326 <<
"is aggregating statistics over " << m_statistics_group_size
327 <<
"processors and the aggregated mini-batch size (" 328 << (m_statistics_group_size * local_mini_batch_size) <<
") " 329 <<
"may be too small to get good statistics";
330 std::cerr << err.str() << std::endl;
335 if (this->num_weights() > 4) {
336 std::stringstream err;
337 err <<
"attempted to setup layer \"" << this->m_name <<
"\" " 338 <<
"with an invalid number of weights";
341 this->set_num_weights(4);
342 if (!this->has_weights(0)) {
343 auto w = std::make_shared<WeightsType>(*this->get_comm());
344 auto init = std::make_unique<constant_initializer<TensorDataType>>(
345 El::TypeTraits<TensorDataType>::One());
346 auto opt = this->m_model->template create_optimizer<TensorDataType>();
347 w->set_name(this->get_name() +
"_scale");
348 w->set_initializer(std::move(init));
349 w->set_optimizer(std::move(opt));
350 this->set_weights(0, w);
351 this->m_model->add_weights(std::move(w));
353 if (!this->has_weights(1)) {
354 auto w = std::make_shared<WeightsType>(*this->get_comm());
355 auto init = std::make_unique<constant_initializer<TensorDataType>>(
356 El::TypeTraits<TensorDataType>::Zero());
357 auto opt = this->m_model->template create_optimizer<TensorDataType>();
358 w->set_name(this->get_name() +
"_bias");
359 w->set_initializer(std::move(init));
360 w->set_optimizer(std::move(opt));
361 this->set_weights(1, w);
362 this->m_model->add_weights(std::move(w));
364 if (!this->has_weights(2)) {
365 auto w = std::make_shared<WeightsType>(*this->get_comm());
366 auto init = std::make_unique<constant_initializer<TensorDataType>>(
367 El::TypeTraits<TensorDataType>::Zero());
368 w->set_name(this->get_name() +
"_running_mean");
369 w->set_initializer(std::move(init));
370 this->set_weights(2, w);
371 this->m_model->add_weights(std::move(w));
373 if (!this->has_weights(3)) {
374 auto w = std::make_shared<WeightsType>(*this->get_comm());
375 auto init = std::make_unique<constant_initializer<TensorDataType>>(
376 El::TypeTraits<TensorDataType>::One());
377 w->set_name(this->get_name() +
"_running_variance");
378 w->set_initializer(std::move(init));
379 this->set_weights(3, w);
380 this->m_model->add_weights(std::move(w));
384 auto dist = this->get_prev_activations().DistData();
385 dist.colDist = El::STAR;
386 dist.rowDist = El::STAR;
387 size_t const num_weights = this->num_weights();
388 for (
size_t ii = 0; ii < num_weights; ++ii) {
389 auto& w = this->get_weights(ii);
390 w.set_dims(num_channels);
391 w.set_matrix_distribution(dist);
398 m_mean_and_var_gradient.reset(
404 El::Zeros(*m_mean_and_var, num_channels, 2);
405 El::Zeros(*m_mean_and_var_gradient, num_channels, 2);
406 El::Zeros(*m_scale_gradient, num_channels, 1);
407 El::Zeros(*m_bias_gradient, num_channels, 1);
410 El::View(*m_mean_v, *m_mean_and_var, El::ALL, El::IR(0, 1));
411 El::View(*m_var_v, *m_mean_and_var, El::ALL, El::IR(1, 2));
412 El::View(*m_mean_gradient_v,
413 *m_mean_and_var_gradient,
416 El::View(*m_var_gradient_v,
417 *m_mean_and_var_gradient,
422 for (
size_t ii = 0; ii < num_weights; ++ii) {
423 auto& w = this->get_weights(ii);
424 if (this->m_frozen) {
431 for (
size_t ii = 0; ii < num_weights; ++ii) {
432 auto& w = this->get_weights(ii);
433 if (w.is_frozen() != this->m_frozen) {
439 (w.is_frozen() ?
"" :
"un"),
449 void fp_compute()
override;
450 void bp_compute()
override;
452 #ifdef LBANN_HAS_DISTCONV 453 friend class batch_normalization_distconv_adapter<TensorDataType,
458 bool is_distconv_supported()
const override 462 void setup_distconv_adapter()
override 464 this->get_distconv_adapter_ptr() = std::make_unique<
465 batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>>(
468 batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
469 get_distconv_adapter()
override;
470 const batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
471 get_distconv_adapter()
const override;
472 #endif // LBANN_HAS_DISTCONV 477 #ifndef LBANN_BATCH_NORMALIZATION_LAYER_INSTANTIATE 478 #define PROTO_DEVICE(T, Device) \ 479 extern template class batch_normalization_layer<T, \ 480 data_layout::DATA_PARALLEL, \ 485 #endif // LBANN_BATCH_NORMALIZATION_LAYER_INSTANTIATE 489 #endif // LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_HPP_INCLUDED
bool m_bessel_correction
Add Bessel's correction to the batch normalization denominator.
El::DistMatrix< TensorDataType, El::STAR, El::STAR, El::ELEMENT, D > StarMatDT
virtual void setup_dims()
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
batch_normalization_layer * copy() const override
Copy function. This function dynamically allocates memory for a layer instance and instantiates a cop...
std::unique_ptr< AbsDistMatrixType > m_var_gradient_v
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
Neural network tensor operation.
Generates nicely formatted description messages.
El::Device get_device_allocation() const override
Get the device allocation for the data tensors. We assume that the decice allocation of the previous ...
batch_normalization_layer(const batch_normalization_layer &other)
virtual description get_description() const
Human-readable description.
constexpr El::Device Device
TensorDataType m_decay
Decay rate for running statistics.
Channel-wise batch normalization, including scale/bias.
batch_normalization_layer(TensorDataType decay=0.9, TensorDataType epsilon=1e-5, int statistics_group_size=1, bool bessel_correction=true)
Set up batch normalization.
std::unique_ptr< AbsDistMatrixType > m_mean_and_var
Current minibatch means and standard deviations.
std::unique_ptr< AbsDistMatrixType > m_mean_v
std::unique_ptr< AbsDistMatrixType > m_var_v
int get_backprop_requirements() const override
Returns the necessary tensors for computing backpropagation.
std::unique_ptr< AbsDistMatrixType > m_mean_and_var_gradient
Gradients w.r.t. means and standard deviations.
TensorDataType m_epsilon
Small number for numerical stability.
std::string get_type() const override
Get the layer type's name.
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
batch_normalization_stats_aggregation
::distconv::tensor::Shape Shape
std::unique_ptr< AbsDistMatrixType > m_scale_gradient
batch_normalization_layer & operator=(const batch_normalization_layer &other)
std::unique_ptr< AbsDistMatrixType > m_mean_gradient_v
data_layout
Data layout that is optimized for different modes of parallelism.
std::unique_ptr< AbsDistMatrixType > m_bias_gradient
std::unordered_map< El::Int, El::Int > m_num_per_sum_cache
int m_statistics_group_size
Size of process group for computing statistics.
void setup_dims() override
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
void setup_data(size_t max_mini_batch_size) override
LBANN_DEFINE_LAYER_BUILDER(elu)
data_layout get_data_layout() const override
Get data layout of the data tensors. We assume that the data layouts of the previous activations...
::distconv::tensor::Distribution Dist
bool can_run_inplace() const override
If True, the computation can run in-place (feeding each input activations tensor as the corresponding...
data_type_layer & operator=(data_type_layer &&other)=default
void setup_data(size_t max_mini_batch_size) override
Setup layer data. Called by the 'setup' function. Memory is allocated for distributed matrices...
description get_description() const override
Human-readable description.
dc::TensorDev< OutputTensorDataType > TensorDevType