Base class for LBANN SGD-family training algorithms.
More...
#include <sgd_training_algorithm.hpp>
|
| | SGDTrainingAlgorithm (std::string name, std::unique_ptr< SGDTerminationCriteria > stop, bool suppress_timer_output) |
| | Construct with a name. More...
|
| |
| | SGDTrainingAlgorithm (const SGDTrainingAlgorithm &other)=delete |
| |
| SGDTrainingAlgorithm & | operator= (const SGDTrainingAlgorithm &other)=delete |
| |
| | SGDTrainingAlgorithm (SGDTrainingAlgorithm &&other)=default |
| |
| SGDTrainingAlgorithm & | operator= (SGDTrainingAlgorithm &&other)=default |
| |
| virtual | ~SGDTrainingAlgorithm ()=default |
| |
| std::string | get_type () const override |
| |
| void | apply (ExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode) override |
| |
| void | train (SGDExecutionContext &c, model &model, data_coordinator &dc, SGDTerminationCriteria const &term) |
| |
| void | evaluate (SGDExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode, SGDTerminationCriteria const &term) |
| |
| std::unique_ptr< SGDExecutionContext > | get_new_execution_context () const |
| | Get a default-initialized execution context. More...
|
| |
| | TrainingAlgorithm (std::string name) |
| | Constructor. More...
|
| |
| virtual | ~TrainingAlgorithm ()=default |
| |
| std::string const & | get_name () const noexcept |
| | A user-defined string identifying the algorithm object. More...
|
| |
| void | apply (model &model, data_coordinator &dc) |
| | Apply the algorithm to the given model. More...
|
| |
| void | setup_models (std::vector< observer_ptr< model >> const &models, size_t max_mini_batch_size, const std::vector< El::Grid *> &grids) |
| | Setup a collection of models. More...
|
| |
| std::unique_ptr< ExecutionContext > | get_new_execution_context () const |
| | Get a default-initialized execution context that fits this training algorithm. More...
|
| |
|
| bool | train_mini_batch (SGDExecutionContext &c, model &model, data_coordinator &dc, ScopeTimer timer) |
| |
| bool | evaluate_mini_batch (SGDExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode, ScopeTimer timer) |
| |
| void | do_train_begin_cbs (model &model, ScopeTimer timer) |
| |
| void | do_train_end_cbs (model &model, ScopeTimer timer) |
| |
| void | do_evaluate_begin_cbs (model &model, execution_mode mode, ScopeTimer timer) |
| |
| void | do_evaluate_end_cbs (model &model, execution_mode mode, ScopeTimer timer) |
| |
| void | do_epoch_begin_cbs (model &model, ScopeTimer timer) |
| |
| void | do_epoch_end_cbs (model &model, ScopeTimer timer) |
| |
| void | do_batch_begin_cbs (model &model, execution_mode mode, ScopeTimer timer) |
| |
| void | do_batch_end_cbs (model &model, execution_mode mode, ScopeTimer timer) |
| |
| SGDExecutionContext * | do_get_new_execution_context () const override |
| | Covariant return-friendly implementation of get_new_exection_context(). More...
|
| |
| | TrainingAlgorithm (const TrainingAlgorithm &other)=delete |
| |
| TrainingAlgorithm & | operator= (const TrainingAlgorithm &other)=delete |
| |
| | TrainingAlgorithm (TrainingAlgorithm &&other)=default |
| |
| TrainingAlgorithm & | operator= (TrainingAlgorithm &&other)=default |
| |
Base class for LBANN SGD-family training algorithms.
Definition at line 50 of file sgd_training_algorithm.hpp.
◆ SGDTrainingAlgorithm() [1/3]
| lbann::SGDTrainingAlgorithm::SGDTrainingAlgorithm |
( |
std::string |
name, |
|
|
std::unique_ptr< SGDTerminationCriteria > |
stop, |
|
|
bool |
suppress_timer_output |
|
) |
| |
◆ SGDTrainingAlgorithm() [2/3]
◆ SGDTrainingAlgorithm() [3/3]
◆ ~SGDTrainingAlgorithm()
| virtual lbann::SGDTrainingAlgorithm::~SGDTrainingAlgorithm |
( |
| ) |
|
|
virtualdefault |
◆ apply()
Apply the training algorithm to the model with the provided context and execution mode
Implements lbann::TrainingAlgorithm.
◆ do_batch_begin_cbs()
Execute callbacks at start of mini-batch.
◆ do_batch_end_cbs()
Execute callbacks at end of mini-batch.
◆ do_epoch_begin_cbs()
| void lbann::SGDTrainingAlgorithm::do_epoch_begin_cbs |
( |
model & |
model, |
|
|
ScopeTimer |
timer |
|
) |
| |
|
protected |
Execute callbacks at start of epoch.
◆ do_epoch_end_cbs()
| void lbann::SGDTrainingAlgorithm::do_epoch_end_cbs |
( |
model & |
model, |
|
|
ScopeTimer |
timer |
|
) |
| |
|
protected |
Execute callbacks at end of epoch.
◆ do_evaluate_begin_cbs()
Execute callbacks at start of evaluation.
◆ do_evaluate_end_cbs()
Execute callbacks at end of evaluation.
◆ do_get_new_execution_context()
◆ do_train_begin_cbs()
| void lbann::SGDTrainingAlgorithm::do_train_begin_cbs |
( |
model & |
model, |
|
|
ScopeTimer |
timer |
|
) |
| |
|
protected |
Execute callbacks at start of training.
◆ do_train_end_cbs()
| void lbann::SGDTrainingAlgorithm::do_train_end_cbs |
( |
model & |
model, |
|
|
ScopeTimer |
timer |
|
) |
| |
|
protected |
Execute callbacks at end of training.
◆ evaluate()
Evaluate a model using the forward pass of an SGD solver.
◆ evaluate_mini_batch()
Evaluate model on one step / mini-batch of an SGD forward pass
◆ get_new_execution_context()
| std::unique_ptr<SGDExecutionContext> lbann::SGDTrainingAlgorithm::get_new_execution_context |
( |
| ) |
const |
Get a default-initialized execution context.
- Note
- This method participates in the "covariant-smart-pointer-return" pattern. In particular, it hides the base-class method to give the illusion of a covariant return.
◆ get_type()
| std::string lbann::SGDTrainingAlgorithm::get_type |
( |
| ) |
const |
|
overridevirtual |
◆ operator=() [1/2]
◆ operator=() [2/2]
◆ train()
Train a model using an iterative SGD solver.
◆ train_mini_batch()
Train model on one step / mini-batch of an SGD forward pass
◆ m_stopping_criteria
◆ m_suppress_timer
| bool lbann::SGDTrainingAlgorithm::m_suppress_timer = false |
|
private |
◆ m_timers
| TimerMap lbann::SGDTrainingAlgorithm::m_timers |
|
private |
◆ m_validation_context
◆ m_validation_epochs
| size_t lbann::SGDTrainingAlgorithm::m_validation_epochs |
|
private |
The documentation for this class was generated from the following file: