28 #ifndef LBANN_CALLBACKS_CALLBACK_CHECKPOINT_HPP_INCLUDED 29 #define LBANN_CALLBACKS_CALLBACK_CHECKPOINT_HPP_INCLUDED 38 class TrainingAlgorithm;
72 std::string restart_dir,
73 int checkpoint_epochs,
76 std::string per_rank_dir,
80 m_active_trainer(nullptr),
81 m_active_training_algorithm(nullptr),
82 m_checkpoint_dir(std::move(checkpoint_dir)),
83 m_restart_dir(std::move(restart_dir)),
84 m_checkpoint_epochs(checkpoint_epochs),
85 m_checkpoint_steps(checkpoint_steps),
86 m_checkpoint_secs(checkpoint_secs),
87 m_per_rank_dir(per_rank_dir),
88 m_ckpt_dist_epochs(ckpt_dist_epochs),
89 m_ckpt_dist_steps(ckpt_dist_steps)
94 void setup(
model* m)
override;
95 void setup(
trainer* t)
override;
96 void on_train_begin(
model* m)
override;
97 void on_train_end(
model* m)
override;
98 void on_epoch_begin(
model* m)
override;
99 void on_batch_begin(
model* m)
override;
100 void on_validation_begin(
model* m)
override;
104 m_checkpoint_dir = dir;
114 if (m_restart_dir.length() != 0) {
115 return m_restart_dir;
118 return m_checkpoint_dir;
128 m_active_training_algorithm = t;
135 m_checkpoint_epochs = epochs;
148 m_ckpt_dist_epochs = ckpt_dist_epochs;
153 m_ckpt_dist_steps = ckpt_dist_steps;
158 return get_restart_dir();
165 if (m_per_rank_dir.length()) {
166 return get_per_rank_dir() +
"/" + get_restart_dir();
169 return get_restart_dir();
174 std::string find_latest_checkpoint(
lbann_comm& comm,
175 const std::string& trainer_name,
176 const std::string& alg_name,
182 bool open_latest_checkpoint(
184 const std::string& task_label,
185 const std::string& trainer_name,
186 const std::string& alg_name,
187 std::function<
bool(
persist&)> reload_shared_ckpt,
188 std::function<
bool(
persist&)> reload_distributed_ckpt);
189 bool reload_model(
model* m);
190 bool reload_trainer(
trainer* t);
191 bool restart(
model* m);
192 std::string
name()
const override {
return "checkpoint"; }
196 void write_specific_proto(lbann_data::Callback& proto)
const final;
199 void do_distributed_checkpoint(
lbann_comm& comm,
233 template <
size_t _max_dir_len>
241 char dirname[_max_dir_len];
246 const std::string& dir);
249 const std::string& dir);
252 const std::string& alg_name,
253 const std::string& dir);
256 const std::string& dir,
263 const std::string& alg_name,
264 const std::string& dir,
272 const std::string& dir);
276 const std::string& alg_name,
277 const std::string& dir);
280 const int rank_in_trainer,
281 const std::string& dir,
288 const std::string& alg_name,
289 const int rank_in_trainer,
290 const std::string& dir,
314 std::unique_ptr<callback_base>
320 #endif // LBANN_CALLBACKS_CALLBACK_CHECKPOINT_HPP_INCLUDED void set_checkpoint_dir(const std::string &dir)
std::string m_checkpoint_dir
std::string get_shared_checkpoint_dirname(const std::string &alg_name, const std::string &dir, visitor_hook hook, execution_mode mode, size_t epoch, size_t step)
checkpoint(std::string checkpoint_dir, std::string restart_dir, int checkpoint_epochs, int checkpoint_steps, int checkpoint_secs, std::string per_rank_dir, int ckpt_dist_epochs, int ckpt_dist_steps)
Construct the checkpoint callback.
const std::string & get_restart_dir()
void set_checkpoint_epochs(int epochs)
void set_active_training_algorithm(TrainingAlgorithm *t)
std::string get_distributed_checkpoint_dirname(const std::string &alg_name, const int rank_in_trainer, const std::string &dir, visitor_hook hook, execution_mode mode, size_t epoch, size_t step)
std::string get_last_distributed_checkpoint_filename(const std::string &alg_name, const std::string &dir)
TrainingAlgorithm * m_active_training_algorithm
void set_ckpt_dist_steps(int ckpt_dist_steps)
std::string m_per_rank_dir
std::string get_distributed_checkpoint_rootdir()
void set_per_rank_dir(std::string dir)
Base class for callbacks during training/testing.
bool write_latest(std::string filename, visitor_hook hook, execution_mode mode, size_t epoch, size_t train)
EvalType m_checkpoint_last
Abstract base class for neural network models.
Checkpoint at given interval in given directory.
std::string get_trainer_checkpoint_dirname(const std::string &trainer_name, const std::string &dir)
execution_mode
Neural network execution mode.
void set_restart_dir(const std::string &dir)
const std::string & get_checkpoint_dir()
void set_ckpt_dist_epochs(int ckpt_dist_epochs)
User-facing class that represents a set of compute resources.
std::string get_last_shared_checkpoint_filename(const std::string &alg_name, const std::string &dir)
checkpoint * copy() const override
void set_active_trainer(trainer *t)
std::string m_restart_dir
void set_checkpoint_secs(EvalType secs)
void set_checkpoint_steps(int steps)
std::string get_shared_checkpoint_rootdir()
visitor_hook
Neural network execution mode.
std::string name() const override
Return this callback's name.
std::unique_ptr< callback_base > build_checkpoint_callback_from_pbuf(const google::protobuf::Message &)
Base class for LBANN training_algorithms.
trainer * m_active_trainer
EvalType m_checkpoint_secs
const std::string & get_per_rank_dir()
bool read_latest(std::string filename, visitor_hook *hook, execution_mode *mode, size_t *epochLast, size_t *trainLast)
Reads the "latest" file and returns the epoch number and sample offset for most recent checkpoint...