29 #ifndef LBANN_CALLBACKS_CALLBACK_LOAD_MODEL_HPP_INCLUDED 30 #define LBANN_CALLBACKS_CALLBACK_LOAD_MODEL_HPP_INCLUDED 36 #include <google/protobuf/message.h> 56 load_model(std::vector<std::string> dirs, std::string extension =
"prototext")
58 m_dirs(std::move(dirs)),
59 m_extension(std::move(extension)),
66 inline void add_dir(
const std::string& dir) { m_dirs.emplace_back(dir); }
68 void on_train_begin(
model* m)
override;
70 void on_test_begin(
model* m)
override;
72 std::string
name()
const override {
return "load model"; }
78 template <
class Archive>
85 friend class cereal::access;
90 void write_specific_proto(lbann_data::Callback& proto)
const final;
101 std::unique_ptr<callback_base>
103 std::shared_ptr<lbann_summary>
const&);
108 #endif // LBANN_CALLBACKS_CALLBACK_LOAD_MODEL_HPP_INCLUDED std::unique_ptr< callback_base > build_load_model_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
std::string name() const override
Return this callback's name.
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
std::string m_extension
Disables the normal behavior of saving when training is complete.
void add_dir(const std::string &dir)
bool m_loaded
Flag to indicate if the model has already been loaded.
Base class for callbacks during training/testing.
std::vector< std::string > m_dirs
Abstract base class for neural network models.
load_model * copy() const override
load_model(std::vector< std::string > dirs, std::string extension="prototext")