LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::SGDTrainingAlgorithm Class Reference

Base class for LBANN SGD-family training algorithms. More...

#include <sgd_training_algorithm.hpp>

Inheritance diagram for lbann::SGDTrainingAlgorithm:
[legend]
Collaboration diagram for lbann::SGDTrainingAlgorithm:
[legend]

Public Member Functions

 SGDTrainingAlgorithm (std::string name, std::unique_ptr< SGDTerminationCriteria > stop, bool suppress_timer_output)
 Construct with a name. More...
 
 SGDTrainingAlgorithm (const SGDTrainingAlgorithm &other)=delete
 
SGDTrainingAlgorithmoperator= (const SGDTrainingAlgorithm &other)=delete
 
 SGDTrainingAlgorithm (SGDTrainingAlgorithm &&other)=default
 
SGDTrainingAlgorithmoperator= (SGDTrainingAlgorithm &&other)=default
 
virtual ~SGDTrainingAlgorithm ()=default
 
std::string get_type () const override
 
void apply (ExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode) override
 
void train (SGDExecutionContext &c, model &model, data_coordinator &dc, SGDTerminationCriteria const &term)
 
void evaluate (SGDExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode, SGDTerminationCriteria const &term)
 
std::unique_ptr< SGDExecutionContextget_new_execution_context () const
 Get a default-initialized execution context. More...
 
- Public Member Functions inherited from lbann::TrainingAlgorithm
 TrainingAlgorithm (std::string name)
 Constructor. More...
 
virtual ~TrainingAlgorithm ()=default
 
std::string const & get_name () const noexcept
 A user-defined string identifying the algorithm object. More...
 
void apply (model &model, data_coordinator &dc)
 Apply the algorithm to the given model. More...
 
void setup_models (std::vector< observer_ptr< model >> const &models, size_t max_mini_batch_size, const std::vector< El::Grid *> &grids)
 Setup a collection of models. More...
 
std::unique_ptr< ExecutionContextget_new_execution_context () const
 Get a default-initialized execution context that fits this training algorithm. More...
 

Protected Member Functions

bool train_mini_batch (SGDExecutionContext &c, model &model, data_coordinator &dc, ScopeTimer timer)
 
bool evaluate_mini_batch (SGDExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode, ScopeTimer timer)
 
void do_train_begin_cbs (model &model, ScopeTimer timer)
 
void do_train_end_cbs (model &model, ScopeTimer timer)
 
void do_evaluate_begin_cbs (model &model, execution_mode mode, ScopeTimer timer)
 
void do_evaluate_end_cbs (model &model, execution_mode mode, ScopeTimer timer)
 
void do_epoch_begin_cbs (model &model, ScopeTimer timer)
 
void do_epoch_end_cbs (model &model, ScopeTimer timer)
 
void do_batch_begin_cbs (model &model, execution_mode mode, ScopeTimer timer)
 
void do_batch_end_cbs (model &model, execution_mode mode, ScopeTimer timer)
 
SGDExecutionContextdo_get_new_execution_context () const override
 Covariant return-friendly implementation of get_new_exection_context(). More...
 
- Protected Member Functions inherited from lbann::TrainingAlgorithm
 TrainingAlgorithm (const TrainingAlgorithm &other)=delete
 
TrainingAlgorithmoperator= (const TrainingAlgorithm &other)=delete
 
 TrainingAlgorithm (TrainingAlgorithm &&other)=default
 
TrainingAlgorithmoperator= (TrainingAlgorithm &&other)=default
 

Private Attributes

TimerMap m_timers
 
std::unique_ptr< SGDTerminationCriteriam_stopping_criteria
 
SGDExecutionContext m_validation_context
 
size_t m_validation_epochs
 
bool m_suppress_timer = false
 Suppress timer output. More...
 

Detailed Description

Base class for LBANN SGD-family training algorithms.

Definition at line 50 of file sgd_training_algorithm.hpp.

Constructor & Destructor Documentation

◆ SGDTrainingAlgorithm() [1/3]

lbann::SGDTrainingAlgorithm::SGDTrainingAlgorithm ( std::string  name,
std::unique_ptr< SGDTerminationCriteria stop,
bool  suppress_timer_output 
)

Construct with a name.

◆ SGDTrainingAlgorithm() [2/3]

lbann::SGDTrainingAlgorithm::SGDTrainingAlgorithm ( const SGDTrainingAlgorithm other)
delete

◆ SGDTrainingAlgorithm() [3/3]

lbann::SGDTrainingAlgorithm::SGDTrainingAlgorithm ( SGDTrainingAlgorithm &&  other)
default

◆ ~SGDTrainingAlgorithm()

virtual lbann::SGDTrainingAlgorithm::~SGDTrainingAlgorithm ( )
virtualdefault

Member Function Documentation

◆ apply()

void lbann::SGDTrainingAlgorithm::apply ( ExecutionContext c,
model model,
data_coordinator dc,
execution_mode  mode 
)
overridevirtual

Apply the training algorithm to the model with the provided context and execution mode

Implements lbann::TrainingAlgorithm.

◆ do_batch_begin_cbs()

void lbann::SGDTrainingAlgorithm::do_batch_begin_cbs ( model model,
execution_mode  mode,
ScopeTimer  timer 
)
protected

Execute callbacks at start of mini-batch.

◆ do_batch_end_cbs()

void lbann::SGDTrainingAlgorithm::do_batch_end_cbs ( model model,
execution_mode  mode,
ScopeTimer  timer 
)
protected

Execute callbacks at end of mini-batch.

◆ do_epoch_begin_cbs()

void lbann::SGDTrainingAlgorithm::do_epoch_begin_cbs ( model model,
ScopeTimer  timer 
)
protected

Execute callbacks at start of epoch.

◆ do_epoch_end_cbs()

void lbann::SGDTrainingAlgorithm::do_epoch_end_cbs ( model model,
ScopeTimer  timer 
)
protected

Execute callbacks at end of epoch.

◆ do_evaluate_begin_cbs()

void lbann::SGDTrainingAlgorithm::do_evaluate_begin_cbs ( model model,
execution_mode  mode,
ScopeTimer  timer 
)
protected

Execute callbacks at start of evaluation.

◆ do_evaluate_end_cbs()

void lbann::SGDTrainingAlgorithm::do_evaluate_end_cbs ( model model,
execution_mode  mode,
ScopeTimer  timer 
)
protected

Execute callbacks at end of evaluation.

◆ do_get_new_execution_context()

SGDExecutionContext* lbann::SGDTrainingAlgorithm::do_get_new_execution_context ( ) const
overrideprotectedvirtual

Covariant return-friendly implementation of get_new_exection_context().

Implements lbann::TrainingAlgorithm.

◆ do_train_begin_cbs()

void lbann::SGDTrainingAlgorithm::do_train_begin_cbs ( model model,
ScopeTimer  timer 
)
protected

Execute callbacks at start of training.

◆ do_train_end_cbs()

void lbann::SGDTrainingAlgorithm::do_train_end_cbs ( model model,
ScopeTimer  timer 
)
protected

Execute callbacks at end of training.

◆ evaluate()

void lbann::SGDTrainingAlgorithm::evaluate ( SGDExecutionContext c,
model model,
data_coordinator dc,
execution_mode  mode,
SGDTerminationCriteria const &  term 
)

Evaluate a model using the forward pass of an SGD solver.

◆ evaluate_mini_batch()

bool lbann::SGDTrainingAlgorithm::evaluate_mini_batch ( SGDExecutionContext c,
model model,
data_coordinator dc,
execution_mode  mode,
ScopeTimer  timer 
)
protected

Evaluate model on one step / mini-batch of an SGD forward pass

◆ get_new_execution_context()

std::unique_ptr<SGDExecutionContext> lbann::SGDTrainingAlgorithm::get_new_execution_context ( ) const

Get a default-initialized execution context.

Note
This method participates in the "covariant-smart-pointer-return" pattern. In particular, it hides the base-class method to give the illusion of a covariant return.

◆ get_type()

std::string lbann::SGDTrainingAlgorithm::get_type ( ) const
overridevirtual

Copy training_algorithm.

Implements lbann::TrainingAlgorithm.

◆ operator=() [1/2]

SGDTrainingAlgorithm& lbann::SGDTrainingAlgorithm::operator= ( const SGDTrainingAlgorithm other)
delete

◆ operator=() [2/2]

SGDTrainingAlgorithm& lbann::SGDTrainingAlgorithm::operator= ( SGDTrainingAlgorithm &&  other)
default

◆ train()

void lbann::SGDTrainingAlgorithm::train ( SGDExecutionContext c,
model model,
data_coordinator dc,
SGDTerminationCriteria const &  term 
)

Train a model using an iterative SGD solver.

◆ train_mini_batch()

bool lbann::SGDTrainingAlgorithm::train_mini_batch ( SGDExecutionContext c,
model model,
data_coordinator dc,
ScopeTimer  timer 
)
protected

Train model on one step / mini-batch of an SGD forward pass

Member Data Documentation

◆ m_stopping_criteria

std::unique_ptr<SGDTerminationCriteria> lbann::SGDTrainingAlgorithm::m_stopping_criteria
private

Definition at line 142 of file sgd_training_algorithm.hpp.

◆ m_suppress_timer

bool lbann::SGDTrainingAlgorithm::m_suppress_timer = false
private

Suppress timer output.

Deprecated:
This is a temporary way to disable timer output. This will be more configurable in the future.

Definition at line 156 of file sgd_training_algorithm.hpp.

◆ m_timers

TimerMap lbann::SGDTrainingAlgorithm::m_timers
private

Definition at line 141 of file sgd_training_algorithm.hpp.

◆ m_validation_context

SGDExecutionContext lbann::SGDTrainingAlgorithm::m_validation_context
private

Definition at line 148 of file sgd_training_algorithm.hpp.

◆ m_validation_epochs

size_t lbann::SGDTrainingAlgorithm::m_validation_epochs
private

Definition at line 149 of file sgd_training_algorithm.hpp.


The documentation for this class was generated from the following file: