|
LBANN
0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
|
Base class for LBANN training_algorithms. More...
#include <training_algorithm.hpp>
Public Member Functions | |
Lifecycle Management | |
| TrainingAlgorithm (std::string name) | |
| Constructor. More... | |
| virtual | ~TrainingAlgorithm ()=default |
Queries | |
| virtual std::string | get_type () const =0 |
| A string identifying the type of the object. More... | |
| std::string const & | get_name () const noexcept |
| A user-defined string identifying the algorithm object. More... | |
Execution interfaces | |
| virtual void | apply (ExecutionContext &context, model &model, data_coordinator &dc, execution_mode mode)=0 |
| Apply the algorithm to the given model. More... | |
| void | apply (model &model, data_coordinator &dc) |
| Apply the algorithm to the given model. More... | |
| void | setup_models (std::vector< observer_ptr< model >> const &models, size_t max_mini_batch_size, const std::vector< El::Grid *> &grids) |
| Setup a collection of models. More... | |
| std::unique_ptr< ExecutionContext > | get_new_execution_context () const |
| Get a default-initialized execution context that fits this training algorithm. More... | |
Protected Member Functions | |
| virtual ExecutionContext * | do_get_new_execution_context () const =0 |
Covariant return-friendly implementation of get_new_exection_context(). More... | |
In-hierarchy Lifecycle Management | |
| TrainingAlgorithm (const TrainingAlgorithm &other)=delete | |
| TrainingAlgorithm & | operator= (const TrainingAlgorithm &other)=delete |
| TrainingAlgorithm (TrainingAlgorithm &&other)=default | |
| TrainingAlgorithm & | operator= (TrainingAlgorithm &&other)=default |
Private Attributes | |
| std::string | m_name |
| The user-defined name of the algorithm. More... | |
Base class for LBANN training_algorithms.
A "training algorithm" is defined as a method for modifying one or more models, where "model" is defined in the LBANN sense (that is, a model object typically consists of a machine learning model plus a "sub-DAG" for computing a training-specific objective function). At this time, we only have support for training a single model unit, though some ad hoc methods exist for training multi-model scenarios such as GANs.
Logically, the inputs to a training algorithm are a model architecture (encapsulated in a model object) and a data source, and the output is a trained model (or, a set of parameters that define the action of the model). Here, "trained" means that the training algorithm has evolved the parameters until user-specified stopping criteria have been met; it does necessarily imply that any underlying optimization method has converged (or even exists) or that such a convergence is even well-defined.
A key capability is that training algorithms should be composable. This allows metaheuristic algorithms to simply be implemented as training algorithms constructed from one or more "inner" training algorithms.
Definition at line 86 of file training_algorithm.hpp.
| lbann::TrainingAlgorithm::TrainingAlgorithm | ( | std::string | name | ) |
Constructor.
| [in] | name | The user-defined name of the algorithm. |
|
virtualdefault |
|
protecteddelete |
|
protecteddefault |
|
pure virtual |
Apply the algorithm to the given model.
| [in,out] | context | The persistent state tracked by the model. |
| [in,out] | model | A model architecture with trainable weights. On exit, the weights will have been updated according to the algorithm. |
| [in,out] | dc | The data source for this round of training. |
| [in] | mode | IMO, superfluous. Will be removed. |
Implemented in lbann::KFAC, lbann::LTFB, and lbann::SGDTrainingAlgorithm.
|
inline |
Apply the algorithm to the given model.
| [in,out] | model | A model architecture with trainable weights. On exit, the weights will have been updated according to the algorithm. |
| [in,out] | dc | The data source for this round of training. |
Definition at line 129 of file training_algorithm.hpp.
|
protectedpure virtual |
Covariant return-friendly implementation of get_new_exection_context().
Implemented in lbann::KFAC, lbann::SGDTrainingAlgorithm, and lbann::LTFB.
|
noexcept |
A user-defined string identifying the algorithm object.
|
inline |
Get a default-initialized execution context that fits this training algorithm.
This method gets a clean, default-initialized execution context suitable for the training algorithm being used. The concrete type is guaranteed to match the concrete type required by the training algorithm.
do_get_new_execution_context(). Definition at line 157 of file training_algorithm.hpp.
|
pure virtual |
A string identifying the type of the object.
Implemented in lbann::KFAC, lbann::LTFB, and lbann::SGDTrainingAlgorithm.
|
protecteddelete |
|
protecteddefault |
| void lbann::TrainingAlgorithm::setup_models | ( | std::vector< observer_ptr< model >> const & | models, |
| size_t | max_mini_batch_size, | ||
| const std::vector< El::Grid *> & | grids | ||
| ) |
Setup a collection of models.
| [in] | models | The collection of models to be setup. |
| [in] | max_mini_batch_size | The largest minibatch size accepted by any model. |
| [in] | grids | Process grids for distributed tensors. |
|
private |
The user-defined name of the algorithm.
Definition at line 179 of file training_algorithm.hpp.