27 #ifndef LBANN_TRAINER_HPP 28 #define LBANN_TRAINER_HPP 33 #include "lbann/proto/lbann.pb.h" 38 #include <unordered_map> 44 class data_coordinator;
49 class generic_data_reader;
50 class TrainingAlgorithm;
75 std::unique_ptr<data_coordinator> dc,
76 size_t mini_batch_size,
77 std::unique_ptr<TrainingAlgorithm> alg =
nullptr);
86 template <
class Archive>
97 void set_name(std::string
const& name);
102 int data_seq_random_seed)
113 "model: Attempted to add null pointer as a callback.");
119 void setup(std::unique_ptr<thread_pool> io_thread_pool,
120 std::map<execution_mode, generic_data_reader*> data_readers);
152 std::vector<observer_ptr<callback_base>> callback_list;
155 callback_list.push_back(ptr.get());
157 return callback_list;
214 execution_mode mode);
219 execution_mode mode);
222 execution_mode mode);
238 El::Int num_batches = 0);
242 std::vector<El::Grid*>
get_grids()
const;
244 void add_grid(std::unique_ptr<El::Grid> g);
286 std::hash<observer_ptr<model>>,
290 std::unordered_map<std::pair<observer_ptr<model>, execution_mode>,
291 std::unique_ptr<ExecutionContext>,
323 std::vector<std::unique_ptr<El::Grid>>
m_grids;
357 #endif // LBANN_TRAINER_HPP persist & get_persist_obj() noexcept
Get the trainer's persist object.
void evaluate(observer_ptr< model > model, execution_mode mode, El::Int num_batches=0)
std::unordered_map< std::pair< observer_ptr< model >, execution_mode >, std::unique_ptr< ExecutionContext >, model_execution_context_hash_t > ModelContextMapType
std::string get_name() const
void allow_background_io_activity(bool enable)
Set a flag that can be used to enable / disable the background I/O activities.
void add_grid(std::unique_ptr< El::Grid > g)
std::string m_name
This trainer's name.
persist m_persist
Persist object used for serializing LBANN classes.
int m_data_seq_random_seed
Random seed used for the RNG used to fetch data.
bool background_io_activity_allowed() const noexcept
Are background I/O activities enabled by the input layers.
std::vector< observer_ptr< callback_base > > get_callbacks() const
Get the list of callbacks for the trainer.
Generates nicely formatted description messages.
std::vector< std::unique_ptr< El::Grid > > m_grids
Processor grids for sub-grid parallelism.
size_t get_max_mini_batch_size() const noexcept
Get the trainer's maximum mini-batch size.
typename std::pair< observer_ptr< model >, execution_mode > execution_context_key_pair_t
bool load_from_checkpoint_distributed(persist &p)
Restore a trainer from a distributed checkpoint.
void for_each_execution_context(std::function< void(observer_ptr< ExecutionContext >)> fn)
ExecutionContext & get_execution_context(observer_ptr< model > model, execution_mode mode)
data_coordinator & get_data_coordinator()
lbann_comm * get_comm() const noexcept
Get the trainer's comm.
bool save_to_checkpoint_distributed()
Create a distributed checkpoint of the trainer.
std::unique_ptr< TrainingAlgorithm > m_training_alg
The training algorithm being used. May be null.
bool execution_context_valid(model &m, execution_mode mode) const noexcept
trainer const & get_const_trainer()
Get a const reference to the global trainer visible to this rank.
trainer(lbann_comm *comm, std::unique_ptr< data_coordinator > dc, size_t mini_batch_size, std::unique_ptr< TrainingAlgorithm > alg=nullptr)
Construct with a communicator and data coordinator.
Abstract base class for neural network models.
size_t m_max_mini_batch_size
Maximum possible minibatch size supported by models and layers in this trainer.
bool m_background_io_allowed
Flag that allows input layers to fetch data in the background.
The execution context for an KFAC algorithm.
std::vector< std::shared_ptr< callback_base > > & get_callbacks_with_ownership()
Hash function for enumeration type.
const data_coordinator & get_data_coordinator() const
bool load_from_checkpoint_shared(persist &p)
Restore trainer from a shared checkpoint.
typename std::add_pointer< T >::type observer_ptr
Creating an observer_ptr to complement the unique_ptr and shared_ptr.
execution_mode
Neural network execution mode.
The stopping criteria for an LTFB-type algorithm.
std::vector< El::Grid * > get_grids() const
exception lbann_exception
void add_callback(std::shared_ptr< callback_base > cb)
execution_context_key_pair_t check_and_build_execution_context(TrainingAlgorithm &alg, observer_ptr< model > model, execution_mode mode)
int get_random_seed() const noexcept
lbann_comm * m_comm
Communication domain for the trainer.
thread_pool & get_io_thread_pool() const
Get the I/O thread pool.
std::unique_ptr< data_coordinator > m_data_coordinator
Data Coordinator holding trainers data readers.
description get_description() const
int m_root_random_seed
Root of the random seed tree.
ModelContextMapType m_model_execution_context
Map from model and execution mode to its execution context.
User-facing class that represents a set of compute resources.
std::vector< std::shared_ptr< callback_base > > m_callbacks
Current callbacks to process.
bool save_to_checkpoint_shared()
Create a shared checkpoint of the trainer.
void write_proto(lbann_data::Trainer &proto)
Write trainer to proto message.
void set_random_seeds(int root_random_seed, int random_seed, int data_seq_random_seed)
Set the random seeds used for the trainer.
trainer & get_trainer()
Get a reference to the global trainer visible to this rank.
void delete_execution_context(execution_context_key_pair_t key)
void set_name(std::string const &name)
Set the trainer's name.
int get_data_seq_random_seed() const noexcept
int m_random_seed
Random seed used for the general RNGs.
std::unique_ptr< thread_pool > m_io_thread_pool
Threads available for I/O.
Hash function for std::pair.
Base class for LBANN training_algorithms.
void serialize(Archive &ar)
Archive for checkpoint and restart.
void train(observer_ptr< model > model, El::Int num_epochs, El::Int num_batches=0)
void setup(std::unique_ptr< thread_pool > io_thread_pool, std::map< execution_mode, generic_data_reader *> data_readers)
Set up the trainer.