27 #ifndef LBANN_OPTIMIZERS_SGD_HPP_INCLUDED 28 #define LBANN_OPTIMIZERS_SGD_HPP_INCLUDED 32 #include "lbann/proto/optimizers.pb.h" 40 template <
typename TensorDataType>
42 :
public Cloneable<sgd<TensorDataType>, data_type_optimizer<TensorDataType>>
66 sgd(TensorDataType learning_rate,
67 TensorDataType momentum = 0,
68 bool nesterov =
false);
71 ~sgd()
override =
default;
78 template <
class ArchiveT>
86 std::string
get_type()
const override {
return "SGD"; }
124 void write_proto(lbann_data::Optimizer& opt)
const final;
131 sgd() :
sgd(
El::To<TensorDataType>(0.f),
El::To<TensorDataType>(0.f), false)
157 #endif // LBANN_HAS_GPU 162 template <
typename TensorDataType>
163 std::unique_ptr<optimizer>
168 #endif // LBANN_OPTIMIZERS_SGD_HPP_INCLUDED void serialize(ArchiveT &ar)
Serialize to the archive.
std::unique_ptr< optimizer > build_sgd_optimizer_from_pbuf(google::protobuf::Message const &)
Inject polymorphic clone functions into hierarchies.
friend class cereal::access
void setup(weights *w) override
Must be called before training.
TensorDataType m_momentum
Decay rate for gradient accumulation.
description get_description() const override
void set_momentum(TensorDataType momentum)
Decay rate for gradient accumulation.
void setup(WeightsType *w=nullptr) override
Generates nicely formatted description messages.
std::unique_ptr< AbsDistMatrixType > m_velocity
Accumulated gradients.
void momentum_step_cpu(AbsDistMatrixType &values, const AbsDistMatrixType &gradient)
std::string get_type() const override
void step_compute(AbsDistMatrixType &values, const AbsDistMatrixType &gradient) override
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
TensorDataType get_momentum() const noexcept
Decay rate for gradient accumulation.
void write_proto(lbann_data::Optimizer &opt) const final
sgd & operator=(const sgd &other)
bool using_nesterov() const noexcept
void set_nesterov(bool nesterov)
sgd()
Default constructor.
const AbsDistMatrixType & get_velocity() const
Stochastic gradient descent optimizer.