27 #ifndef LBANN_OPTIMIZERS_RMSPROP_HPP_INCLUDED 28 #define LBANN_OPTIMIZERS_RMSPROP_HPP_INCLUDED 32 #include "lbann/proto/optimizers.pb.h" 42 template <
typename TensorDataType>
44 data_type_optimizer<TensorDataType>>
65 rmsprop(TensorDataType learning_rate,
66 TensorDataType decay_rate,
67 TensorDataType eps = 1e-8);
73 template <
class Archive>
77 std::string
get_type()
const override {
return "RMSprop"; }
85 void write_proto(lbann_data::Optimizer& opt)
const final;
88 friend cereal::access;
96 El::To<TensorDataType>(1.f),
97 El::To<TensorDataType>(1e-8))
119 #endif // LBANN_HAS_GPU 122 template <
typename TensorDataType>
123 std::unique_ptr<optimizer>
128 #endif // LBANN_OPTIMIZERS_RMSPROP_HPP_INCLUDED
Inject polymorphic clone functions into hierarchies.
void setup(weights *w) override
Must be called before training.
TensorDataType m_decay_rate
Generates nicely formatted description messages.
description get_description() const override
std::unique_ptr< AbsDistMatrixType > m_cache
void write_proto(lbann_data::Optimizer &opt) const final
void step_compute_cpu(AbsDistMatrixType &values, const AbsDistMatrixType &gradient)
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
rmsprop & operator=(const rmsprop &other)
rmsprop()
Default constructor.
void setup(WeightsType *w=nullptr) override
std::unique_ptr< optimizer > build_rmsprop_optimizer_from_pbuf(google::protobuf::Message const &)
void step_compute(AbsDistMatrixType &values, const AbsDistMatrixType &gradient) override
~rmsprop() override=default
void serialize(Archive &ar)
std::string get_type() const override