29 #ifndef LBANN_CALLBACKS_CALLBACK_SAVE_MODEL_HPP_INCLUDED 30 #define LBANN_CALLBACKS_CALLBACK_SAVE_MODEL_HPP_INCLUDED 38 #include <google/protobuf/message.h> 60 bool disable_save_after_training,
61 std::string extension =
"prototext")
63 m_dir(std::move(dir)),
64 m_disable_save_after_training(disable_save_after_training),
65 m_extension(std::move(extension))
70 void on_train_end(
model* m)
override;
71 std::string
name()
const override {
return "save model"; }
78 bool do_save_model(
model* m);
79 bool do_save_model_weights(
model* m);
83 void write_specific_proto(lbann_data::Callback& proto)
const final;
91 void write_proto_binary(
const lbann_data::Model& proto,
92 const std::string filename);
93 void write_proto_text(
const lbann_data::Model& proto,
94 const std::string filename);
98 const std::string& model_name,
99 const std::string& dir)
101 return build_string(dir,
'/', trainer_name,
'/', model_name,
'/');
105 std::unique_ptr<callback_base>
107 std::shared_ptr<lbann_summary>
const&);
112 #endif // LBANN_CALLBACKS_CALLBACK_SAVE_MODEL_HPP_INCLUDED const std::string & get_target_dir()
std::string name() const override
Return this callback's name.
save_model(std::string dir, bool disable_save_after_training, std::string extension="prototext")
std::string build_string(Args &&... args)
Build a string from the arguments.
void set_target_dir(const std::string &dir)
Base class for callbacks during training/testing.
Abstract base class for neural network models.
std::string get_save_model_dirname(const std::string &trainer_name, const std::string &model_name, const std::string &dir)
save_model * copy() const override
bool m_disable_save_after_training
Disables the normal behavior of saving when training is complete.
std::unique_ptr< callback_base > build_save_model_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)