27 #ifndef LBANN_CALLBACKS_CALLBACK_PERTURB_LEARNING_RATE_HPP_INCLUDED 28 #define LBANN_CALLBACKS_CALLBACK_PERTURB_LEARNING_RATE_HPP_INCLUDED 67 DataType learning_rate_factor,
68 bool perturb_during_training =
false,
69 El::Int batch_interval = 1,
70 std::set<std::string> weights_names = std::set<std::string>());
75 std::string
name()
const override 77 return "perturb optimizer learning rate";
87 template <
class Archive>
125 const google::protobuf::Message&,
126 std::shared_ptr<lbann_summary>
const&);
131 #endif // LBANN_CALLBACKS_CALLBACK_PERTURB_LEARNING_RATE_HPP_INCLUDED void on_batch_begin(model *m) override
Called at the beginning of a (mini-)batch.
bool m_perturb_during_training
Base class for callbacks during training/testing.
Abstract base class for neural network models.
Hyperparameter exploration of optimizer learning rate.
void perturb(model &m) const
void write_specific_proto(lbann_data::Callback &proto) const final
perturb_learning_rate * copy() const override
DataType m_learning_rate_factor
friend class cereal::access
std::unique_ptr< callback_base > build_perturb_learning_rate_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
std::set< std::string > m_weights_names
std::string name() const override
Return this callback's name.
void serialize(Archive &ar)
Store state to archive for checkpoint and restart.
void setup(model *m) override
Called once to set up the callback on the model (after all layers are set up).