27 #ifndef LBANN_CALLBACKS_CALLBACK_PERTURB_ADAM_HPP_INCLUDED 28 #define LBANN_CALLBACKS_CALLBACK_PERTURB_ADAM_HPP_INCLUDED 72 DataType beta1_factor,
73 DataType beta2_factor,
74 DataType eps_factor = 0,
75 bool perturb_during_training =
false,
76 El::Int batch_interval = 1,
77 std::set<std::string> weights_names = std::set<std::string>());
79 std::string
name()
const override {
return "perturb Adam"; }
88 template <
class Archive>
140 std::unique_ptr<callback_base>
142 std::shared_ptr<lbann_summary>
const&);
147 #endif // LBANN_CALLBACKS_CALLBACK_PERTURB_ADAM_HPP_INCLUDED friend class cereal::access
void write_specific_proto(lbann_data::Callback &proto) const final
bool m_perturb_during_training
void on_batch_begin(model *m) override
Called at the beginning of a (mini-)batch.
std::string name() const override
Return this callback's name.
DataType m_learning_rate_factor
std::set< std::string > m_weights_names
void setup(model *m) override
Called once to set up the callback on the model (after all layers are set up).
Base class for callbacks during training/testing.
Abstract base class for neural network models.
Hyperparameter exploration with Adam optimizers.
std::unique_ptr< callback_base > build_perturb_adam_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
void perturb(model &m) const
void serialize(Archive &ar)
Store state to archive for checkpoint and restart.
perturb_adam * copy() const override