28 #ifndef LBANN_CALLBACKS_CALLBACK_REPLACE_WEIGHTS_HPP_INCLUDED 29 #define LBANN_CALLBACKS_CALLBACK_REPLACE_WEIGHTS_HPP_INCLUDED 49 std::vector<std::string> dst,
50 int batch_interval = 1);
58 std::string
name()
const override {
return "replace weights"; }
69 std::unique_ptr<callback_base>
71 std::shared_ptr<lbann_summary>
const&);
76 #endif // LBANN_CALLBACKS_CALLBACK_REPLACE_WEIGHTS_HPP_INCLUDED std::vector< Layer * > m_dst_layers
std::unique_ptr< callback_base > build_replace_weights_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
std::vector< std::string > m_src_layer_names
replace_weights & operator=(const replace_weights &)=default
replace_weights * copy() const override
Base class for callbacks during training/testing.
Abstract base class for neural network models.
void write_specific_proto(lbann_data::Callback &proto) const final
void setup(model *m) override
Called once to set up the callback on the model (after all layers are set up).
std::vector< std::string > m_dst_layer_names
std::string name() const override
Return this callback's name.
void on_batch_end(model *m) override
Called immediately after the end of a (mini-)batch.
std::vector< Layer * > m_src_layers
replace_weights(std::vector< std::string > src, std::vector< std::string > dst, int batch_interval=1)