LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::callback::learning_rate Class Reference

#include <learning_rate.hpp>

Inheritance diagram for lbann::callback::learning_rate:
[legend]
Collaboration diagram for lbann::callback::learning_rate:
[legend]

Public Member Functions

 learning_rate ()
 
 learning_rate (const learning_rate &)=default
 
learning_rateoperator= (const learning_rate &)=default
 
 learning_rate (std::vector< std::string > weights_names)
 
void setup (model *m) override
 
void on_epoch_end (model *m) override
 
void on_backward_prop_end (model *m) override
 
- Public Member Functions inherited from lbann::callback_base
 callback_base (int batch_interval=1)
 Initialize a callback with an optional batch interval. More...
 
 callback_base (const callback_base &)=default
 
virtual ~callback_base ()=default
 
virtual callback_basecopy () const =0
 
virtual void setup (trainer *t)
 Called once to set up the callback on the trainer. More...
 
virtual void on_setup_end (model *m)
 Called at the end of setup. More...
 
virtual void on_train_begin (model *m)
 Called at the beginning of training. More...
 
virtual void on_train_end (model *m)
 Called at the end of training. More...
 
virtual void on_phase_end (model *m)
 Called at the end of every phase (multiple epochs) in a layer-wise model training. More...
 
virtual void on_epoch_begin (model *m)
 Called at the beginning of each epoch. More...
 
virtual void on_batch_begin (model *m)
 Called at the beginning of a (mini-)batch. More...
 
virtual void on_batch_end (model *m)
 Called immediately after the end of a (mini-)batch. More...
 
virtual void on_test_begin (model *m)
 Called at the beginning of testing. More...
 
virtual void on_test_end (model *m)
 Called immediately after the end of testing. More...
 
virtual void on_validation_begin (model *m)
 Called at the beginning of validation. More...
 
virtual void on_validation_end (model *m)
 Called immediately after the end of validation. More...
 
virtual void on_forward_prop_begin (model *m)
 Called when a model begins forward propagation. More...
 
virtual void on_forward_prop_begin (model *m, Layer *l)
 Called when a layer begins forward propagation. More...
 
virtual void on_forward_prop_end (model *m)
 Called when a model ends forward propagation. More...
 
virtual void on_forward_prop_end (model *m, Layer *l)
 Called when a layer ends forward propagation. More...
 
virtual void on_backward_prop_begin (model *m)
 Called when a model begins backward propagation. More...
 
virtual void on_backward_prop_begin (model *m, Layer *l)
 Called when a layer begins backward propagation. More...
 
virtual void on_backward_prop_end (model *m, Layer *l)
 Called when a layer ends backward propagation. More...
 
virtual void on_optimize_begin (model *m)
 Called when a model begins optimization. More...
 
virtual void on_optimize_begin (model *m, weights *w)
 Called when weights begins optimization. More...
 
virtual void on_optimize_end (model *m)
 Called when a model ends optimization. More...
 
virtual void on_optimize_end (model *m, weights *w)
 Called when weights ends optimization. More...
 
virtual void on_batch_evaluate_begin (model *m)
 Called at the beginning of a (mini-)batch evaluation (validation / testing). More...
 
virtual void on_batch_evaluate_end (model *m)
 Called at the end of a (mini-)batch evaluation (validation / testing). More...
 
virtual void on_evaluate_forward_prop_begin (model *m)
 Called when a model begins forward propagation for evaluation (validation / testing). More...
 
virtual void on_evaluate_forward_prop_begin (model *m, Layer *l)
 Called when a layer begins forward propagation for evaluation (validation / testing). More...
 
virtual void on_evaluate_forward_prop_end (model *m)
 Called when a model ends forward propagation for evaluation (validation / testing). More...
 
virtual void on_evaluate_forward_prop_end (model *m, Layer *l)
 Called when a layer ends forward propagation for evaluation (validation / testing). More...
 
int get_batch_interval () const
 Return the batch interval. More...
 
virtual std::string name () const =0
 Return this callback's name. More...
 
virtual description get_description () const
 Human-readable description. More...
 
template<class Archive >
void serialize (Archive &ar)
 Store state to archive for checkpoint and restart. More...
 
void write_proto (lbann_data::Callback &proto) const
 Write a protobuf description of the callback. More...
 

Protected Member Functions

std::vector< std::string > const & get_weights_names () const
 
virtual float global_schedule (model *m)
 
virtual float optimizer_schedule (model *m, optimizer &opt)
 
const std::unordered_set< weights * > & get_weights () const noexcept
 
- Protected Member Functions inherited from lbann::callback_base
std::string get_multi_trainer_path (const model &m, const std::string &root_dir)
 Build a standard directory hierarchy including trainer ID. More...
 
std::string get_multi_trainer_ec_model_path (const model &m, const std::string &root_dir)
 Build a standard directory hierachy including trainer, execution context, and model information (in that order). More...
 
std::string get_multi_trainer_model_path (const model &m, const std::string &root_dir)
 Build a standard directory hierachy including trainer, model information in that order. More...
 
callback_baseoperator= (const callback_base &)=default
 Copy-assignment operator. More...
 
virtual void write_specific_proto (lbann_data::Callback &proto) const =0
 Add callback specific data to prototext. More...
 

Static Protected Member Functions

static float get_current_global_learning_rate () noexcept
 
static void update_global_learning_rate (float rate) noexcept
 

Private Attributes

std::vector< std::string > m_weights_names
 
std::unordered_set< weights * > m_weights
 

Static Private Attributes

static float m_cur_global_lr
 

Additional Inherited Members

- Protected Attributes inherited from lbann::callback_base
int m_batch_interval
 Batch methods should once every this many steps. More...
 

Detailed Description

Base class for learning rate schedules. Child classes should implement the schedule method to make changes.

Definition at line 49 of file learning_rate.hpp.

Constructor & Destructor Documentation

◆ learning_rate() [1/3]

lbann::callback::learning_rate::learning_rate ( )

◆ learning_rate() [2/3]

lbann::callback::learning_rate::learning_rate ( const learning_rate )
default

◆ learning_rate() [3/3]

lbann::callback::learning_rate::learning_rate ( std::vector< std::string >  weights_names)

Only apply to specific weights.

Member Function Documentation

◆ get_current_global_learning_rate()

static float lbann::callback::learning_rate::get_current_global_learning_rate ( )
inlinestaticprotectednoexcept

Definition at line 97 of file learning_rate.hpp.

Here is the caller graph for this function:

◆ get_weights()

const std::unordered_set<weights*>& lbann::callback::learning_rate::get_weights ( ) const
inlineprotectednoexcept

Definition at line 92 of file learning_rate.hpp.

◆ get_weights_names()

std::vector<std::string> const& lbann::callback::learning_rate::get_weights_names ( ) const
inlineprotected

Definition at line 67 of file learning_rate.hpp.

◆ global_schedule()

virtual float lbann::callback::learning_rate::global_schedule ( model m)
inlineprotectedvirtual

This is called at the end of every epoch to update the learning rate for every optimizer. Adjustments should be made based on the current global learning rate. The returned learning rate will be used to automatically update the current global learning rate.

Reimplemented in lbann::callback::cosine_decay_learning_rate, lbann::callback::poly_learning_rate, lbann::callback::linear_growth_learning_rate, lbann::callback::drop_fixed_learning_rate, lbann::callback::adaptive_learning_rate, lbann::callback::set_learning_rate, and lbann::callback::step_learning_rate.

Definition at line 80 of file learning_rate.hpp.

Here is the call graph for this function:
Here is the caller graph for this function:

◆ on_backward_prop_end()

void lbann::callback::learning_rate::on_backward_prop_end ( model m)
overridevirtual

Apply local/per-optimizer learning rate schedules.

Reimplemented from lbann::callback_base.

◆ on_epoch_end()

void lbann::callback::learning_rate::on_epoch_end ( model m)
overridevirtual

Apply global learning rate schedules.

Reimplemented from lbann::callback_base.

◆ operator=()

learning_rate& lbann::callback::learning_rate::operator= ( const learning_rate )
default

◆ optimizer_schedule()

virtual float lbann::callback::learning_rate::optimizer_schedule ( model m,
optimizer opt 
)
protectedvirtual

This is called at the end of every training mini-batch to update the learning rate for optimizer opt. The current global learning rate is not updated automatically based on this method.

Reimplemented in lbann::callback::cosine_decay_learning_rate, lbann::callback::optimizerwise_adaptive_learning_rate, and lbann::callback::poly_learning_rate.

Here is the caller graph for this function:

◆ setup()

void lbann::callback::learning_rate::setup ( model m)
overridevirtual

Do some initialization.

Reimplemented from lbann::callback_base.

Reimplemented in lbann::callback::cosine_decay_learning_rate, lbann::callback::poly_learning_rate, and lbann::callback::linear_growth_learning_rate.

Here is the caller graph for this function:

◆ update_global_learning_rate()

static void lbann::callback::learning_rate::update_global_learning_rate ( float  rate)
inlinestaticprotectednoexcept

Definition at line 102 of file learning_rate.hpp.

Member Data Documentation

◆ m_cur_global_lr

float lbann::callback::learning_rate::m_cur_global_lr
staticprivate

This should be maintained by all learning rate schedule implementations as the current global learning rate. This enables coordination among different schedules, particularly ones that work on a per-optimizer basis.

Definition at line 114 of file learning_rate.hpp.

◆ m_weights

std::unordered_set<weights*> lbann::callback::learning_rate::m_weights
private

Weights to update.

Definition at line 120 of file learning_rate.hpp.

◆ m_weights_names

std::vector<std::string> lbann::callback::learning_rate::m_weights_names
private

Names of the weights being updated.

Definition at line 117 of file learning_rate.hpp.


The documentation for this class was generated from the following file: