LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::trainer Class Reference

User-facing class that represents a set of compute resources. More...

#include <trainer.hpp>

Collaboration diagram for lbann::trainer:
[legend]

Public Types

using execution_context_key_pair_t = typename std::pair< observer_ptr< model >, execution_mode >
 

Public Member Functions

execution_context_key_pair_t check_and_build_execution_context (TrainingAlgorithm &alg, observer_ptr< model > model, execution_mode mode)
 
execution_context_key_pair_t check_and_build_execution_context (ExecutionContext &c, model &model, execution_mode mode)
 
ExecutionContextget_execution_context (observer_ptr< model > model, execution_mode mode)
 
ExecutionContextget_execution_context (execution_context_key_pair_t key)
 
bool execution_context_valid (model &m, execution_mode mode) const noexcept
 
bool execution_context_valid (execution_context_key_pair_t key) const noexcept
 
Lifecycle management
 trainer (lbann_comm *comm, std::unique_ptr< data_coordinator > dc, size_t mini_batch_size, std::unique_ptr< TrainingAlgorithm > alg=nullptr)
 Construct with a communicator and data coordinator. More...
 
 ~trainer ()
 
Serialization
template<class Archive >
void serialize (Archive &ar)
 Archive for checkpoint and restart. More...
 
Configuration
void set_name (std::string const &name)
 Set the trainer's name. More...
 
void set_random_seeds (int root_random_seed, int random_seed, int data_seq_random_seed)
 Set the random seeds used for the trainer. More...
 
void add_callback (std::shared_ptr< callback_base > cb)
 
void setup (std::unique_ptr< thread_pool > io_thread_pool, std::map< execution_mode, generic_data_reader *> data_readers)
 Set up the trainer. More...
 
void allow_background_io_activity (bool enable)
 Set a flag that can be used to enable / disable the background I/O activities. More...
 
Queries
std::string get_name () const
 
description get_description () const
 
int get_random_seed () const noexcept
 
int get_data_seq_random_seed () const noexcept
 
std::vector< observer_ptr< callback_base > > get_callbacks () const
 Get the list of callbacks for the trainer. More...
 
std::vector< std::shared_ptr< callback_base > > & get_callbacks_with_ownership ()
 
const data_coordinatorget_data_coordinator () const
 
data_coordinatorget_data_coordinator ()
 
thread_poolget_io_thread_pool () const
 Get the I/O thread pool. More...
 
lbann_commget_comm () const noexcept
 Get the trainer's comm. More...
 
persistget_persist_obj () noexcept
 Get the trainer's persist object. More...
 
size_t get_max_mini_batch_size () const noexcept
 Get the trainer's maximum mini-batch size. More...
 
bool background_io_activity_allowed () const noexcept
 Are background I/O activities enabled by the input layers. More...
 
Training and evaluation interface
void train (observer_ptr< model > model, El::Int num_epochs, El::Int num_batches=0)
 
void evaluate (observer_ptr< model > model, execution_mode mode, El::Int num_batches=0)
 
Sub-grid management
std::vector< El::Grid * > get_grids () const
 
void add_grid (std::unique_ptr< El::Grid > g)
 
Checkpointing
bool save_to_checkpoint_shared ()
 Create a shared checkpoint of the trainer. More...
 
bool load_from_checkpoint_shared (persist &p)
 Restore trainer from a shared checkpoint. More...
 
bool load_from_checkpoint_shared (model &m, ExecutionContext &c)
 Restore model from a shared checkpoint. More...
 
bool save_to_checkpoint_distributed ()
 Create a distributed checkpoint of the trainer. More...
 
bool load_from_checkpoint_distributed (persist &p)
 Restore a trainer from a distributed checkpoint. More...
 
bool load_from_checkpoint_distributed (model &m, ExecutionContext &c)
 Restore a model from a distributed checkpoint. More...
 
void write_proto (lbann_data::Trainer &proto)
 Write trainer to proto message. More...
 

Private Types

using model_execution_context_hash_t = pair_hash< observer_ptr< model >, execution_mode, std::hash< observer_ptr< model > >, enum_hash< execution_mode > >
 Hash function for m_model_execution_context. More...
 
using ModelContextMapType = std::unordered_map< std::pair< observer_ptr< model >, execution_mode >, std::unique_ptr< ExecutionContext >, model_execution_context_hash_t >
 

Private Member Functions

void delete_execution_context (execution_context_key_pair_t key)
 
void for_each_execution_context (std::function< void(observer_ptr< ExecutionContext >)> fn)
 

Private Attributes

persist m_persist
 Persist object used for serializing LBANN classes. More...
 
ModelContextMapType m_model_execution_context
 Map from model and execution mode to its execution context. More...
 
std::string m_name
 This trainer's name. More...
 
std::vector< std::shared_ptr< callback_base > > m_callbacks
 Current callbacks to process. More...
 
std::unique_ptr< thread_poolm_io_thread_pool
 Threads available for I/O. More...
 
std::unique_ptr< data_coordinatorm_data_coordinator
 Data Coordinator holding trainers data readers. More...
 
std::unique_ptr< TrainingAlgorithmm_training_alg
 The training algorithm being used. May be null. More...
 
lbann_commm_comm
 Communication domain for the trainer. More...
 
std::vector< std::unique_ptr< El::Grid > > m_grids
 Processor grids for sub-grid parallelism. More...
 
size_t m_max_mini_batch_size
 Maximum possible minibatch size supported by models and layers in this trainer. More...
 
int m_root_random_seed
 Root of the random seed tree. More...
 
int m_random_seed
 Random seed used for the general RNGs. More...
 
int m_data_seq_random_seed
 Random seed used for the RNG used to fetch data. More...
 
bool m_background_io_allowed
 Flag that allows input layers to fetch data in the background. More...
 

Detailed Description

User-facing class that represents a set of compute resources.

A trainer is responsible for managing the interactions of an lbann_comm object with other objects in the library, most notably models and data_coordinators.

Definition at line 60 of file trainer.hpp.

Member Typedef Documentation

◆ execution_context_key_pair_t

Definition at line 209 of file trainer.hpp.

◆ model_execution_context_hash_t

Hash function for m_model_execution_context.

Definition at line 287 of file trainer.hpp.

◆ ModelContextMapType

using lbann::trainer::ModelContextMapType = std::unordered_map<std::pair<observer_ptr<model>, execution_mode>, std::unique_ptr<ExecutionContext>, model_execution_context_hash_t>
private

Definition at line 292 of file trainer.hpp.

Constructor & Destructor Documentation

◆ trainer()

lbann::trainer::trainer ( lbann_comm comm,
std::unique_ptr< data_coordinator dc,
size_t  mini_batch_size,
std::unique_ptr< TrainingAlgorithm alg = nullptr 
)

Construct with a communicator and data coordinator.

Parameters
[in]commA reference to a valid lbann_comm object.
[in]dcThe data coordinator used by this trainer.
[in]mini_batch_sizeThe minibatch size? What's a minibatch? That sounds like an SGD thing...
[in]algThe training algorithm to use.
Todo:
I don't know why mini_batch_size is here.

◆ ~trainer()

lbann::trainer::~trainer ( )

Member Function Documentation

◆ add_callback()

void lbann::trainer::add_callback ( std::shared_ptr< callback_base cb)
inline

Definition at line 109 of file trainer.hpp.

Here is the call graph for this function:

◆ add_grid()

void lbann::trainer::add_grid ( std::unique_ptr< El::Grid >  g)

◆ allow_background_io_activity()

void lbann::trainer::allow_background_io_activity ( bool  enable)
inline

Set a flag that can be used to enable / disable the background I/O activities.

Definition at line 125 of file trainer.hpp.

◆ background_io_activity_allowed()

bool lbann::trainer::background_io_activity_allowed ( ) const
inlinenoexcept

Are background I/O activities enabled by the input layers.

Definition at line 201 of file trainer.hpp.

◆ check_and_build_execution_context() [1/2]

execution_context_key_pair_t lbann::trainer::check_and_build_execution_context ( TrainingAlgorithm alg,
observer_ptr< model model,
execution_mode  mode 
)

◆ check_and_build_execution_context() [2/2]

execution_context_key_pair_t lbann::trainer::check_and_build_execution_context ( ExecutionContext c,
model model,
execution_mode  mode 
)

◆ delete_execution_context()

void lbann::trainer::delete_execution_context ( execution_context_key_pair_t  key)
private

◆ evaluate()

void lbann::trainer::evaluate ( observer_ptr< model model,
execution_mode  mode,
El::Int  num_batches = 0 
)

◆ execution_context_valid() [1/2]

bool lbann::trainer::execution_context_valid ( model m,
execution_mode  mode 
) const
noexcept

◆ execution_context_valid() [2/2]

bool lbann::trainer::execution_context_valid ( execution_context_key_pair_t  key) const
noexcept

◆ for_each_execution_context()

void lbann::trainer::for_each_execution_context ( std::function< void(observer_ptr< ExecutionContext >)>  fn)
private

◆ get_callbacks()

std::vector<observer_ptr<callback_base> > lbann::trainer::get_callbacks ( ) const
inline

Get the list of callbacks for the trainer.

Definition at line 150 of file trainer.hpp.

◆ get_callbacks_with_ownership()

std::vector<std::shared_ptr<callback_base> >& lbann::trainer::get_callbacks_with_ownership ( )
inline

Definition at line 160 of file trainer.hpp.

◆ get_comm()

lbann_comm* lbann::trainer::get_comm ( ) const
inlinenoexcept

Get the trainer's comm.

Definition at line 189 of file trainer.hpp.

◆ get_data_coordinator() [1/2]

const data_coordinator& lbann::trainer::get_data_coordinator ( ) const
inline

Definition at line 165 of file trainer.hpp.

Here is the caller graph for this function:

◆ get_data_coordinator() [2/2]

data_coordinator& lbann::trainer::get_data_coordinator ( )
inline

Definition at line 173 of file trainer.hpp.

Here is the call graph for this function:

◆ get_data_seq_random_seed()

int lbann::trainer::get_data_seq_random_seed ( ) const
inlinenoexcept

Definition at line 144 of file trainer.hpp.

◆ get_description()

description lbann::trainer::get_description ( ) const

Human-readable description.

Here is the caller graph for this function:

◆ get_execution_context() [1/2]

ExecutionContext& lbann::trainer::get_execution_context ( observer_ptr< model model,
execution_mode  mode 
)

◆ get_execution_context() [2/2]

ExecutionContext& lbann::trainer::get_execution_context ( execution_context_key_pair_t  key)

◆ get_grids()

std::vector<El::Grid*> lbann::trainer::get_grids ( ) const

◆ get_io_thread_pool()

thread_pool& lbann::trainer::get_io_thread_pool ( ) const
inline

Get the I/O thread pool.

Definition at line 180 of file trainer.hpp.

◆ get_max_mini_batch_size()

size_t lbann::trainer::get_max_mini_batch_size ( ) const
inlinenoexcept

Get the trainer's maximum mini-batch size.

Definition at line 195 of file trainer.hpp.

◆ get_name()

std::string lbann::trainer::get_name ( ) const
inline

Return the trainer's name; this is an arbitrary string that may be useful in multi-trainer scenarios, e.g, LTFB, jag

Definition at line 138 of file trainer.hpp.

Here is the call graph for this function:

◆ get_persist_obj()

persist& lbann::trainer::get_persist_obj ( )
inlinenoexcept

Get the trainer's persist object.

Definition at line 192 of file trainer.hpp.

◆ get_random_seed()

int lbann::trainer::get_random_seed ( ) const
inlinenoexcept

Definition at line 143 of file trainer.hpp.

◆ load_from_checkpoint_distributed() [1/2]

bool lbann::trainer::load_from_checkpoint_distributed ( persist p)

Restore a trainer from a distributed checkpoint.

◆ load_from_checkpoint_distributed() [2/2]

bool lbann::trainer::load_from_checkpoint_distributed ( model m,
ExecutionContext c 
)

Restore a model from a distributed checkpoint.

◆ load_from_checkpoint_shared() [1/2]

bool lbann::trainer::load_from_checkpoint_shared ( persist p)

Restore trainer from a shared checkpoint.

◆ load_from_checkpoint_shared() [2/2]

bool lbann::trainer::load_from_checkpoint_shared ( model m,
ExecutionContext c 
)

Restore model from a shared checkpoint.

◆ save_to_checkpoint_distributed()

bool lbann::trainer::save_to_checkpoint_distributed ( )

Create a distributed checkpoint of the trainer.

◆ save_to_checkpoint_shared()

bool lbann::trainer::save_to_checkpoint_shared ( )

Create a shared checkpoint of the trainer.

◆ serialize()

template<class Archive >
void lbann::trainer::serialize ( Archive &  ar)

Archive for checkpoint and restart.

◆ set_name()

void lbann::trainer::set_name ( std::string const &  name)

Set the trainer's name.

This is an arbitrary string that may be useful in multi-trainer scenarios, e.g, LTFB, jag, etc.

◆ set_random_seeds()

void lbann::trainer::set_random_seeds ( int  root_random_seed,
int  random_seed,
int  data_seq_random_seed 
)
inline

Set the random seeds used for the trainer.

Definition at line 100 of file trainer.hpp.

◆ setup()

void lbann::trainer::setup ( std::unique_ptr< thread_pool io_thread_pool,
std::map< execution_mode, generic_data_reader *>  data_readers 
)

Set up the trainer.

Here is the caller graph for this function:

◆ train()

void lbann::trainer::train ( observer_ptr< model model,
El::Int  num_epochs,
El::Int  num_batches = 0 
)

◆ write_proto()

void lbann::trainer::write_proto ( lbann_data::Trainer &  proto)

Write trainer to proto message.

Member Data Documentation

◆ m_background_io_allowed

bool lbann::trainer::m_background_io_allowed
private

Flag that allows input layers to fetch data in the background.

Definition at line 344 of file trainer.hpp.

◆ m_callbacks

std::vector<std::shared_ptr<callback_base> > lbann::trainer::m_callbacks
private

Current callbacks to process.

Definition at line 301 of file trainer.hpp.

◆ m_comm

lbann_comm* lbann::trainer::m_comm
private

Communication domain for the trainer.

Definition at line 316 of file trainer.hpp.

◆ m_data_coordinator

std::unique_ptr<data_coordinator> lbann::trainer::m_data_coordinator
private

Data Coordinator holding trainers data readers.

Definition at line 307 of file trainer.hpp.

◆ m_data_seq_random_seed

int lbann::trainer::m_data_seq_random_seed
private

Random seed used for the RNG used to fetch data.

Definition at line 341 of file trainer.hpp.

◆ m_grids

std::vector<std::unique_ptr<El::Grid> > lbann::trainer::m_grids
private

Processor grids for sub-grid parallelism.

Does not include grid 0, which corresponds to the trainer's MPI communicator.

Definition at line 323 of file trainer.hpp.

◆ m_io_thread_pool

std::unique_ptr<thread_pool> lbann::trainer::m_io_thread_pool
private

Threads available for I/O.

Definition at line 304 of file trainer.hpp.

◆ m_max_mini_batch_size

size_t lbann::trainer::m_max_mini_batch_size
private

Maximum possible minibatch size supported by models and layers in this trainer.

Note
This field will eventually be local to the particular, instance of the training context.

Definition at line 330 of file trainer.hpp.

◆ m_model_execution_context

ModelContextMapType lbann::trainer::m_model_execution_context
private

Map from model and execution mode to its execution context.

Definition at line 295 of file trainer.hpp.

◆ m_name

std::string lbann::trainer::m_name
private

This trainer's name.

Definition at line 298 of file trainer.hpp.

◆ m_persist

persist lbann::trainer::m_persist
private

Persist object used for serializing LBANN classes.

Definition at line 280 of file trainer.hpp.

◆ m_random_seed

int lbann::trainer::m_random_seed
private

Random seed used for the general RNGs.

Definition at line 338 of file trainer.hpp.

◆ m_root_random_seed

int lbann::trainer::m_root_random_seed
private

Root of the random seed tree.

Either default or user supplied.

Definition at line 335 of file trainer.hpp.

◆ m_training_alg

std::unique_ptr<TrainingAlgorithm> lbann::trainer::m_training_alg
private

The training algorithm being used. May be null.

If null, a different type of execution algorithm is being used (e.g., inference).

Definition at line 313 of file trainer.hpp.


The documentation for this class was generated from the following file: