LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
trainer.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_TRAINER_HPP
28 #define LBANN_TRAINER_HPP
29 
30 #include "lbann/base.hpp"
31 #include "lbann/detect_El_mpi.hpp"
32 #include "lbann/io/persist.hpp"
33 #include "lbann/proto/lbann.pb.h"
34 #include "lbann/utils/hash.hpp"
36 #include <memory>
37 #include <string>
38 #include <unordered_map>
39 #include <vector>
40 
41 namespace lbann {
42 
43 // Forward-declarations
44 class data_coordinator;
45 class description;
46 class lbann_comm;
47 class callback_base;
48 class ExecutionContext;
49 class generic_data_reader;
50 class TrainingAlgorithm;
52 class model;
53 
60 class trainer
61 {
62 public:
64 
74  trainer(lbann_comm* comm,
75  std::unique_ptr<data_coordinator> dc,
76  size_t mini_batch_size,
77  std::unique_ptr<TrainingAlgorithm> alg = nullptr);
78 
79  ~trainer();
80 
82 
83 
86  template <class Archive>
87  void serialize(Archive& ar);
88 
90 
91 
97  void set_name(std::string const& name);
98 
100  void set_random_seeds(int root_random_seed,
101  int random_seed,
102  int data_seq_random_seed)
103  {
104  m_root_random_seed = root_random_seed;
105  m_random_seed = random_seed;
106  m_data_seq_random_seed = data_seq_random_seed;
107  }
108 
109  void add_callback(std::shared_ptr<callback_base> cb)
110  {
111  if (cb == nullptr) {
112  throw lbann_exception(
113  "model: Attempted to add null pointer as a callback.");
114  }
115  m_callbacks.push_back(std::move(cb));
116  }
117 
119  void setup(std::unique_ptr<thread_pool> io_thread_pool,
120  std::map<execution_mode, generic_data_reader*> data_readers);
121 
126  {
127  m_background_io_allowed = enable;
128  }
129 
131 
132 
138  std::string get_name() const { return m_name; }
139 
142 
143  int get_random_seed() const noexcept { return m_random_seed; }
144  int get_data_seq_random_seed() const noexcept
145  {
146  return m_data_seq_random_seed;
147  }
148 
150  std::vector<observer_ptr<callback_base>> get_callbacks() const
151  {
152  std::vector<observer_ptr<callback_base>> callback_list;
153  callback_list.reserve(m_callbacks.size());
154  for (const auto& ptr : m_callbacks) {
155  callback_list.push_back(ptr.get());
156  }
157  return callback_list;
158  }
159 
160  std::vector<std::shared_ptr<callback_base>>& get_callbacks_with_ownership()
161  {
162  return m_callbacks;
163  }
164 
166  {
167  if (m_data_coordinator == nullptr) {
168  LBANN_ERROR("data_coordinator is nullptr");
169  }
170  return *m_data_coordinator;
171  }
172 
174  {
175  return const_cast<data_coordinator&>(
176  static_cast<const trainer&>(*this).get_data_coordinator());
177  }
178 
181  {
182  if (!m_io_thread_pool) {
183  LBANN_ERROR("m_io_thread_pool is null");
184  }
185  return *(m_io_thread_pool.get());
186  }
187 
189  lbann_comm* get_comm() const noexcept { return m_comm; }
190 
192  persist& get_persist_obj() noexcept { return m_persist; }
193 
195  size_t get_max_mini_batch_size() const noexcept
196  {
197  return m_max_mini_batch_size;
198  }
199 
201  bool background_io_activity_allowed() const noexcept
202  {
204  }
205 
207 
209  typename std::pair<observer_ptr<model>, execution_mode>;
210 
214  execution_mode mode);
215 
218  model& model,
219  execution_mode mode);
220 
222  execution_mode mode);
223 
225 
226  bool execution_context_valid(model& m, execution_mode mode) const noexcept;
227 
228  bool execution_context_valid(execution_context_key_pair_t key) const noexcept;
229 
231 
233  void
234  train(observer_ptr<model> model, El::Int num_epochs, El::Int num_batches = 0);
235 
237  execution_mode mode,
238  El::Int num_batches = 0);
239 
241 
242  std::vector<El::Grid*> get_grids() const;
244  void add_grid(std::unique_ptr<El::Grid> g);
246 
247 
251 
254 
257 
260 
263 
266 
268  void write_proto(lbann_data::Trainer& proto);
269 
271 
272 private:
274 
276  std::function<void(observer_ptr<ExecutionContext>)> fn);
277 
278 private:
281 
286  std::hash<observer_ptr<model>>,
288 
289  using ModelContextMapType =
290  std::unordered_map<std::pair<observer_ptr<model>, execution_mode>,
291  std::unique_ptr<ExecutionContext>,
293 
296 
298  std::string m_name;
299 
301  std::vector<std::shared_ptr<callback_base>> m_callbacks;
302 
304  std::unique_ptr<thread_pool> m_io_thread_pool;
305 
307  std::unique_ptr<data_coordinator> m_data_coordinator;
308 
313  std::unique_ptr<TrainingAlgorithm> m_training_alg;
314 
317 
323  std::vector<std::unique_ptr<El::Grid>> m_grids;
324 
331 
336 
339 
342 
345 };
346 
349 
353 trainer const& get_const_trainer();
354 
355 } // namespace lbann
356 
357 #endif // LBANN_TRAINER_HPP
persist & get_persist_obj() noexcept
Get the trainer&#39;s persist object.
Definition: trainer.hpp:192
void evaluate(observer_ptr< model > model, execution_mode mode, El::Int num_batches=0)
std::unordered_map< std::pair< observer_ptr< model >, execution_mode >, std::unique_ptr< ExecutionContext >, model_execution_context_hash_t > ModelContextMapType
Definition: trainer.hpp:292
std::string get_name() const
Definition: trainer.hpp:138
void allow_background_io_activity(bool enable)
Set a flag that can be used to enable / disable the background I/O activities.
Definition: trainer.hpp:125
void add_grid(std::unique_ptr< El::Grid > g)
std::string m_name
This trainer&#39;s name.
Definition: trainer.hpp:298
#define LBANN_ERROR(...)
Definition: exception.hpp:37
persist m_persist
Persist object used for serializing LBANN classes.
Definition: trainer.hpp:280
int m_data_seq_random_seed
Random seed used for the RNG used to fetch data.
Definition: trainer.hpp:341
bool background_io_activity_allowed() const noexcept
Are background I/O activities enabled by the input layers.
Definition: trainer.hpp:201
std::vector< observer_ptr< callback_base > > get_callbacks() const
Get the list of callbacks for the trainer.
Definition: trainer.hpp:150
Generates nicely formatted description messages.
Definition: description.hpp:49
std::vector< std::unique_ptr< El::Grid > > m_grids
Processor grids for sub-grid parallelism.
Definition: trainer.hpp:323
size_t get_max_mini_batch_size() const noexcept
Get the trainer&#39;s maximum mini-batch size.
Definition: trainer.hpp:195
typename std::pair< observer_ptr< model >, execution_mode > execution_context_key_pair_t
Definition: trainer.hpp:209
bool load_from_checkpoint_distributed(persist &p)
Restore a trainer from a distributed checkpoint.
void for_each_execution_context(std::function< void(observer_ptr< ExecutionContext >)> fn)
ExecutionContext & get_execution_context(observer_ptr< model > model, execution_mode mode)
data_coordinator & get_data_coordinator()
Definition: trainer.hpp:173
lbann_comm * get_comm() const noexcept
Get the trainer&#39;s comm.
Definition: trainer.hpp:189
bool save_to_checkpoint_distributed()
Create a distributed checkpoint of the trainer.
std::unique_ptr< TrainingAlgorithm > m_training_alg
The training algorithm being used. May be null.
Definition: trainer.hpp:313
bool execution_context_valid(model &m, execution_mode mode) const noexcept
trainer const & get_const_trainer()
Get a const reference to the global trainer visible to this rank.
trainer(lbann_comm *comm, std::unique_ptr< data_coordinator > dc, size_t mini_batch_size, std::unique_ptr< TrainingAlgorithm > alg=nullptr)
Construct with a communicator and data coordinator.
Abstract base class for neural network models.
Definition: model.hpp:83
size_t m_max_mini_batch_size
Maximum possible minibatch size supported by models and layers in this trainer.
Definition: trainer.hpp:330
bool m_background_io_allowed
Flag that allows input layers to fetch data in the background.
Definition: trainer.hpp:344
The execution context for an KFAC algorithm.
std::vector< std::shared_ptr< callback_base > > & get_callbacks_with_ownership()
Definition: trainer.hpp:160
Hash function for enumeration type.
Definition: hash.hpp:58
const data_coordinator & get_data_coordinator() const
Definition: trainer.hpp:165
bool load_from_checkpoint_shared(persist &p)
Restore trainer from a shared checkpoint.
typename std::add_pointer< T >::type observer_ptr
Creating an observer_ptr to complement the unique_ptr and shared_ptr.
Definition: base.hpp:54
execution_mode
Neural network execution mode.
Definition: base.hpp:229
The stopping criteria for an LTFB-type algorithm.
std::vector< El::Grid * > get_grids() const
exception lbann_exception
Definition: exception.hpp:145
void add_callback(std::shared_ptr< callback_base > cb)
Definition: trainer.hpp:109
execution_context_key_pair_t check_and_build_execution_context(TrainingAlgorithm &alg, observer_ptr< model > model, execution_mode mode)
int get_random_seed() const noexcept
Definition: trainer.hpp:143
lbann_comm * m_comm
Communication domain for the trainer.
Definition: trainer.hpp:316
thread_pool & get_io_thread_pool() const
Get the I/O thread pool.
Definition: trainer.hpp:180
std::unique_ptr< data_coordinator > m_data_coordinator
Data Coordinator holding trainers data readers.
Definition: trainer.hpp:307
description get_description() const
int m_root_random_seed
Root of the random seed tree.
Definition: trainer.hpp:335
ModelContextMapType m_model_execution_context
Map from model and execution mode to its execution context.
Definition: trainer.hpp:295
User-facing class that represents a set of compute resources.
Definition: trainer.hpp:60
std::vector< std::shared_ptr< callback_base > > m_callbacks
Current callbacks to process.
Definition: trainer.hpp:301
bool save_to_checkpoint_shared()
Create a shared checkpoint of the trainer.
void write_proto(lbann_data::Trainer &proto)
Write trainer to proto message.
void set_random_seeds(int root_random_seed, int random_seed, int data_seq_random_seed)
Set the random seeds used for the trainer.
Definition: trainer.hpp:100
trainer & get_trainer()
Get a reference to the global trainer visible to this rank.
void delete_execution_context(execution_context_key_pair_t key)
void set_name(std::string const &name)
Set the trainer&#39;s name.
int get_data_seq_random_seed() const noexcept
Definition: trainer.hpp:144
int m_random_seed
Random seed used for the general RNGs.
Definition: trainer.hpp:338
std::unique_ptr< thread_pool > m_io_thread_pool
Threads available for I/O.
Definition: trainer.hpp:304
Hash function for std::pair.
Definition: hash.hpp:75
Base class for LBANN training_algorithms.
void serialize(Archive &ar)
Archive for checkpoint and restart.
void train(observer_ptr< model > model, El::Int num_epochs, El::Int num_batches=0)
void setup(std::unique_ptr< thread_pool > io_thread_pool, std::map< execution_mode, generic_data_reader *> data_readers)
Set up the trainer.