27 #ifndef LBANN_OPTIMIZERS_ADAM_HPP_INCLUDED 28 #define LBANN_OPTIMIZERS_ADAM_HPP_INCLUDED 32 #include "lbann/proto/optimizers.pb.h" 46 template <
typename TensorDataType>
48 :
public Cloneable<adam<TensorDataType>, data_type_optimizer<TensorDataType>>
72 adam(TensorDataType learning_rate,
73 TensorDataType beta1 = 0.9,
74 TensorDataType beta2 = 0.99,
75 TensorDataType eps = 1e-8,
76 TensorDataType adamw_weight_decay = 0.0);
82 template <
class Archive>
91 std::string
get_type()
const override {
return "Adam"; }
101 TensorDataType
get_beta1() const noexcept {
return m_beta1; }
103 void set_beta1(TensorDataType beta1) { m_beta1 = beta1; }
105 TensorDataType
get_beta2() const noexcept {
return m_beta2; }
107 void set_beta2(TensorDataType beta2) { m_beta2 = beta2; }
109 TensorDataType
get_eps() const noexcept {
return m_eps; }
111 void set_eps(TensorDataType eps) { m_eps = eps; }
115 return m_adamw_weight_decay;
120 m_adamw_weight_decay = adamw_weight_decay;
141 m_current_beta1 = current_beta1;
152 m_current_beta2 = current_beta2;
160 using OptimizerType::setup;
166 void write_proto(lbann_data::Optimizer& opt)
const final;
169 friend cereal::access;
176 :
adam(
El::To<TensorDataType>(1.f),
177 El::To<TensorDataType>(0.9),
178 El::To<TensorDataType>(0.99),
179 El::To<TensorDataType>(1e-8),
180 El::To<TensorDataType>(0))
197 TensorDataType m_current_beta1 = TensorDataType(1.);
199 TensorDataType m_current_beta2 = TensorDataType(1.);
211 const TensorDataType& correction);
216 const TensorDataType& correction);
217 #endif // LBANN_HAS_GPU 220 template <
typename TensorDataType>
221 std::unique_ptr<optimizer>
226 #endif // LBANN_OPTIMIZERS_ADAM_HPP_INCLUDED std::unique_ptr< AbsDistMatrixType > m_moment2
TensorDataType get_beta2() const noexcept
Inject polymorphic clone functions into hierarchies.
void set_current_beta1(TensorDataType current_beta1)
std::string get_type() const override
TensorDataType get_current_beta1() const noexcept
adam()
Default constructor.
TensorDataType get_current_beta2() const noexcept
TensorDataType get_beta1() const noexcept
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
Generates nicely formatted description messages.
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
Hyperparameter exploration with Adam optimizers.
std::unique_ptr< AbsDistMatrixType > m_moment1
void set_beta1(TensorDataType beta1)
void set_beta2(TensorDataType beta2)
void set_adamw_weight_decay(TensorDataType adamw_weight_decay)
void set_eps(TensorDataType eps)
TensorDataType m_adamw_weight_decay
TensorDataType get_adamw_weight_decay() const noexcept
TensorDataType get_eps() const noexcept
void set_current_beta2(TensorDataType current_beta2)
std::unique_ptr< optimizer > build_adam_optimizer_from_pbuf(google::protobuf::Message const &)