LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
batch_normalization.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_HPP_INCLUDED
28 #define LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_HPP_INCLUDED
29 
31 #include "lbann/models/model.hpp"
33 #include "lbann/proto/layers.pb.h"
34 
35 #ifdef LBANN_HAS_DISTCONV
36 #include "distconv/dnn_backend/batchnorm.hpp"
37 #include "lbann/utils/distconv.hpp"
38 #endif
39 
40 namespace lbann {
41 
43 {
45  local,
47  node_local,
49  global
50 };
51 
52 #ifdef LBANN_HAS_DISTCONV
53 namespace dc {
55 using Backend = ::distconv::BackendDNNLib;
56 template <typename TensorDataType>
57 using BatchNormalization =
58  ::distconv::BatchNormalization<Backend, TensorDataType>;
59 } // namespace dc
60 
61 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
62 class batch_normalization_distconv_adapter
63  : public data_type_distconv_adapter<TensorDataType>
64 {
65 public:
66  using TensorDevType =
68  batch_normalization_distconv_adapter(Layer& layer)
70  {}
71  virtual ~batch_normalization_distconv_adapter() = default;
72  void setup_fp_tensors() override;
73  void setup_bp_tensors() override;
74  dc::Shape get_per_channel_stat_shape() const;
75  dc::Dist get_per_channel_stat_dist(const dc::Dist& input_dist) const;
76  void setup_layer(size_t workspace_capacity) override;
77  std::unique_ptr<TensorDevType>
78  setup_error_signals_i(int index) const override;
79  void fp_compute();
80  void bp_compute();
81 
82  TensorDevType m_mean;
83  TensorDevType m_var;
84  TensorDevType m_scale;
85  TensorDevType m_bias;
86  TensorDevType m_running_mean;
87  TensorDevType m_running_var;
88  TensorDevType m_mean_gradient;
89  TensorDevType m_var_gradient;
90  TensorDevType m_scale_gradient;
91  TensorDevType m_bias_gradient;
92  std::unique_ptr<dc::BatchNormalization<TensorDataType>> m_bn;
93 };
94 #endif // LBANN_HAS_DISTCONV
95 
109 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
110 class batch_normalization_layer : public data_type_layer<TensorDataType>
111 {
112  static_assert(T_layout == data_layout::DATA_PARALLEL,
113  "batch normalization only supports DATA_PARALLEL");
114 
115 public:
117 
120  using AbsDistMatrixType = El::AbstractDistMatrix<TensorDataType>;
121 
124 
127 
129 
130 private:
132  TensorDataType m_decay;
134  TensorDataType m_epsilon;
152  std::unordered_map<El::Int, El::Int> m_num_per_sum_cache;
153 
158  std::unique_ptr<AbsDistMatrixType> m_mean_and_var;
160  std::unique_ptr<AbsDistMatrixType> m_mean_v;
162  std::unique_ptr<AbsDistMatrixType> m_var_v;
167  std::unique_ptr<AbsDistMatrixType> m_mean_and_var_gradient;
169  std::unique_ptr<AbsDistMatrixType> m_mean_gradient_v;
171  std::unique_ptr<AbsDistMatrixType> m_var_gradient_v;
173  std::unique_ptr<AbsDistMatrixType> m_scale_gradient;
175  std::unique_ptr<AbsDistMatrixType> m_bias_gradient;
176 
177 public:
187  batch_normalization_layer(TensorDataType decay = 0.9,
188  TensorDataType epsilon = 1e-5,
189  int statistics_group_size = 1,
190  bool bessel_correction = true)
191  : data_type_layer<TensorDataType>(nullptr),
192  m_decay(decay),
193  m_epsilon(epsilon),
194  m_statistics_group_size(statistics_group_size),
195  m_bessel_correction(bessel_correction)
196  {
197 #ifdef LBANN_DETERMINISTIC
198  // Force global computation.
199  m_statistics_group_size = 0;
200 #endif
201  }
202 
204  : data_type_layer<TensorDataType>(other),
205  m_decay(other.m_decay),
206  m_epsilon(other.m_epsilon),
207  m_statistics_group_size(other.m_statistics_group_size),
208  m_bessel_correction(other.m_bessel_correction),
209  m_num_per_sum_cache(other.m_num_per_sum_cache),
210  m_mean_and_var(other.m_mean_and_var ? other.m_mean_and_var->Copy()
211  : nullptr),
212  m_mean_v(other.m_mean_v ? other.m_mean_v->Copy() : nullptr),
213  m_var_v(other.m_var_v ? other.m_var_v->Copy() : nullptr),
214  m_mean_and_var_gradient(other.m_mean_and_var_gradient
215  ? other.m_mean_and_var_gradient->Copy()
216  : nullptr),
217  m_mean_gradient_v(
218  other.m_mean_gradient_v ? other.m_mean_gradient_v->Copy() : nullptr),
219  m_var_gradient_v(other.m_var_gradient_v ? other.m_var_gradient_v->Copy()
220  : nullptr),
221  m_scale_gradient(other.m_scale_gradient ? other.m_scale_gradient->Copy()
222  : nullptr),
223  m_bias_gradient(other.m_bias_gradient ? other.m_bias_gradient->Copy()
224  : nullptr)
225  {}
226 
228  {
230  m_decay = other.m_decay;
231  m_epsilon = other.m_epsilon;
232  m_statistics_group_size = other.m_statistics_group_size;
233  m_bessel_correction = other.m_bessel_correction;
234  m_num_per_sum_cache = other.m_num_per_sum_cache;
235 
236  // Deep copy matrices
237  m_mean_and_var.reset(other.m_mean_and_var ? other.m_mean_and_var->Copy()
238  : nullptr);
239  m_mean_v.reset(other.m_mean_v ? other.m_mean_v->Copy() : nullptr);
240  m_var_v.reset(other.m_var_v ? other.m_var_v->Copy() : nullptr);
241  m_mean_and_var_gradient.reset(other.m_mean_and_var_gradient
242  ? other.m_mean_and_var_gradient->Copy()
243  : nullptr);
244  m_mean_gradient_v.reset(
245  other.m_mean_gradient_v ? other.m_mean_gradient_v->Copy() : nullptr);
246  m_var_gradient_v.reset(
247  other.m_var_gradient_v ? other.m_var_gradient_v->Copy() : nullptr);
248  m_scale_gradient.reset(
249  other.m_scale_gradient ? other.m_scale_gradient->Copy() : nullptr);
250  m_bias_gradient.reset(other.m_bias_gradient ? other.m_bias_gradient->Copy()
251  : nullptr);
252 
253  return *this;
254  }
255 
256  batch_normalization_layer* copy() const override
257  {
258  return new batch_normalization_layer(*this);
259  }
260  std::string get_type() const override { return "batch normalization"; }
261  data_layout get_data_layout() const override { return T_layout; }
262  El::Device get_device_allocation() const override { return Dev; }
263  bool can_run_inplace() const override { return false; }
264  int get_backprop_requirements() const override
265  {
267  }
268 
269  description get_description() const override
270  {
272  desc.add("Decay", m_decay);
273  desc.add("Epsilon", m_epsilon);
274  desc.add("Statistics group size", m_statistics_group_size);
275  desc.add("Bessel's correction", m_bessel_correction);
276  return desc;
277  }
278 
280 
282  template <typename ArchiveT>
283  void serialize(ArchiveT& ar);
284 
286 
287 protected:
289  void write_specific_proto(lbann_data::Layer& proto) const final;
290 
291  void setup_dims() override
292  {
294  this->set_output_dims(this->get_input_dims());
295  }
296 
297  void setup_data(size_t max_mini_batch_size) override
298  {
299  data_type_layer<TensorDataType>::setup_data(max_mini_batch_size);
300  const auto& output_dims = this->get_output_dims();
301  const auto& num_channels = output_dims[0];
302 
303  // Display warning if mini-batch size is small
304  const auto& output = this->get_activations();
305  const auto& mini_batch_size = output.Width();
306  const auto& local_mini_batch_size = mini_batch_size / output.DistSize();
307  if (m_statistics_group_size == 0 && mini_batch_size <= 4) {
308  if (output.DistRank() == 0) {
309  std::stringstream err;
310  err << "LBANN warning: " << get_type() << " layer \""
311  << this->get_name() << "\" "
312  << "is using global statistics and "
313  << "the mini-batch size (" << mini_batch_size << ") "
314  << "may be too small to get good statistics";
315  std::cerr << err.str() << std::endl;
316  }
317  }
318  else if (m_statistics_group_size != 0 &&
319  m_statistics_group_size * local_mini_batch_size <= 4) {
320  // This possibly underestimates the aggregation size for processors with
321  // smaller local mini-batch sizes.
322  if (output.DistRank() == 0) {
323  std::stringstream err;
324  err << "LBANN warning: " << get_type() << " layer \""
325  << this->get_name() << "\" "
326  << "is aggregating statistics over " << m_statistics_group_size
327  << "processors and the aggregated mini-batch size ("
328  << (m_statistics_group_size * local_mini_batch_size) << ") "
329  << "may be too small to get good statistics";
330  std::cerr << err.str() << std::endl;
331  }
332  }
333 
334  // Initialize default weights if none are provided
335  if (this->num_weights() > 4) {
336  std::stringstream err;
337  err << "attempted to setup layer \"" << this->m_name << "\" "
338  << "with an invalid number of weights";
339  LBANN_ERROR(err.str());
340  }
341  this->set_num_weights(4);
342  if (!this->has_weights(0)) {
343  auto w = std::make_shared<WeightsType>(*this->get_comm());
344  auto init = std::make_unique<constant_initializer<TensorDataType>>(
345  El::TypeTraits<TensorDataType>::One());
346  auto opt = this->m_model->template create_optimizer<TensorDataType>();
347  w->set_name(this->get_name() + "_scale");
348  w->set_initializer(std::move(init));
349  w->set_optimizer(std::move(opt));
350  this->set_weights(0, w);
351  this->m_model->add_weights(std::move(w));
352  }
353  if (!this->has_weights(1)) {
354  auto w = std::make_shared<WeightsType>(*this->get_comm());
355  auto init = std::make_unique<constant_initializer<TensorDataType>>(
356  El::TypeTraits<TensorDataType>::Zero());
357  auto opt = this->m_model->template create_optimizer<TensorDataType>();
358  w->set_name(this->get_name() + "_bias");
359  w->set_initializer(std::move(init));
360  w->set_optimizer(std::move(opt));
361  this->set_weights(1, w);
362  this->m_model->add_weights(std::move(w));
363  }
364  if (!this->has_weights(2)) {
365  auto w = std::make_shared<WeightsType>(*this->get_comm());
366  auto init = std::make_unique<constant_initializer<TensorDataType>>(
367  El::TypeTraits<TensorDataType>::Zero());
368  w->set_name(this->get_name() + "_running_mean");
369  w->set_initializer(std::move(init));
370  this->set_weights(2, w);
371  this->m_model->add_weights(std::move(w));
372  }
373  if (!this->has_weights(3)) {
374  auto w = std::make_shared<WeightsType>(*this->get_comm());
375  auto init = std::make_unique<constant_initializer<TensorDataType>>(
376  El::TypeTraits<TensorDataType>::One());
377  w->set_name(this->get_name() + "_running_variance");
378  w->set_initializer(std::move(init));
379  this->set_weights(3, w);
380  this->m_model->add_weights(std::move(w));
381  }
382 
383  // Setup weights
384  auto dist = this->get_prev_activations().DistData();
385  dist.colDist = El::STAR;
386  dist.rowDist = El::STAR;
387  size_t const num_weights = this->num_weights();
388  for (size_t ii = 0; ii < num_weights; ++ii) {
389  auto& w = this->get_weights(ii);
390  w.set_dims(num_channels);
391  w.set_matrix_distribution(dist);
392  }
393 
394  // Initialize matrices
395  m_mean_and_var.reset(new StarMatDT<TensorDataType, Dev>(*dist.grid));
396  m_mean_v.reset(new StarMatDT<TensorDataType, Dev>(*dist.grid));
397  m_var_v.reset(new StarMatDT<TensorDataType, Dev>(*dist.grid));
398  m_mean_and_var_gradient.reset(
399  new StarMatDT<TensorDataType, Dev>(*dist.grid));
400  m_mean_gradient_v.reset(new StarMatDT<TensorDataType, Dev>(*dist.grid));
401  m_var_gradient_v.reset(new StarMatDT<TensorDataType, Dev>(*dist.grid));
402  m_scale_gradient.reset(new StarMatDT<TensorDataType, Dev>(*dist.grid));
403  m_bias_gradient.reset(new StarMatDT<TensorDataType, Dev>(*dist.grid));
404  El::Zeros(*m_mean_and_var, num_channels, 2);
405  El::Zeros(*m_mean_and_var_gradient, num_channels, 2);
406  El::Zeros(*m_scale_gradient, num_channels, 1);
407  El::Zeros(*m_bias_gradient, num_channels, 1);
408 
409  // Initialize views.
410  El::View(*m_mean_v, *m_mean_and_var, El::ALL, El::IR(0, 1));
411  El::View(*m_var_v, *m_mean_and_var, El::ALL, El::IR(1, 2));
412  El::View(*m_mean_gradient_v,
413  *m_mean_and_var_gradient,
414  El::ALL,
415  El::IR(0, 1));
416  El::View(*m_var_gradient_v,
417  *m_mean_and_var_gradient,
418  El::ALL,
419  El::IR(1, 2));
420 
421  // Initialize freeze state
422  for (size_t ii = 0; ii < num_weights; ++ii) {
423  auto& w = this->get_weights(ii);
424  if (this->m_frozen) {
425  w.freeze();
426  }
427  else {
428  w.unfreeze();
429  }
430  }
431  for (size_t ii = 0; ii < num_weights; ++ii) {
432  auto& w = this->get_weights(ii);
433  if (w.is_frozen() != this->m_frozen) {
434  LBANN_ERROR((this->m_frozen ? "" : "un"),
435  "frozen layer "
436  "\"",
437  this->get_name(),
438  "\" has ",
439  (w.is_frozen() ? "" : "un"),
440  "frozen weights "
441  "\"",
442  w.get_name(),
443  "\"");
444  ;
445  }
446  }
447  }
448 
449  void fp_compute() override;
450  void bp_compute() override;
451 
452 #ifdef LBANN_HAS_DISTCONV
453  friend class batch_normalization_distconv_adapter<TensorDataType,
454  T_layout,
455  Dev>;
456 
457 protected:
458  bool is_distconv_supported() const override
459  {
460  return Dev == El::Device::GPU && T_layout == data_layout::DATA_PARALLEL;
461  }
462  void setup_distconv_adapter() override
463  {
464  this->get_distconv_adapter_ptr() = std::make_unique<
465  batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>>(
466  *this);
467  }
468  batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
469  get_distconv_adapter() override;
470  const batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
471  get_distconv_adapter() const override;
472 #endif // LBANN_HAS_DISTCONV
473 };
474 
475 LBANN_DEFINE_LAYER_BUILDER(batch_normalization);
476 
477 #ifndef LBANN_BATCH_NORMALIZATION_LAYER_INSTANTIATE
478 #define PROTO_DEVICE(T, Device) \
479  extern template class batch_normalization_layer<T, \
480  data_layout::DATA_PARALLEL, \
481  Device>
482 
484 #undef PROTO_DEVICE
485 #endif // LBANN_BATCH_NORMALIZATION_LAYER_INSTANTIATE
486 
487 } // namespace lbann
488 
489 #endif // LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_HPP_INCLUDED
bool m_bessel_correction
Add Bessel&#39;s correction to the batch normalization denominator.
El::DistMatrix< TensorDataType, El::STAR, El::STAR, El::ELEMENT, D > StarMatDT
Definition: base.hpp:145
virtual void setup_dims()
Setup tensor dimensions Called by the &#39;setup&#39; function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
batch_normalization_layer * copy() const override
Copy function. This function dynamically allocates memory for a layer instance and instantiates a cop...
std::unique_ptr< AbsDistMatrixType > m_var_gradient_v
#define LBANN_ERROR(...)
Definition: exception.hpp:37
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
Neural network tensor operation.
Definition: layer.hpp:285
Generates nicely formatted description messages.
Definition: description.hpp:49
El::Device get_device_allocation() const override
Get the device allocation for the data tensors. We assume that the decice allocation of the previous ...
batch_normalization_layer(const batch_normalization_layer &other)
virtual description get_description() const
Human-readable description.
constexpr El::Device Device
TensorDataType m_decay
Decay rate for running statistics.
Channel-wise batch normalization, including scale/bias.
batch_normalization_layer(TensorDataType decay=0.9, TensorDataType epsilon=1e-5, int statistics_group_size=1, bool bessel_correction=true)
Set up batch normalization.
std::unique_ptr< AbsDistMatrixType > m_mean_and_var
Current minibatch means and standard deviations.
std::unique_ptr< AbsDistMatrixType > m_mean_v
std::unique_ptr< AbsDistMatrixType > m_var_v
int get_backprop_requirements() const override
Returns the necessary tensors for computing backpropagation.
std::unique_ptr< AbsDistMatrixType > m_mean_and_var_gradient
Gradients w.r.t. means and standard deviations.
TensorDataType m_epsilon
Small number for numerical stability.
std::string get_type() const override
Get the layer type&#39;s name.
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
batch_normalization_stats_aggregation
::distconv::tensor::Shape Shape
std::unique_ptr< AbsDistMatrixType > m_scale_gradient
batch_normalization_layer & operator=(const batch_normalization_layer &other)
std::unique_ptr< AbsDistMatrixType > m_mean_gradient_v
data_layout
Data layout that is optimized for different modes of parallelism.
Definition: base.hpp:218
std::unique_ptr< AbsDistMatrixType > m_bias_gradient
std::unordered_map< El::Int, El::Int > m_num_per_sum_cache
int m_statistics_group_size
Size of process group for computing statistics.
void setup_dims() override
Setup tensor dimensions Called by the &#39;setup&#39; function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
void setup_data(size_t max_mini_batch_size) override
LBANN_DEFINE_LAYER_BUILDER(elu)
data_layout get_data_layout() const override
Get data layout of the data tensors. We assume that the data layouts of the previous activations...
::distconv::tensor::Distribution Dist
bool can_run_inplace() const override
If True, the computation can run in-place (feeding each input activations tensor as the corresponding...
data_type_layer & operator=(data_type_layer &&other)=default
void setup_data(size_t max_mini_batch_size) override
Setup layer data. Called by the &#39;setup&#39; function. Memory is allocated for distributed matrices...
description get_description() const override
Human-readable description.
dc::TensorDev< OutputTensorDataType > TensorDevType