LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::KFAC Class Referencefinal

An implementation of the KFAC second-order optimization algorithm. More...

#include <kfac.hpp>

Inheritance diagram for lbann::KFAC:
[legend]
Collaboration diagram for lbann::KFAC:
[legend]

Public Types

using TermCriteriaType = SGDTerminationCriteria
 
using ExeContextType = kfac::KFACExecutionContext
 

Public Member Functions

Life-cycle management
 KFAC (std::string name, std::unique_ptr< TermCriteriaType > stop, std::vector< double > damping_act_params, std::vector< double > damping_err_params, std::vector< double > damping_bn_act_params, std::vector< double > damping_bn_err_params, std::vector< bool > kfac_use_interval, size_t damping_warmup_steps, double kronecker_decay, bool print_time, bool print_matrix, bool print_matrix_summary, bool use_pi, std::vector< size_t > update_intervals, size_t update_interval_steps, kfac::kfac_inverse_strategy inverse_strategy, std::vector< std::string > disable_layers, double learning_rate_factor, double learning_rate_factor_gru, size_t compute_interval, bool distribute_precondition_compute, bool use_eigen_decomposition, bool enable_copy_errors, bool enable_copy_activations)
 Construct KFAC from its component pieces. More...
 
 ~KFAC () noexcept=default
 
 KFAC (KFAC const &other)=delete
 
KFACoperator= (const KFAC &other)=delete
 
 KFAC (KFAC &&other)=default
 
KFACoperator= (KFAC &&other)=default
 
std::string get_type () const final
 Queries. More...
 
Apply interface
void apply (ExecutionContext &context, model &m, data_coordinator &dc, execution_mode mode) final
 Apply the training algorithm to refine model weights. More...
 
void train (ExeContextType &c, model &model, data_coordinator &dc, TermCriteriaType const &term)
 Train a model using KFAC. More...
 
- Public Member Functions inherited from lbann::TrainingAlgorithm
 TrainingAlgorithm (std::string name)
 Constructor. More...
 
virtual ~TrainingAlgorithm ()=default
 
std::string const & get_name () const noexcept
 A user-defined string identifying the algorithm object. More...
 
void apply (model &model, data_coordinator &dc)
 Apply the algorithm to the given model. More...
 
void setup_models (std::vector< observer_ptr< model >> const &models, size_t max_mini_batch_size, const std::vector< El::Grid *> &grids)
 Setup a collection of models. More...
 
std::unique_ptr< ExecutionContextget_new_execution_context () const
 Get a default-initialized execution context that fits this training algorithm. More...
 

Static Public Attributes

static constexpr const El::Device Device = El::Device::CPU
 
static constexpr const double damping_0_default = 3e-2
 The default parameters of a Tikhonov damping technique. More...
 
static constexpr const size_t damping_warmup_steps_default = 100
 
static constexpr const double kronecker_decay_default = 0.99
 The default parameters of the decay factor. More...
 
static constexpr const bool prof_sync = true
 Parameters for prof_region_*. More...
 
static constexpr const int prof_color = 0
 

Protected Member Functions

bool train_mini_batch (ExeContextType &c, model &model, data_coordinator &dc)
 Train model on one step / mini-batch of an SGD forward pass. More...
 
kfac::KFACExecutionContextdo_get_new_execution_context () const final
 Covariant return-friendly implementation of get_new_exection_context(). More...
 
void start_send_recv_inverse_matrices (ExeContextType &context, lbann_comm *comm)
 
void end_send_recv_inverse_matrices (ExeContextType &context, lbann_comm *comm)
 
Callback hooks
void do_train_begin_cbs (model &model)
 
void do_train_end_cbs (model &model)
 
void do_epoch_begin_cbs (model &model)
 
void do_epoch_end_cbs (model &model)
 
void do_batch_begin_cbs (model &model)
 
void do_batch_end_cbs (model &model)
 
- Protected Member Functions inherited from lbann::TrainingAlgorithm
 TrainingAlgorithm (const TrainingAlgorithm &other)=delete
 
TrainingAlgorithmoperator= (const TrainingAlgorithm &other)=delete
 
 TrainingAlgorithm (TrainingAlgorithm &&other)=default
 
TrainingAlgorithmoperator= (TrainingAlgorithm &&other)=default
 

Private Member Functions

void on_forward_prop_end (ExeContextType &context, model &model)
 
void on_backward_prop_end (ExeContextType &context, model &model)
 
void sync_weights_model (model &model, lbann_comm *comm)
 Data exchange functions to synchronize model and weights. More...
 
void start_sync_weights_async (model &model, lbann_comm *comm)
 
void end_sync_weights_async (model &model, lbann_comm *comm)
 
void start_old_async_weights_model (model &model, lbann_comm *comm, ExeContextType &context)
 
void end_old_async_weights_model (model &model, lbann_comm *comm, ExeContextType &context)
 
void allgather_precondition_gradient (lbann_comm &comm, ExeContextType &context)
 

Private Attributes

std::unique_ptr< TermCriteriaTypem_stopping_criteria
 The KFAC stopping criteria. More...
 
std::vector< double > m_damping_act_params
 Pairs of the initial and the target damping value. If only one value is specified, it will be used throughout training. More...
 
std::vector< double > m_damping_err_params
 
std::vector< double > m_damping_bn_act_params
 
std::vector< double > m_damping_bn_err_params
 
size_t m_damping_warmup_steps
 The number of warmup steps of the Tikhnov damping technique. More...
 
double m_kronecker_decay
 The decay factor of kronecker factors. More...
 
bool m_print_time
 Knobs to print information for debugging. More...
 
bool m_print_matrix
 
bool m_print_matrix_summary
 
bool m_use_pi
 Weather to use the pi constant to adjust the damping constant. More...
 
std::vector< size_t > m_update_intervals
 Space-separated pairs of the initial and the target update intervals. If only one value is specified, it will be used throughout training. More...
 
size_t m_update_interval_steps
 The number of steps for changing the update interval. More...
 
kfac::kfac_inverse_strategy m_inverse_strategy
 Assignment strategy for the model-parallel part. More...
 
std::vector< std::string > m_disable_layers
 List of layers to be ignored by the callback. More...
 
double m_learning_rate_factor
 Factors to be multiplied to the learning rate. More...
 
double m_learning_rate_factor_gru
 
bool m_has_kronecker_inverse = false
 Whether inverse of Kronecker factors are available. More...
 
size_t m_compute_interval
 KFAC Compute interval. More...
 
bool m_distribute_precondition_compute
 distribute precondition gradient compute. More...
 
bool m_enable_copy_errors
 copy errors to a temporary matrix to increase overlap of compute and communication. More...
 
bool m_enable_copy_activations
 copy activations to a temporary matrix to increase overlap of compute and communication. More...
 
bool m_use_eigen_decomposition
 use eigen value decomposition for inversing the matrix. More...
 
El::Matrix< double, El::Device::CPU > m_inverse_matrices_size
 
int m_global_inverse_buffer_size = 0
 
int m_weight_matrices_buffer_size = 0
 
std::vector< kfac::ReqTm_inverse_matrix_communication_reqs
 vector for async communication reqs. More...
 
std::vector< kfac::ReqTm_weights_communication_reqs
 
int m_time_span_inverse_comm = 0
 Profiling variables. More...
 
int m_time_span_inverse_send_recv = 0
 
int m_time_span_forward_comm = 0
 
int m_time_span_forward_comm_end = 0
 
int m_time_span_backward_comm = 0
 
int m_time_span_backward_comm_end = 0
 
int m_time_span_precond_comm = 0
 
int m_time_forward_pass = 0
 
int m_time_backward_pass = 0
 
int m_time_kfac = 0
 
std::vector< bool > m_use_KFAC_epoch
 

Detailed Description

An implementation of the KFAC second-order optimization algorithm.

Martens, James and Roger Grosse. "Optimizing neural networks with kronecker-factored approximate curvature." International conference on machine learning. 2015.

Grosse, Roger, and James Martens. "A kronecker-factored approximate fisher matrix for convolution layers." International Conference on Machine Learning. 2016.

Osawa, Kazuki, et al. "Large-scale distributed second-order optimization using kronecker-factored approximate curvature for deep convolutional neural networks." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.

Definition at line 59 of file kfac.hpp.

Member Typedef Documentation

◆ ExeContextType

◆ TermCriteriaType

Definition at line 63 of file kfac.hpp.

Constructor & Destructor Documentation

◆ KFAC() [1/3]

lbann::KFAC::KFAC ( std::string  name,
std::unique_ptr< TermCriteriaType stop,
std::vector< double >  damping_act_params,
std::vector< double >  damping_err_params,
std::vector< double >  damping_bn_act_params,
std::vector< double >  damping_bn_err_params,
std::vector< bool >  kfac_use_interval,
size_t  damping_warmup_steps,
double  kronecker_decay,
bool  print_time,
bool  print_matrix,
bool  print_matrix_summary,
bool  use_pi,
std::vector< size_t >  update_intervals,
size_t  update_interval_steps,
kfac::kfac_inverse_strategy  inverse_strategy,
std::vector< std::string >  disable_layers,
double  learning_rate_factor,
double  learning_rate_factor_gru,
size_t  compute_interval,
bool  distribute_precondition_compute,
bool  use_eigen_decomposition,
bool  enable_copy_errors,
bool  enable_copy_activations 
)

Construct KFAC from its component pieces.

◆ ~KFAC()

lbann::KFAC::~KFAC ( )
defaultnoexcept

◆ KFAC() [2/3]

lbann::KFAC::KFAC ( KFAC const &  other)
delete

◆ KFAC() [3/3]

lbann::KFAC::KFAC ( KFAC &&  other)
default

Member Function Documentation

◆ allgather_precondition_gradient()

void lbann::KFAC::allgather_precondition_gradient ( lbann_comm comm,
ExeContextType context 
)
private

◆ apply()

void lbann::KFAC::apply ( ExecutionContext context,
model m,
data_coordinator dc,
execution_mode  mode 
)
finalvirtual

Apply the training algorithm to refine model weights.

Parameters
[in,out]contextThe persistent execution context for this algorithm.
[in,out]mThe model to be trained.
[in,out]dcThe data source for training.
[in]modeCompletely superfluous.

Implements lbann::TrainingAlgorithm.

◆ do_batch_begin_cbs()

void lbann::KFAC::do_batch_begin_cbs ( model model)
protected

Execute callbacks at start of mini-batch.

◆ do_batch_end_cbs()

void lbann::KFAC::do_batch_end_cbs ( model model)
protected

Execute callbacks at end of mini-batch.

◆ do_epoch_begin_cbs()

void lbann::KFAC::do_epoch_begin_cbs ( model model)
protected

Execute callbacks at start of epoch.

◆ do_epoch_end_cbs()

void lbann::KFAC::do_epoch_end_cbs ( model model)
protected

Execute callbacks at end of epoch.

◆ do_get_new_execution_context()

kfac::KFACExecutionContext* lbann::KFAC::do_get_new_execution_context ( ) const
finalprotectedvirtual

Covariant return-friendly implementation of get_new_exection_context().

Implements lbann::TrainingAlgorithm.

◆ do_train_begin_cbs()

void lbann::KFAC::do_train_begin_cbs ( model model)
protected

Execute callbacks at start of training.

◆ do_train_end_cbs()

void lbann::KFAC::do_train_end_cbs ( model model)
protected

Execute callbacks at end of training.

◆ end_old_async_weights_model()

void lbann::KFAC::end_old_async_weights_model ( model model,
lbann_comm comm,
ExeContextType context 
)
private

◆ end_send_recv_inverse_matrices()

void lbann::KFAC::end_send_recv_inverse_matrices ( ExeContextType context,
lbann_comm comm 
)
protected

◆ end_sync_weights_async()

void lbann::KFAC::end_sync_weights_async ( model model,
lbann_comm comm 
)
private

◆ get_type()

std::string lbann::KFAC::get_type ( ) const
finalvirtual

Queries.

Implements lbann::TrainingAlgorithm.

◆ on_backward_prop_end()

void lbann::KFAC::on_backward_prop_end ( ExeContextType context,
model model 
)
private

◆ on_forward_prop_end()

void lbann::KFAC::on_forward_prop_end ( ExeContextType context,
model model 
)
private
Todo:
Break up into more manageable pieces

◆ operator=() [1/2]

KFAC& lbann::KFAC::operator= ( const KFAC other)
delete

◆ operator=() [2/2]

KFAC& lbann::KFAC::operator= ( KFAC &&  other)
default

◆ start_old_async_weights_model()

void lbann::KFAC::start_old_async_weights_model ( model model,
lbann_comm comm,
ExeContextType context 
)
private

◆ start_send_recv_inverse_matrices()

void lbann::KFAC::start_send_recv_inverse_matrices ( ExeContextType context,
lbann_comm comm 
)
protected

◆ start_sync_weights_async()

void lbann::KFAC::start_sync_weights_async ( model model,
lbann_comm comm 
)
private

◆ sync_weights_model()

void lbann::KFAC::sync_weights_model ( model model,
lbann_comm comm 
)
private

Data exchange functions to synchronize model and weights.

◆ train()

void lbann::KFAC::train ( ExeContextType c,
model model,
data_coordinator dc,
TermCriteriaType const &  term 
)

Train a model using KFAC.

◆ train_mini_batch()

bool lbann::KFAC::train_mini_batch ( ExeContextType c,
model model,
data_coordinator dc 
)
protected

Train model on one step / mini-batch of an SGD forward pass.

Member Data Documentation

◆ damping_0_default

constexpr const double lbann::KFAC::damping_0_default = 3e-2
static

The default parameters of a Tikhonov damping technique.

Definition at line 133 of file kfac.hpp.

◆ damping_warmup_steps_default

constexpr const size_t lbann::KFAC::damping_warmup_steps_default = 100
static

Definition at line 134 of file kfac.hpp.

◆ Device

constexpr const El::Device lbann::KFAC::Device = El::Device::CPU
static

Definition at line 129 of file kfac.hpp.

◆ kronecker_decay_default

constexpr const double lbann::KFAC::kronecker_decay_default = 0.99
static

The default parameters of the decay factor.

Definition at line 137 of file kfac.hpp.

◆ m_compute_interval

size_t lbann::KFAC::m_compute_interval
private

KFAC Compute interval.

Definition at line 248 of file kfac.hpp.

◆ m_damping_act_params

std::vector<double> lbann::KFAC::m_damping_act_params
private

Pairs of the initial and the target damping value. If only one value is specified, it will be used throughout training.

Definition at line 210 of file kfac.hpp.

◆ m_damping_bn_act_params

std::vector<double> lbann::KFAC::m_damping_bn_act_params
private

Definition at line 210 of file kfac.hpp.

◆ m_damping_bn_err_params

std::vector<double> lbann::KFAC::m_damping_bn_err_params
private

Definition at line 210 of file kfac.hpp.

◆ m_damping_err_params

std::vector<double> lbann::KFAC::m_damping_err_params
private

Definition at line 210 of file kfac.hpp.

◆ m_damping_warmup_steps

size_t lbann::KFAC::m_damping_warmup_steps
private

The number of warmup steps of the Tikhnov damping technique.

Definition at line 214 of file kfac.hpp.

◆ m_disable_layers

std::vector<std::string> lbann::KFAC::m_disable_layers
private

List of layers to be ignored by the callback.

Definition at line 239 of file kfac.hpp.

◆ m_distribute_precondition_compute

bool lbann::KFAC::m_distribute_precondition_compute
private

distribute precondition gradient compute.

Definition at line 251 of file kfac.hpp.

◆ m_enable_copy_activations

bool lbann::KFAC::m_enable_copy_activations
private

copy activations to a temporary matrix to increase overlap of compute and communication.

Definition at line 259 of file kfac.hpp.

◆ m_enable_copy_errors

bool lbann::KFAC::m_enable_copy_errors
private

copy errors to a temporary matrix to increase overlap of compute and communication.

Definition at line 255 of file kfac.hpp.

◆ m_global_inverse_buffer_size

int lbann::KFAC::m_global_inverse_buffer_size = 0
private

Definition at line 266 of file kfac.hpp.

◆ m_has_kronecker_inverse

bool lbann::KFAC::m_has_kronecker_inverse = false
private

Whether inverse of Kronecker factors are available.

Definition at line 245 of file kfac.hpp.

◆ m_inverse_matrices_size

El::Matrix<double, El::Device::CPU> lbann::KFAC::m_inverse_matrices_size
private

Definition at line 264 of file kfac.hpp.

◆ m_inverse_matrix_communication_reqs

std::vector<kfac::ReqT> lbann::KFAC::m_inverse_matrix_communication_reqs
private

vector for async communication reqs.

Definition at line 269 of file kfac.hpp.

◆ m_inverse_strategy

kfac::kfac_inverse_strategy lbann::KFAC::m_inverse_strategy
private

Assignment strategy for the model-parallel part.

Definition at line 236 of file kfac.hpp.

◆ m_kronecker_decay

double lbann::KFAC::m_kronecker_decay
private

The decay factor of kronecker factors.

Definition at line 217 of file kfac.hpp.

◆ m_learning_rate_factor

double lbann::KFAC::m_learning_rate_factor
private

Factors to be multiplied to the learning rate.

Definition at line 242 of file kfac.hpp.

◆ m_learning_rate_factor_gru

double lbann::KFAC::m_learning_rate_factor_gru
private

Definition at line 242 of file kfac.hpp.

◆ m_print_matrix

bool lbann::KFAC::m_print_matrix
private

Definition at line 220 of file kfac.hpp.

◆ m_print_matrix_summary

bool lbann::KFAC::m_print_matrix_summary
private

Definition at line 220 of file kfac.hpp.

◆ m_print_time

bool lbann::KFAC::m_print_time
private

Knobs to print information for debugging.

Definition at line 220 of file kfac.hpp.

◆ m_stopping_criteria

std::unique_ptr<TermCriteriaType> lbann::KFAC::m_stopping_criteria
private

The KFAC stopping criteria.

Definition at line 205 of file kfac.hpp.

◆ m_time_backward_pass

int lbann::KFAC::m_time_backward_pass = 0
private

Definition at line 277 of file kfac.hpp.

◆ m_time_forward_pass

int lbann::KFAC::m_time_forward_pass = 0
private

Definition at line 276 of file kfac.hpp.

◆ m_time_kfac

int lbann::KFAC::m_time_kfac = 0
private

Definition at line 277 of file kfac.hpp.

◆ m_time_span_backward_comm

int lbann::KFAC::m_time_span_backward_comm = 0
private

Definition at line 275 of file kfac.hpp.

◆ m_time_span_backward_comm_end

int lbann::KFAC::m_time_span_backward_comm_end = 0
private

Definition at line 275 of file kfac.hpp.

◆ m_time_span_forward_comm

int lbann::KFAC::m_time_span_forward_comm = 0
private

Definition at line 274 of file kfac.hpp.

◆ m_time_span_forward_comm_end

int lbann::KFAC::m_time_span_forward_comm_end = 0
private

Definition at line 274 of file kfac.hpp.

◆ m_time_span_inverse_comm

int lbann::KFAC::m_time_span_inverse_comm = 0
private

Profiling variables.

Definition at line 273 of file kfac.hpp.

◆ m_time_span_inverse_send_recv

int lbann::KFAC::m_time_span_inverse_send_recv = 0
private

Definition at line 273 of file kfac.hpp.

◆ m_time_span_precond_comm

int lbann::KFAC::m_time_span_precond_comm = 0
private

Definition at line 276 of file kfac.hpp.

◆ m_update_interval_steps

size_t lbann::KFAC::m_update_interval_steps
private

The number of steps for changing the update interval.

Definition at line 233 of file kfac.hpp.

◆ m_update_intervals

std::vector<size_t> lbann::KFAC::m_update_intervals
private

Space-separated pairs of the initial and the target update intervals. If only one value is specified, it will be used throughout training.

Definition at line 230 of file kfac.hpp.

◆ m_use_eigen_decomposition

bool lbann::KFAC::m_use_eigen_decomposition
private

use eigen value decomposition for inversing the matrix.

Definition at line 262 of file kfac.hpp.

◆ m_use_KFAC_epoch

std::vector<bool> lbann::KFAC::m_use_KFAC_epoch
private

Definition at line 279 of file kfac.hpp.

◆ m_use_pi

bool lbann::KFAC::m_use_pi
private

Weather to use the pi constant to adjust the damping constant.

Definition at line 224 of file kfac.hpp.

◆ m_weight_matrices_buffer_size

int lbann::KFAC::m_weight_matrices_buffer_size = 0
private

Definition at line 266 of file kfac.hpp.

◆ m_weights_communication_reqs

std::vector<kfac::ReqT> lbann::KFAC::m_weights_communication_reqs
private

Definition at line 269 of file kfac.hpp.

◆ prof_color

constexpr const int lbann::KFAC::prof_color = 0
static

Definition at line 141 of file kfac.hpp.

◆ prof_sync

constexpr const bool lbann::KFAC::prof_sync = true
static

Parameters for prof_region_*.

Definition at line 140 of file kfac.hpp.


The documentation for this class was generated from the following file: