|
LBANN
0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
|
User-facing class that represents a set of compute resources. More...
#include <trainer.hpp>
Public Types | |
| using | execution_context_key_pair_t = typename std::pair< observer_ptr< model >, execution_mode > |
Public Member Functions | |
| execution_context_key_pair_t | check_and_build_execution_context (TrainingAlgorithm &alg, observer_ptr< model > model, execution_mode mode) |
| execution_context_key_pair_t | check_and_build_execution_context (ExecutionContext &c, model &model, execution_mode mode) |
| ExecutionContext & | get_execution_context (observer_ptr< model > model, execution_mode mode) |
| ExecutionContext & | get_execution_context (execution_context_key_pair_t key) |
| bool | execution_context_valid (model &m, execution_mode mode) const noexcept |
| bool | execution_context_valid (execution_context_key_pair_t key) const noexcept |
Lifecycle management | |
| trainer (lbann_comm *comm, std::unique_ptr< data_coordinator > dc, size_t mini_batch_size, std::unique_ptr< TrainingAlgorithm > alg=nullptr) | |
| Construct with a communicator and data coordinator. More... | |
| ~trainer () | |
Serialization | |
| template<class Archive > | |
| void | serialize (Archive &ar) |
| Archive for checkpoint and restart. More... | |
Configuration | |
| void | set_name (std::string const &name) |
| Set the trainer's name. More... | |
| void | set_random_seeds (int root_random_seed, int random_seed, int data_seq_random_seed) |
| Set the random seeds used for the trainer. More... | |
| void | add_callback (std::shared_ptr< callback_base > cb) |
| void | setup (std::unique_ptr< thread_pool > io_thread_pool, std::map< execution_mode, generic_data_reader *> data_readers) |
| Set up the trainer. More... | |
| void | allow_background_io_activity (bool enable) |
| Set a flag that can be used to enable / disable the background I/O activities. More... | |
Queries | |
| std::string | get_name () const |
| description | get_description () const |
| int | get_random_seed () const noexcept |
| int | get_data_seq_random_seed () const noexcept |
| std::vector< observer_ptr< callback_base > > | get_callbacks () const |
| Get the list of callbacks for the trainer. More... | |
| std::vector< std::shared_ptr< callback_base > > & | get_callbacks_with_ownership () |
| const data_coordinator & | get_data_coordinator () const |
| data_coordinator & | get_data_coordinator () |
| thread_pool & | get_io_thread_pool () const |
| Get the I/O thread pool. More... | |
| lbann_comm * | get_comm () const noexcept |
| Get the trainer's comm. More... | |
| persist & | get_persist_obj () noexcept |
| Get the trainer's persist object. More... | |
| size_t | get_max_mini_batch_size () const noexcept |
| Get the trainer's maximum mini-batch size. More... | |
| bool | background_io_activity_allowed () const noexcept |
| Are background I/O activities enabled by the input layers. More... | |
Training and evaluation interface | |
| void | train (observer_ptr< model > model, El::Int num_epochs, El::Int num_batches=0) |
| void | evaluate (observer_ptr< model > model, execution_mode mode, El::Int num_batches=0) |
Sub-grid management | |
| std::vector< El::Grid * > | get_grids () const |
| void | add_grid (std::unique_ptr< El::Grid > g) |
Checkpointing | |
| bool | save_to_checkpoint_shared () |
| Create a shared checkpoint of the trainer. More... | |
| bool | load_from_checkpoint_shared (persist &p) |
| Restore trainer from a shared checkpoint. More... | |
| bool | load_from_checkpoint_shared (model &m, ExecutionContext &c) |
| Restore model from a shared checkpoint. More... | |
| bool | save_to_checkpoint_distributed () |
| Create a distributed checkpoint of the trainer. More... | |
| bool | load_from_checkpoint_distributed (persist &p) |
| Restore a trainer from a distributed checkpoint. More... | |
| bool | load_from_checkpoint_distributed (model &m, ExecutionContext &c) |
| Restore a model from a distributed checkpoint. More... | |
| void | write_proto (lbann_data::Trainer &proto) |
| Write trainer to proto message. More... | |
Private Types | |
| using | model_execution_context_hash_t = pair_hash< observer_ptr< model >, execution_mode, std::hash< observer_ptr< model > >, enum_hash< execution_mode > > |
Hash function for m_model_execution_context. More... | |
| using | ModelContextMapType = std::unordered_map< std::pair< observer_ptr< model >, execution_mode >, std::unique_ptr< ExecutionContext >, model_execution_context_hash_t > |
Private Member Functions | |
| void | delete_execution_context (execution_context_key_pair_t key) |
| void | for_each_execution_context (std::function< void(observer_ptr< ExecutionContext >)> fn) |
Private Attributes | |
| persist | m_persist |
| Persist object used for serializing LBANN classes. More... | |
| ModelContextMapType | m_model_execution_context |
| Map from model and execution mode to its execution context. More... | |
| std::string | m_name |
| This trainer's name. More... | |
| std::vector< std::shared_ptr< callback_base > > | m_callbacks |
| Current callbacks to process. More... | |
| std::unique_ptr< thread_pool > | m_io_thread_pool |
| Threads available for I/O. More... | |
| std::unique_ptr< data_coordinator > | m_data_coordinator |
| Data Coordinator holding trainers data readers. More... | |
| std::unique_ptr< TrainingAlgorithm > | m_training_alg |
| The training algorithm being used. May be null. More... | |
| lbann_comm * | m_comm |
| Communication domain for the trainer. More... | |
| std::vector< std::unique_ptr< El::Grid > > | m_grids |
| Processor grids for sub-grid parallelism. More... | |
| size_t | m_max_mini_batch_size |
| Maximum possible minibatch size supported by models and layers in this trainer. More... | |
| int | m_root_random_seed |
| Root of the random seed tree. More... | |
| int | m_random_seed |
| Random seed used for the general RNGs. More... | |
| int | m_data_seq_random_seed |
| Random seed used for the RNG used to fetch data. More... | |
| bool | m_background_io_allowed |
| Flag that allows input layers to fetch data in the background. More... | |
User-facing class that represents a set of compute resources.
A trainer is responsible for managing the interactions of an lbann_comm object with other objects in the library, most notably models and data_coordinators.
Definition at line 60 of file trainer.hpp.
| using lbann::trainer::execution_context_key_pair_t = typename std::pair<observer_ptr<model>, execution_mode> |
Definition at line 209 of file trainer.hpp.
|
private |
Hash function for m_model_execution_context.
Definition at line 287 of file trainer.hpp.
|
private |
Definition at line 292 of file trainer.hpp.
| lbann::trainer::trainer | ( | lbann_comm * | comm, |
| std::unique_ptr< data_coordinator > | dc, | ||
| size_t | mini_batch_size, | ||
| std::unique_ptr< TrainingAlgorithm > | alg = nullptr |
||
| ) |
Construct with a communicator and data coordinator.
| [in] | comm | A reference to a valid lbann_comm object. |
| [in] | dc | The data coordinator used by this trainer. |
| [in] | mini_batch_size | The minibatch size? What's a minibatch? That sounds like an SGD thing... |
| [in] | alg | The training algorithm to use. |
mini_batch_size is here. | lbann::trainer::~trainer | ( | ) |
|
inline |
| void lbann::trainer::add_grid | ( | std::unique_ptr< El::Grid > | g | ) |
|
inline |
Set a flag that can be used to enable / disable the background I/O activities.
Definition at line 125 of file trainer.hpp.
|
inlinenoexcept |
Are background I/O activities enabled by the input layers.
Definition at line 201 of file trainer.hpp.
| execution_context_key_pair_t lbann::trainer::check_and_build_execution_context | ( | TrainingAlgorithm & | alg, |
| observer_ptr< model > | model, | ||
| execution_mode | mode | ||
| ) |
| execution_context_key_pair_t lbann::trainer::check_and_build_execution_context | ( | ExecutionContext & | c, |
| model & | model, | ||
| execution_mode | mode | ||
| ) |
|
private |
| void lbann::trainer::evaluate | ( | observer_ptr< model > | model, |
| execution_mode | mode, | ||
| El::Int | num_batches = 0 |
||
| ) |
|
noexcept |
|
noexcept |
|
private |
|
inline |
Get the list of callbacks for the trainer.
Definition at line 150 of file trainer.hpp.
|
inline |
Definition at line 160 of file trainer.hpp.
|
inlinenoexcept |
Get the trainer's comm.
Definition at line 189 of file trainer.hpp.
|
inline |
|
inline |
|
inlinenoexcept |
Definition at line 144 of file trainer.hpp.
| description lbann::trainer::get_description | ( | ) | const |
Human-readable description.
| ExecutionContext& lbann::trainer::get_execution_context | ( | observer_ptr< model > | model, |
| execution_mode | mode | ||
| ) |
| ExecutionContext& lbann::trainer::get_execution_context | ( | execution_context_key_pair_t | key | ) |
| std::vector<El::Grid*> lbann::trainer::get_grids | ( | ) | const |
|
inline |
Get the I/O thread pool.
Definition at line 180 of file trainer.hpp.
|
inlinenoexcept |
Get the trainer's maximum mini-batch size.
Definition at line 195 of file trainer.hpp.
|
inline |
Return the trainer's name; this is an arbitrary string that may be useful in multi-trainer scenarios, e.g, LTFB, jag
Definition at line 138 of file trainer.hpp.
|
inlinenoexcept |
Get the trainer's persist object.
Definition at line 192 of file trainer.hpp.
|
inlinenoexcept |
Definition at line 143 of file trainer.hpp.
| bool lbann::trainer::load_from_checkpoint_distributed | ( | persist & | p | ) |
Restore a trainer from a distributed checkpoint.
| bool lbann::trainer::load_from_checkpoint_distributed | ( | model & | m, |
| ExecutionContext & | c | ||
| ) |
Restore a model from a distributed checkpoint.
| bool lbann::trainer::load_from_checkpoint_shared | ( | persist & | p | ) |
Restore trainer from a shared checkpoint.
| bool lbann::trainer::load_from_checkpoint_shared | ( | model & | m, |
| ExecutionContext & | c | ||
| ) |
Restore model from a shared checkpoint.
| bool lbann::trainer::save_to_checkpoint_distributed | ( | ) |
Create a distributed checkpoint of the trainer.
| bool lbann::trainer::save_to_checkpoint_shared | ( | ) |
Create a shared checkpoint of the trainer.
| void lbann::trainer::serialize | ( | Archive & | ar | ) |
Archive for checkpoint and restart.
| void lbann::trainer::set_name | ( | std::string const & | name | ) |
Set the trainer's name.
This is an arbitrary string that may be useful in multi-trainer scenarios, e.g, LTFB, jag, etc.
|
inline |
Set the random seeds used for the trainer.
Definition at line 100 of file trainer.hpp.
| void lbann::trainer::setup | ( | std::unique_ptr< thread_pool > | io_thread_pool, |
| std::map< execution_mode, generic_data_reader *> | data_readers | ||
| ) |
Set up the trainer.
| void lbann::trainer::train | ( | observer_ptr< model > | model, |
| El::Int | num_epochs, | ||
| El::Int | num_batches = 0 |
||
| ) |
| void lbann::trainer::write_proto | ( | lbann_data::Trainer & | proto | ) |
Write trainer to proto message.
|
private |
Flag that allows input layers to fetch data in the background.
Definition at line 344 of file trainer.hpp.
|
private |
Current callbacks to process.
Definition at line 301 of file trainer.hpp.
|
private |
Communication domain for the trainer.
Definition at line 316 of file trainer.hpp.
|
private |
Data Coordinator holding trainers data readers.
Definition at line 307 of file trainer.hpp.
|
private |
Random seed used for the RNG used to fetch data.
Definition at line 341 of file trainer.hpp.
|
private |
Processor grids for sub-grid parallelism.
Does not include grid 0, which corresponds to the trainer's MPI communicator.
Definition at line 323 of file trainer.hpp.
|
private |
Threads available for I/O.
Definition at line 304 of file trainer.hpp.
|
private |
Maximum possible minibatch size supported by models and layers in this trainer.
Definition at line 330 of file trainer.hpp.
|
private |
Map from model and execution mode to its execution context.
Definition at line 295 of file trainer.hpp.
|
private |
This trainer's name.
Definition at line 298 of file trainer.hpp.
|
private |
Persist object used for serializing LBANN classes.
Definition at line 280 of file trainer.hpp.
|
private |
Random seed used for the general RNGs.
Definition at line 338 of file trainer.hpp.
|
private |
Root of the random seed tree.
Either default or user supplied.
Definition at line 335 of file trainer.hpp.
|
private |
The training algorithm being used. May be null.
If null, a different type of execution algorithm is being used (e.g., inference).
Definition at line 313 of file trainer.hpp.