LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
kfac.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_EXECUTION_ALGORITHMS_KFAC_HPP_INCLUDED
27 #define LBANN_EXECUTION_ALGORITHMS_KFAC_HPP_INCLUDED
28 
37 
38 #include <google/protobuf/message.h>
39 #include <memory>
40 
41 namespace lbann {
42 
59 class KFAC final : public TrainingAlgorithm
60 {
61 
62 public:
65 
66 public:
68 
71  KFAC(std::string name,
72  std::unique_ptr<TermCriteriaType> stop,
73  std::vector<double> damping_act_params,
74  std::vector<double> damping_err_params,
75  std::vector<double> damping_bn_act_params,
76  std::vector<double> damping_bn_err_params,
77  std::vector<bool> kfac_use_interval,
78  size_t damping_warmup_steps,
79  double kronecker_decay,
80  bool print_time,
81  bool print_matrix,
82  bool print_matrix_summary,
83  bool use_pi,
84  std::vector<size_t> update_intervals,
85  size_t update_interval_steps,
86  kfac::kfac_inverse_strategy inverse_strategy,
87  std::vector<std::string> disable_layers,
88  double learning_rate_factor,
89  double learning_rate_factor_gru,
90  size_t compute_interval,
91  bool distribute_precondition_compute,
92  bool use_eigen_decomposition,
93  bool enable_copy_errors,
94  bool enable_copy_activations);
95 
96  ~KFAC() noexcept = default;
97  KFAC(KFAC const& other) = delete;
98  KFAC& operator=(const KFAC& other) = delete;
99  KFAC(KFAC&& other) = default;
100  KFAC& operator=(KFAC&& other) = default;
102 
103  std::string get_type() const final;
106 
107 
115  void apply(ExecutionContext& context,
116  model& m,
117  data_coordinator& dc,
118  execution_mode mode) final;
120  void train(ExeContextType& c,
121  model& model,
122  data_coordinator& dc,
123  TermCriteriaType const& term);
125 
126 #ifdef LBANN_HAS_GPU
127  constexpr static const El::Device Device = El::Device::GPU;
128 #else
129  constexpr static const El::Device Device = El::Device::CPU;
130 #endif // LBANN_HAS_GPU
131 
133  constexpr static const double damping_0_default = 3e-2;
134  constexpr static const size_t damping_warmup_steps_default = 100;
135 
137  constexpr static const double kronecker_decay_default = 0.99;
138 
140  constexpr static const bool prof_sync = true;
141  constexpr static const int prof_color = 0;
142 
143 protected:
145  bool train_mini_batch(ExeContextType& c, model& model, data_coordinator& dc);
146 
148 
150  void do_train_begin_cbs(model& model);
152  void do_train_end_cbs(model& model);
154  void do_epoch_begin_cbs(model& model);
156  void do_epoch_end_cbs(model& model);
158  void do_batch_begin_cbs(model& model);
160  void do_batch_end_cbs(model& model);
162 
167 
169  lbann_comm* comm);
171  lbann_comm* comm);
172 
173 private:
174 #if 1
175 
176  void on_forward_prop_end(ExeContextType& context, model& model);
177  void on_backward_prop_end(ExeContextType& context, model& model);
178 
179 #else
180 
181  void compute_kronecker_factors(ExeContextType& context, model& model);
182 
184  void invert_kronecker_factors(ExeContextType& context, model& model);
185 
187  void precondition_gradients(ExeContextType& context, model& model);
188 #endif // 0
189 
191  void sync_weights_model(model& model, lbann_comm* comm);
192  void start_sync_weights_async(model& model, lbann_comm* comm);
193  void end_sync_weights_async(model& model, lbann_comm* comm);
194 
195  void start_old_async_weights_model(model& model,
196  lbann_comm* comm,
197  ExeContextType& context);
198  void end_old_async_weights_model(model& model,
199  lbann_comm* comm,
200  ExeContextType& context);
202  ExeContextType& context);
203 
205  std::unique_ptr<TermCriteriaType> m_stopping_criteria;
206 
212 
215 
218 
221 
224  bool m_use_pi;
225 
230  std::vector<size_t> m_update_intervals;
231 
234 
237 
239  std::vector<std::string> m_disable_layers;
240 
243 
246 
249 
252 
256 
260 
263 
264  El::Matrix<double, El::Device::CPU> m_inverse_matrices_size;
265 
267 
269  std::vector<kfac::ReqT> m_inverse_matrix_communication_reqs,
271 
278 
279  std::vector<bool> m_use_KFAC_epoch;
280 
281 }; // class KFAC
282 
283 } // namespace lbann
284 
288 template <>
289 std::unique_ptr<lbann::KFAC>
290 lbann::make<lbann::KFAC>(google::protobuf::Message const& msg);
291 
292 #endif // LBANN_EXECUTION_ALGORITHMS_KFAC_HPP_INCLUDED
int m_time_span_forward_comm_end
Definition: kfac.hpp:274
void end_old_async_weights_model(model &model, lbann_comm *comm, ExeContextType &context)
KFAC(std::string name, std::unique_ptr< TermCriteriaType > stop, std::vector< double > damping_act_params, std::vector< double > damping_err_params, std::vector< double > damping_bn_act_params, std::vector< double > damping_bn_err_params, std::vector< bool > kfac_use_interval, size_t damping_warmup_steps, double kronecker_decay, bool print_time, bool print_matrix, bool print_matrix_summary, bool use_pi, std::vector< size_t > update_intervals, size_t update_interval_steps, kfac::kfac_inverse_strategy inverse_strategy, std::vector< std::string > disable_layers, double learning_rate_factor, double learning_rate_factor_gru, size_t compute_interval, bool distribute_precondition_compute, bool use_eigen_decomposition, bool enable_copy_errors, bool enable_copy_activations)
Construct KFAC from its component pieces.
void allgather_precondition_gradient(lbann_comm &comm, ExeContextType &context)
double m_kronecker_decay
The decay factor of kronecker factors.
Definition: kfac.hpp:217
int m_global_inverse_buffer_size
Definition: kfac.hpp:266
int m_time_span_backward_comm
Definition: kfac.hpp:275
std::vector< size_t > m_update_intervals
Space-separated pairs of the initial and the target update intervals. If only one value is specified...
Definition: kfac.hpp:230
void do_epoch_begin_cbs(model &model)
void do_train_begin_cbs(model &model)
int m_time_span_inverse_comm
Profiling variables.
Definition: kfac.hpp:273
static constexpr const El::Device Device
Definition: kfac.hpp:129
bool m_distribute_precondition_compute
distribute precondition gradient compute.
Definition: kfac.hpp:251
std::vector< double > m_damping_err_params
Definition: kfac.hpp:210
void end_sync_weights_async(model &model, lbann_comm *comm)
void on_forward_prop_end(ExeContextType &context, model &model)
std::string get_type() const final
Queries.
void on_backward_prop_end(ExeContextType &context, model &model)
void start_send_recv_inverse_matrices(ExeContextType &context, lbann_comm *comm)
size_t m_damping_warmup_steps
The number of warmup steps of the Tikhnov damping technique.
Definition: kfac.hpp:214
bool m_print_time
Knobs to print information for debugging.
Definition: kfac.hpp:220
int m_time_span_inverse_send_recv
Definition: kfac.hpp:273
void sync_weights_model(model &model, lbann_comm *comm)
Data exchange functions to synchronize model and weights.
double m_learning_rate_factor_gru
Definition: kfac.hpp:242
static constexpr const int prof_color
Definition: kfac.hpp:141
constexpr El::Device Device
int m_time_kfac
Definition: kfac.hpp:277
kfac::kfac_inverse_strategy m_inverse_strategy
Assignment strategy for the model-parallel part.
Definition: kfac.hpp:236
An implementation of the KFAC second-order optimization algorithm.
Definition: kfac.hpp:59
void do_batch_end_cbs(model &model)
int m_time_forward_pass
Definition: kfac.hpp:276
bool train_mini_batch(ExeContextType &c, model &model, data_coordinator &dc)
Train model on one step / mini-batch of an SGD forward pass.
std::vector< bool > m_use_KFAC_epoch
Definition: kfac.hpp:279
size_t m_update_interval_steps
The number of steps for changing the update interval.
Definition: kfac.hpp:233
Abstract base class for neural network models.
Definition: model.hpp:83
double m_learning_rate_factor
Factors to be multiplied to the learning rate.
Definition: kfac.hpp:242
int m_time_backward_pass
Definition: kfac.hpp:277
bool m_use_pi
Weather to use the pi constant to adjust the damping constant.
Definition: kfac.hpp:224
void do_train_end_cbs(model &model)
bool m_print_matrix
Definition: kfac.hpp:220
execution_mode
Neural network execution mode.
Definition: base.hpp:229
static constexpr const double damping_0_default
The default parameters of a Tikhonov damping technique.
Definition: kfac.hpp:133
bool m_use_eigen_decomposition
use eigen value decomposition for inversing the matrix.
Definition: kfac.hpp:262
static constexpr const bool prof_sync
Parameters for prof_region_*.
Definition: kfac.hpp:140
size_t m_compute_interval
KFAC Compute interval.
Definition: kfac.hpp:248
bool m_enable_copy_errors
copy errors to a temporary matrix to increase overlap of compute and communication.
Definition: kfac.hpp:255
bool m_print_matrix_summary
Definition: kfac.hpp:220
void start_old_async_weights_model(model &model, lbann_comm *comm, ExeContextType &context)
std::vector< double > m_damping_act_params
Pairs of the initial and the target damping value. If only one value is specified, it will be used throughout training.
Definition: kfac.hpp:210
bool m_enable_copy_activations
copy activations to a temporary matrix to increase overlap of compute and communication.
Definition: kfac.hpp:259
void do_epoch_end_cbs(model &model)
int m_time_span_precond_comm
Definition: kfac.hpp:276
El::Matrix< double, El::Device::CPU > m_inverse_matrices_size
Definition: kfac.hpp:264
void train(ExeContextType &c, model &model, data_coordinator &dc, TermCriteriaType const &term)
Train a model using KFAC.
std::vector< std::string > m_disable_layers
List of layers to be ignored by the callback.
Definition: kfac.hpp:239
std::vector< kfac::ReqT > m_inverse_matrix_communication_reqs
vector for async communication reqs.
Definition: kfac.hpp:269
std::vector< double > m_damping_bn_err_params
Definition: kfac.hpp:210
static constexpr const size_t damping_warmup_steps_default
Definition: kfac.hpp:134
~KFAC() noexcept=default
void do_batch_begin_cbs(model &model)
int m_weight_matrices_buffer_size
Definition: kfac.hpp:266
std::vector< kfac::ReqT > m_weights_communication_reqs
Definition: kfac.hpp:269
std::unique_ptr< TermCriteriaType > m_stopping_criteria
The KFAC stopping criteria.
Definition: kfac.hpp:205
std::vector< double > m_damping_bn_act_params
Definition: kfac.hpp:210
void end_send_recv_inverse_matrices(ExeContextType &context, lbann_comm *comm)
kfac::KFACExecutionContext * do_get_new_execution_context() const final
Covariant return-friendly implementation of get_new_exection_context().
static constexpr const double kronecker_decay_default
The default parameters of the decay factor.
Definition: kfac.hpp:137
Base class for LBANN training_algorithms.
void start_sync_weights_async(model &model, lbann_comm *comm)
int m_time_span_backward_comm_end
Definition: kfac.hpp:275
bool m_has_kronecker_inverse
Whether inverse of Kronecker factors are available.
Definition: kfac.hpp:245
int m_time_span_forward_comm
Definition: kfac.hpp:274
Base class for SGD stopping.
void apply(ExecutionContext &context, model &m, data_coordinator &dc, execution_mode mode) final
Apply the training algorithm to refine model weights.