LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
batch_normalization_impl.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_IMPL_HPP_INCLUDED
28 #define LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_IMPL_HPP_INCLUDED
29 
31 
32 #ifdef LBANN_HAS_DISTCONV
34 #endif // LBANN_HAS_DISTCONV
35 
36 namespace lbann {
37 
38 template <typename T, data_layout L, El::Device D>
40  lbann_data::Layer& proto) const
41 {
42  proto.set_datatype(proto::ProtoDataType<T>);
43  auto* msg = proto.mutable_batch_normalization();
44  msg->set_decay(m_decay);
45  msg->set_epsilon(m_epsilon);
46  msg->set_statistics_group_size(m_statistics_group_size);
47 }
48 
49 #ifdef LBANN_HAS_DISTCONV
50 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
51 const batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
53  const
54 {
55  return dynamic_cast<
56  const batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&>(
58 }
59 
60 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
61 batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&
63 {
64  return const_cast<
65  batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>&>(
66  static_cast<
68  .get_distconv_adapter());
69 }
70 
71 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
72 dc::Shape batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
73  get_per_channel_stat_shape() const
74 {
75  auto& l = dynamic_cast<
77  this->layer());
78  const int num_channels = this->get_activations_shape()[dc::get_channel_dim()];
79  // Sanity check that the shared tensors have the correct shape
80  assert_ne(num_channels, 0);
81  assert_eq(l.m_mean_and_var->Matrix().Width() *
82  l.m_mean_and_var->Matrix().Height(),
83  num_channels * 2);
84  dc::Shape per_channel_stat_shape(dc::get_num_dims(l), 1);
85  per_channel_stat_shape[dc::get_channel_dim()] = num_channels;
86  return per_channel_stat_shape;
87 }
88 
89 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
90 dc::Dist batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
91  get_per_channel_stat_dist(const dc::Dist& input_dist) const
92 {
93  auto shared_dist = dc::Dist::make_distribution(input_dist.get_locale_shape());
94  auto split_shape = input_dist.get_split_shape();
95  // set all dimensions to be 1 except for the channel dimension
96  auto pc = split_shape[-2];
97  // set all elements to 1
98  split_shape = 1;
99  split_shape[-2] = pc;
100  shared_dist.set_split_shape(split_shape);
101 
102  return shared_dist;
103 }
104 
105 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
106 void batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
107  setup_fp_tensors()
108 {
110 
111  auto& l =
113  this->layer());
114  const auto& input_dist = this->get_prev_activations_dist();
115 
116  const auto per_channel_stat_shape = get_per_channel_stat_shape();
117  const auto shared_dist = get_per_channel_stat_dist(input_dist);
118 
119  const dc::LocaleMPI loc(dc::get_mpi_comm(), false);
120 
121  // mean
122  m_mean = TensorDevType(per_channel_stat_shape, loc, shared_dist);
123  assert0(dc::tensor::View(m_mean, l.m_mean_v->Buffer()));
124  // var
125  m_var = TensorDevType(per_channel_stat_shape, loc, shared_dist);
126  assert0(dc::tensor::View(m_var, l.m_var_v->Buffer()));
127  // scale: view to weights[0]
128  m_scale = TensorDevType(per_channel_stat_shape, loc, shared_dist);
129  // bias: view to weights[1]
130  m_bias = TensorDevType(per_channel_stat_shape, loc, shared_dist);
131  // running_mean: view to weights[2]
132  m_running_mean = TensorDevType(per_channel_stat_shape, loc, shared_dist);
133  // running_var: view to weights[3]
134  m_running_var = TensorDevType(per_channel_stat_shape, loc, shared_dist);
135 }
136 
137 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
138 void batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
139  setup_bp_tensors()
140 {
142 
143  const auto& prev_error_signal_dist = this->get_prev_error_signals_dist();
144  auto& l =
146  this->layer());
147 
148  const auto per_channel_stat_shape = get_per_channel_stat_shape();
149  const auto shared_dist = get_per_channel_stat_dist(prev_error_signal_dist);
150 
151  const dc::LocaleMPI loc(dc::get_mpi_comm(), false);
152 
153  // scale_gradient
154  m_scale_gradient = TensorDevType(per_channel_stat_shape, loc, shared_dist);
155  assert0(dc::tensor::View(m_scale_gradient, l.m_scale_gradient->Buffer()));
156  // bias_gradient
157  m_bias_gradient = TensorDevType(per_channel_stat_shape, loc, shared_dist);
158  assert0(dc::tensor::View(m_bias_gradient, l.m_bias_gradient->Buffer()));
159  // mean_gradient
160  m_mean_gradient = TensorDevType(per_channel_stat_shape, loc, shared_dist);
161  assert0(dc::tensor::View(m_mean_gradient, l.m_mean_gradient_v->Buffer()));
162  // var_gradient
163  m_var_gradient = TensorDevType(per_channel_stat_shape, loc, shared_dist);
164  assert0(dc::tensor::View(m_var_gradient, l.m_var_gradient_v->Buffer()));
165 }
166 
167 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
168 void batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
169  setup_layer(size_t workspace_capacity)
170 {
171  auto& l =
173  this->layer());
174  bool global_stats;
175  if (l.m_statistics_group_size == 0) {
176  global_stats = true;
177  }
178  else if (l.m_statistics_group_size == 1) {
179  global_stats = false;
180  }
181  else {
182  LBANN_ERROR("statistics_group_size must be either 0 or 1 for now.");
183  }
184 
185  m_bn = std::make_unique<dc::BatchNormalization<TensorDataType>>(
186  dc::get_backend(),
187  dc::get_num_dims(l),
188  l.m_decay,
189  l.m_epsilon,
190  global_stats);
191 }
192 
193 template <typename TensorDataType, data_layout T_layout, El::Device Dev>
194 std::unique_ptr<typename batch_normalization_distconv_adapter<TensorDataType,
195  T_layout,
196  Dev>::TensorDevType>
197 batch_normalization_distconv_adapter<TensorDataType, T_layout, Dev>::
198  setup_error_signals_i(int index) const
199 {
200  assert_eq(index, 0);
201  auto& parent_layer = this->layer().get_parent_layer();
202  if (parent_layer.get_backprop_requirements() & ACTIVATIONS
203  || parent_layer.get_type() == "identity"
204  || this->get_prev_activations_dist() != this->get_error_signals_dist()
205  || std::getenv("DISTCONV_DISABLE_MEM_OPT")) {
207  }
208  const auto& prev_activations = this->get_prev_activations(0);
209  return std::make_unique<TensorDevType>(prev_activations);
210 }
211 #endif // LBANN_HAS_DISTCONV
212 
213 } // namespace lbann
214 
215 #endif // LBANN_LAYER_REGULARIZER_BATCH_NORMALIZATION_IMPL_HPP_INCLUDED
::distconv::tensor::LocaleMPI LocaleMPI
#define LBANN_ERROR(...)
Definition: exception.hpp:37
virtual std::unique_ptr< InputTensorDevType > setup_error_signals_i(int index) const
Channel-wise batch normalization, including scale/bias.
::distconv::tensor::Shape Shape
virtual void setup_bp_tensors()
::distconv::tensor::Distribution Dist
virtual void setup_fp_tensors()
void write_specific_proto(lbann_data::Layer &proto) const final