LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::callback::perturb_adam Class Reference

Hyperparameter exploration with Adam optimizers. More...

#include <perturb_adam.hpp>

Inheritance diagram for lbann::callback::perturb_adam:
[legend]
Collaboration diagram for lbann::callback::perturb_adam:
[legend]

Public Member Functions

 perturb_adam (DataType learning_rate_factor, DataType beta1_factor, DataType beta2_factor, DataType eps_factor=0, bool perturb_during_training=false, El::Int batch_interval=1, std::set< std::string > weights_names=std::set< std::string >())
 
perturb_adamcopy () const override
 
std::string name () const override
 Return this callback's name. More...
 
void setup (model *m) override
 Called once to set up the callback on the model (after all layers are set up). More...
 
void on_batch_begin (model *m) override
 Called at the beginning of a (mini-)batch. More...
 
Serialization
template<class Archive >
void serialize (Archive &ar)
 Store state to archive for checkpoint and restart. More...
 
- Public Member Functions inherited from lbann::callback_base
 callback_base (int batch_interval=1)
 Initialize a callback with an optional batch interval. More...
 
 callback_base (const callback_base &)=default
 
virtual ~callback_base ()=default
 
virtual void setup (trainer *t)
 Called once to set up the callback on the trainer. More...
 
virtual void on_setup_end (model *m)
 Called at the end of setup. More...
 
virtual void on_train_begin (model *m)
 Called at the beginning of training. More...
 
virtual void on_train_end (model *m)
 Called at the end of training. More...
 
virtual void on_phase_end (model *m)
 Called at the end of every phase (multiple epochs) in a layer-wise model training. More...
 
virtual void on_epoch_begin (model *m)
 Called at the beginning of each epoch. More...
 
virtual void on_epoch_end (model *m)
 Called immediate after the end of each epoch. More...
 
virtual void on_batch_end (model *m)
 Called immediately after the end of a (mini-)batch. More...
 
virtual void on_test_begin (model *m)
 Called at the beginning of testing. More...
 
virtual void on_test_end (model *m)
 Called immediately after the end of testing. More...
 
virtual void on_validation_begin (model *m)
 Called at the beginning of validation. More...
 
virtual void on_validation_end (model *m)
 Called immediately after the end of validation. More...
 
virtual void on_forward_prop_begin (model *m)
 Called when a model begins forward propagation. More...
 
virtual void on_forward_prop_begin (model *m, Layer *l)
 Called when a layer begins forward propagation. More...
 
virtual void on_forward_prop_end (model *m)
 Called when a model ends forward propagation. More...
 
virtual void on_forward_prop_end (model *m, Layer *l)
 Called when a layer ends forward propagation. More...
 
virtual void on_backward_prop_begin (model *m)
 Called when a model begins backward propagation. More...
 
virtual void on_backward_prop_begin (model *m, Layer *l)
 Called when a layer begins backward propagation. More...
 
virtual void on_backward_prop_end (model *m)
 Called when a model ends backward propagation. More...
 
virtual void on_backward_prop_end (model *m, Layer *l)
 Called when a layer ends backward propagation. More...
 
virtual void on_optimize_begin (model *m)
 Called when a model begins optimization. More...
 
virtual void on_optimize_begin (model *m, weights *w)
 Called when weights begins optimization. More...
 
virtual void on_optimize_end (model *m)
 Called when a model ends optimization. More...
 
virtual void on_optimize_end (model *m, weights *w)
 Called when weights ends optimization. More...
 
virtual void on_batch_evaluate_begin (model *m)
 Called at the beginning of a (mini-)batch evaluation (validation / testing). More...
 
virtual void on_batch_evaluate_end (model *m)
 Called at the end of a (mini-)batch evaluation (validation / testing). More...
 
virtual void on_evaluate_forward_prop_begin (model *m)
 Called when a model begins forward propagation for evaluation (validation / testing). More...
 
virtual void on_evaluate_forward_prop_begin (model *m, Layer *l)
 Called when a layer begins forward propagation for evaluation (validation / testing). More...
 
virtual void on_evaluate_forward_prop_end (model *m)
 Called when a model ends forward propagation for evaluation (validation / testing). More...
 
virtual void on_evaluate_forward_prop_end (model *m, Layer *l)
 Called when a layer ends forward propagation for evaluation (validation / testing). More...
 
int get_batch_interval () const
 Return the batch interval. More...
 
virtual description get_description () const
 Human-readable description. More...
 
template<class Archive >
void serialize (Archive &ar)
 Store state to archive for checkpoint and restart. More...
 
void write_proto (lbann_data::Callback &proto) const
 Write a protobuf description of the callback. More...
 

Private Member Functions

void write_specific_proto (lbann_data::Callback &proto) const final
 
 perturb_adam ()
 
void perturb (model &m) const
 
void perturb (lbann_comm &comm, adam< DataType > &m) const
 

Private Attributes

DataType m_learning_rate_factor
 
DataType m_beta1_factor
 
DataType m_beta2_factor
 
DataType m_eps_factor
 
bool m_perturb_during_training
 
std::set< std::string > m_weights_names
 

Friends

class cereal::access
 

Additional Inherited Members

- Protected Member Functions inherited from lbann::callback_base
std::string get_multi_trainer_path (const model &m, const std::string &root_dir)
 Build a standard directory hierarchy including trainer ID. More...
 
std::string get_multi_trainer_ec_model_path (const model &m, const std::string &root_dir)
 Build a standard directory hierachy including trainer, execution context, and model information (in that order). More...
 
std::string get_multi_trainer_model_path (const model &m, const std::string &root_dir)
 Build a standard directory hierachy including trainer, model information in that order. More...
 
callback_baseoperator= (const callback_base &)=default
 Copy-assignment operator. More...
 
- Protected Attributes inherited from lbann::callback_base
int m_batch_interval
 Batch methods should once every this many steps. More...
 

Detailed Description

Hyperparameter exploration with Adam optimizers.

Goes through the Adam optimizers in a model and perturbs four hyperparameters: the learning rate, $\beta_1$, $\beta_2$, and $\epsilon$. Since these hyperparameters can range over orders of magnitude, the perturbations are performed in log space. More precisely, random values are drawn from normal distributions (with user-provided standard deviations) and added to $\log(\text{learning rate})$, $\log(1-\beta_1)$, $\log(1-\beta_2)$, and $\log\epsilon$.

Definition at line 49 of file perturb_adam.hpp.

Constructor & Destructor Documentation

◆ perturb_adam() [1/2]

lbann::callback::perturb_adam::perturb_adam ( DataType  learning_rate_factor,
DataType  beta1_factor,
DataType  beta2_factor,
DataType  eps_factor = 0,
bool  perturb_during_training = false,
El::Int  batch_interval = 1,
std::set< std::string >  weights_names = std::set< std::string >() 
)
Parameters
learning_rate_factorStandard deviation of learning rate perturbation (in log space).
beta1_factorStandard deviation of $\beta_1$ perturbation (in log space).
beta2_factorStandard deviation of $\beta_2$ perturbation (in log space).
eps_factorStandard deviation of $\epsilon$ perturbation (in log space).
perturb_during_trainingWhether to periodically perturb hyperparameters during training or to only perturb once during setup.
batch_intervalNumber of training mini-batch steps between perturbations. Only used if perturb_during_training is true.
weights_namesNames of weights with Adam optimizers. If empty, all Adam optimizers in the model are perturbed.

◆ perturb_adam() [2/2]

lbann::callback::perturb_adam::perturb_adam ( )
private
Here is the caller graph for this function:

Member Function Documentation

◆ copy()

perturb_adam* lbann::callback::perturb_adam::copy ( ) const
inlineoverridevirtual

Implements lbann::callback_base.

Definition at line 78 of file perturb_adam.hpp.

Here is the call graph for this function:

◆ name()

std::string lbann::callback::perturb_adam::name ( ) const
inlineoverridevirtual

Return this callback's name.

Implements lbann::callback_base.

Definition at line 79 of file perturb_adam.hpp.

Here is the call graph for this function:

◆ on_batch_begin()

void lbann::callback::perturb_adam::on_batch_begin ( model m)
overridevirtual

Called at the beginning of a (mini-)batch.

Reimplemented from lbann::callback_base.

Here is the caller graph for this function:

◆ perturb() [1/2]

void lbann::callback::perturb_adam::perturb ( model m) const
private

Perturb Adam optimizers in model.

◆ perturb() [2/2]

void lbann::callback::perturb_adam::perturb ( lbann_comm comm,
adam< DataType > &  m 
) const
private

Perturb Adam optimizer hyperparameters.

◆ serialize()

template<class Archive >
void lbann::callback::perturb_adam::serialize ( Archive &  ar)

Store state to archive for checkpoint and restart.

Here is the caller graph for this function:

◆ setup()

void lbann::callback::perturb_adam::setup ( model m)
overridevirtual

Called once to set up the callback on the model (after all layers are set up).

Reimplemented from lbann::callback_base.

Here is the caller graph for this function:

◆ write_specific_proto()

void lbann::callback::perturb_adam::write_specific_proto ( lbann_data::Callback &  proto) const
finalprivatevirtual

Add callback specific data to prototext

Implements lbann::callback_base.

Here is the caller graph for this function:

Friends And Related Function Documentation

◆ cereal::access

friend class cereal::access
friend

Definition at line 97 of file perturb_adam.hpp.

Member Data Documentation

◆ m_beta1_factor

DataType lbann::callback::perturb_adam::m_beta1_factor
private

Standard deviation of $\beta_1$ perturbation.

In log space.

Definition at line 109 of file perturb_adam.hpp.

◆ m_beta2_factor

DataType lbann::callback::perturb_adam::m_beta2_factor
private

Standard deviation of $\beta_2$ perturbation.

In log space.

Definition at line 114 of file perturb_adam.hpp.

◆ m_eps_factor

DataType lbann::callback::perturb_adam::m_eps_factor
private

Standard deviation of $\epsilon$ perturbation.

In log space.

Definition at line 119 of file perturb_adam.hpp.

◆ m_learning_rate_factor

DataType lbann::callback::perturb_adam::m_learning_rate_factor
private

Standard deviation of learning rate perturbation.

In log space.

Definition at line 104 of file perturb_adam.hpp.

◆ m_perturb_during_training

bool lbann::callback::perturb_adam::m_perturb_during_training
private

Whether to periodically perturb during training.

If false, only perturb once during setup.

Definition at line 125 of file perturb_adam.hpp.

◆ m_weights_names

std::set<std::string> lbann::callback::perturb_adam::m_weights_names
private

Optimizers for these weights will be perturbed.

If empty, all Adam optimizers in the model will be perturbed.

Definition at line 131 of file perturb_adam.hpp.


The documentation for this class was generated from the following file: