26 #ifndef LBANN_EXECUTION_ALGORITHMS_KFAC_HPP_INCLUDED 27 #define LBANN_EXECUTION_ALGORITHMS_KFAC_HPP_INCLUDED 38 #include <google/protobuf/message.h> 71 KFAC(std::string name,
72 std::unique_ptr<TermCriteriaType> stop,
73 std::vector<double> damping_act_params,
74 std::vector<double> damping_err_params,
75 std::vector<double> damping_bn_act_params,
76 std::vector<double> damping_bn_err_params,
77 std::vector<bool> kfac_use_interval,
78 size_t damping_warmup_steps,
79 double kronecker_decay,
82 bool print_matrix_summary,
84 std::vector<size_t> update_intervals,
85 size_t update_interval_steps,
87 std::vector<std::string> disable_layers,
88 double learning_rate_factor,
89 double learning_rate_factor_gru,
90 size_t compute_interval,
91 bool distribute_precondition_compute,
92 bool use_eigen_decomposition,
93 bool enable_copy_errors,
94 bool enable_copy_activations);
96 ~KFAC() noexcept = default;
98 KFAC& operator=(const
KFAC& other) = delete;
100 KFAC& operator=(
KFAC&& other) = default;
130 #endif // LBANN_HAS_GPU 181 void compute_kronecker_factors(
ExeContextType& context, model& model);
184 void invert_kronecker_factors(
ExeContextType& context, model& model);
187 void precondition_gradients(
ExeContextType& context, model& model);
289 std::unique_ptr<lbann::KFAC>
290 lbann::make<lbann::KFAC>(google::protobuf::Message
const& msg);
292 #endif // LBANN_EXECUTION_ALGORITHMS_KFAC_HPP_INCLUDED int m_time_span_forward_comm_end
void end_old_async_weights_model(model &model, lbann_comm *comm, ExeContextType &context)
KFAC(std::string name, std::unique_ptr< TermCriteriaType > stop, std::vector< double > damping_act_params, std::vector< double > damping_err_params, std::vector< double > damping_bn_act_params, std::vector< double > damping_bn_err_params, std::vector< bool > kfac_use_interval, size_t damping_warmup_steps, double kronecker_decay, bool print_time, bool print_matrix, bool print_matrix_summary, bool use_pi, std::vector< size_t > update_intervals, size_t update_interval_steps, kfac::kfac_inverse_strategy inverse_strategy, std::vector< std::string > disable_layers, double learning_rate_factor, double learning_rate_factor_gru, size_t compute_interval, bool distribute_precondition_compute, bool use_eigen_decomposition, bool enable_copy_errors, bool enable_copy_activations)
Construct KFAC from its component pieces.
void allgather_precondition_gradient(lbann_comm &comm, ExeContextType &context)
double m_kronecker_decay
The decay factor of kronecker factors.
int m_global_inverse_buffer_size
int m_time_span_backward_comm
std::vector< size_t > m_update_intervals
Space-separated pairs of the initial and the target update intervals. If only one value is specified...
void do_epoch_begin_cbs(model &model)
void do_train_begin_cbs(model &model)
int m_time_span_inverse_comm
Profiling variables.
static constexpr const El::Device Device
bool m_distribute_precondition_compute
distribute precondition gradient compute.
std::vector< double > m_damping_err_params
void end_sync_weights_async(model &model, lbann_comm *comm)
void on_forward_prop_end(ExeContextType &context, model &model)
std::string get_type() const final
Queries.
void on_backward_prop_end(ExeContextType &context, model &model)
void start_send_recv_inverse_matrices(ExeContextType &context, lbann_comm *comm)
size_t m_damping_warmup_steps
The number of warmup steps of the Tikhnov damping technique.
bool m_print_time
Knobs to print information for debugging.
int m_time_span_inverse_send_recv
void sync_weights_model(model &model, lbann_comm *comm)
Data exchange functions to synchronize model and weights.
double m_learning_rate_factor_gru
static constexpr const int prof_color
constexpr El::Device Device
kfac::kfac_inverse_strategy m_inverse_strategy
Assignment strategy for the model-parallel part.
An implementation of the KFAC second-order optimization algorithm.
void do_batch_end_cbs(model &model)
bool train_mini_batch(ExeContextType &c, model &model, data_coordinator &dc)
Train model on one step / mini-batch of an SGD forward pass.
std::vector< bool > m_use_KFAC_epoch
size_t m_update_interval_steps
The number of steps for changing the update interval.
Abstract base class for neural network models.
double m_learning_rate_factor
Factors to be multiplied to the learning rate.
bool m_use_pi
Weather to use the pi constant to adjust the damping constant.
void do_train_end_cbs(model &model)
execution_mode
Neural network execution mode.
static constexpr const double damping_0_default
The default parameters of a Tikhonov damping technique.
bool m_use_eigen_decomposition
use eigen value decomposition for inversing the matrix.
static constexpr const bool prof_sync
Parameters for prof_region_*.
size_t m_compute_interval
KFAC Compute interval.
bool m_enable_copy_errors
copy errors to a temporary matrix to increase overlap of compute and communication.
bool m_print_matrix_summary
void start_old_async_weights_model(model &model, lbann_comm *comm, ExeContextType &context)
std::vector< double > m_damping_act_params
Pairs of the initial and the target damping value. If only one value is specified, it will be used throughout training.
bool m_enable_copy_activations
copy activations to a temporary matrix to increase overlap of compute and communication.
void do_epoch_end_cbs(model &model)
int m_time_span_precond_comm
El::Matrix< double, El::Device::CPU > m_inverse_matrices_size
void train(ExeContextType &c, model &model, data_coordinator &dc, TermCriteriaType const &term)
Train a model using KFAC.
std::vector< std::string > m_disable_layers
List of layers to be ignored by the callback.
std::vector< kfac::ReqT > m_inverse_matrix_communication_reqs
vector for async communication reqs.
std::vector< double > m_damping_bn_err_params
static constexpr const size_t damping_warmup_steps_default
void do_batch_begin_cbs(model &model)
int m_weight_matrices_buffer_size
std::vector< kfac::ReqT > m_weights_communication_reqs
std::unique_ptr< TermCriteriaType > m_stopping_criteria
The KFAC stopping criteria.
std::vector< double > m_damping_bn_act_params
void end_send_recv_inverse_matrices(ExeContextType &context, lbann_comm *comm)
kfac::KFACExecutionContext * do_get_new_execution_context() const final
Covariant return-friendly implementation of get_new_exection_context().
static constexpr const double kronecker_decay_default
The default parameters of the decay factor.
Base class for LBANN training_algorithms.
void start_sync_weights_async(model &model, lbann_comm *comm)
int m_time_span_backward_comm_end
bool m_has_kronecker_inverse
Whether inverse of Kronecker factors are available.
int m_time_span_forward_comm
Base class for SGD stopping.
void apply(ExecutionContext &context, model &m, data_coordinator &dc, execution_mode mode) final
Apply the training algorithm to refine model weights.