|
LBANN
0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
|
Adam optimizer. More...
#include <adam.hpp>
Public Types | |
Public Types | |
| using | AbsDistMatrixType = El::AbstractDistMatrix< TensorDataType > |
| The tensor type expected in this object. More... | |
| using | OptimizerType = data_type_optimizer< TensorDataType > |
| The optimizer base type of this object. More... | |
| using | WeightsType = data_type_weights< TensorDataType > |
| The concrete weights type used by this object. More... | |
Public Member Functions | |
| void | write_proto (lbann_data::Optimizer &opt) const final |
Life cycle functions | |
| adam (TensorDataType learning_rate, TensorDataType beta1=0.9, TensorDataType beta2=0.99, TensorDataType eps=1e-8, TensorDataType adamw_weight_decay=0.0) | |
| adam (const adam &other) | |
| adam & | operator= (const adam &other) |
| ~adam ()=default | |
| template<class Archive > | |
| void | serialize (Archive &ar) |
Descriptions | |
| std::string | get_type () const override |
| description | get_description () const override |
Access functions | |
| TensorDataType | get_beta1 () const noexcept |
| void | set_beta1 (TensorDataType beta1) |
| TensorDataType | get_beta2 () const noexcept |
| void | set_beta2 (TensorDataType beta2) |
| TensorDataType | get_eps () const noexcept |
| void | set_eps (TensorDataType eps) |
| TensorDataType | get_adamw_weight_decay () const noexcept |
| void | set_adamw_weight_decay (TensorDataType adamw_weight_decay) |
| const AbsDistMatrixType & | get_moment1 () const |
| AbsDistMatrixType & | get_moment1 () |
| const AbsDistMatrixType & | get_moment2 () const |
| AbsDistMatrixType & | get_moment2 () |
| TensorDataType | get_current_beta1 () const noexcept |
| void | set_current_beta1 (TensorDataType current_beta1) |
| TensorDataType | get_current_beta2 () const noexcept |
| void | set_current_beta2 (TensorDataType current_beta2) |
Setup | |
| void | setup (WeightsType *w=nullptr) override |
Public Member Functions inherited from lbann::Cloneable< adam< TensorDataType >, data_type_optimizer< TensorDataType > > | |
| std::unique_ptr< adam< TensorDataType > > | clone () const |
| Return an exception-safe, memory-safe copy of this object. More... | |
Protected Member Functions | |
| adam () | |
| Default constructor. More... | |
| void | step_compute (AbsDistMatrixType &values, const AbsDistMatrixType &gradient) override |
Private Types | |
| using | BaseType = Cloneable< adam< TensorDataType >, data_type_optimizer< TensorDataType > > |
Private Member Functions | |
| void | step_compute_cpu (AbsDistMatrixType &values, const AbsDistMatrixType &gradient, const TensorDataType &correction) |
Private Attributes | |
| TensorDataType | m_beta1 |
| TensorDataType | m_beta2 |
| TensorDataType | m_eps |
| TensorDataType | m_adamw_weight_decay |
| TensorDataType | m_current_beta1 = TensorDataType(1.) |
| TensorDataType | m_current_beta2 = TensorDataType(1.) |
| std::unique_ptr< AbsDistMatrixType > | m_moment1 |
| std::unique_ptr< AbsDistMatrixType > | m_moment2 |
Friends | |
| class | callback::perturb_adam |
Adam optimizer.
Reference:
Diederik P. Kingma and Jimmy Ba. "Adam: A method for stochastic optimization." arXiv preprint arXiv:1412.6980 (2014).
| using lbann::adam< TensorDataType >::AbsDistMatrixType = El::AbstractDistMatrix<TensorDataType> |
|
private |
| using lbann::adam< TensorDataType >::OptimizerType = data_type_optimizer<TensorDataType> |
| using lbann::adam< TensorDataType >::WeightsType = data_type_weights<TensorDataType> |
| lbann::adam< TensorDataType >::adam | ( | TensorDataType | learning_rate, |
| TensorDataType | beta1 = 0.9, |
||
| TensorDataType | beta2 = 0.99, |
||
| TensorDataType | eps = 1e-8, |
||
| TensorDataType | adamw_weight_decay = 0.0 |
||
| ) |
| lbann::adam< TensorDataType >::adam | ( | const adam< TensorDataType > & | other | ) |
|
default |
|
inlineprotected |
|
inlinenoexcept |
|
inlinenoexcept |
|
inlinenoexcept |
|
inlinenoexcept |
|
inlinenoexcept |
|
override |
Human-readable description.
|
inlinenoexcept |
| const AbsDistMatrixType& lbann::adam< TensorDataType >::get_moment1 | ( | ) | const |
First moment estimates.
| AbsDistMatrixType& lbann::adam< TensorDataType >::get_moment1 | ( | ) |
First moment estimates.
| const AbsDistMatrixType& lbann::adam< TensorDataType >::get_moment2 | ( | ) | const |
Second moment estimates.
| AbsDistMatrixType& lbann::adam< TensorDataType >::get_moment2 | ( | ) |
Second moment estimates.
|
inlineoverride |
| adam& lbann::adam< TensorDataType >::operator= | ( | const adam< TensorDataType > & | other | ) |
| void lbann::adam< TensorDataType >::serialize | ( | Archive & | ar | ) |
Archive for checkpoint and restart
Definition at line 37 of file adam_impl.hpp.
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
override |
|
overrideprotected |
Computation for an optimization step.
|
private |
CPU implementation of optimization step.
|
final |
Add optimizer data to prototext
|
friend |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |
|
private |