|
LBANN
0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
|
Hypergradient Adam optimizer. More...
#include <hypergradient_adam.hpp>
Public Types | |
Public Types | |
| using | AbsDistMatrixType = El::AbstractDistMatrix< TensorDataType > |
| The tensor type expected in this object. More... | |
| using | WeightsType = data_type_weights< TensorDataType > |
| The concrete weights type used by this object. More... | |
| using | OptimizerType = data_type_optimizer< TensorDataType > |
| The base optimizer type for this class. More... | |
Public Member Functions | |
| hypergradient_adam (TensorDataType init_learning_rate=El::To< TensorDataType >(1e-3), TensorDataType hyper_learning_rate=El::To< TensorDataType >(1e-7), TensorDataType beta1=El::To< TensorDataType >(0.9), TensorDataType beta2=El::To< TensorDataType >(0.99), TensorDataType eps=El::To< TensorDataType >(1e-8)) | |
| Construct a Hypergradient Adam optimizer object. More... | |
| hypergradient_adam (const hypergradient_adam &other) | |
| hypergradient_adam & | operator= (const hypergradient_adam &other) |
| ~hypergradient_adam () override=default | |
| template<class Archive > | |
| void | serialize (Archive &ar) |
| std::string | get_type () const override |
| Human-readable type name. More... | |
| description | get_description () const override |
| Human-readable description. More... | |
| void | setup (WeightsType *w=nullptr) override |
| void | write_proto (lbann_data::Optimizer &opt) const final |
Public Member Functions inherited from lbann::Cloneable< hypergradient_adam< TensorDataType >, data_type_optimizer< TensorDataType > > | |
| std::unique_ptr< hypergradient_adam< TensorDataType > > | clone () const |
| Return an exception-safe, memory-safe copy of this object. More... | |
Protected Member Functions | |
| void | step_compute (AbsDistMatrixType &values, const AbsDistMatrixType &gradient) override |
| Computation for an optimization step. More... | |
Private Types | |
| using | BaseType = Cloneable< hypergradient_adam< TensorDataType >, data_type_optimizer< TensorDataType > > |
Private Attributes | |
| TensorDataType | m_hyper_learning_rate |
| Hypergradient learning rate. More... | |
| TensorDataType | m_beta1 |
| Update factor for first moment estimate. More... | |
| TensorDataType | m_beta2 |
| Update factor for second moment estimate. More... | |
| TensorDataType | m_eps |
| Small factor to avoid division by zero. More... | |
| TensorDataType | m_current_beta1 |
| beta1 ^ iteration. More... | |
| TensorDataType | m_current_beta2 |
| beta2 ^ iteration. More... | |
| std::unique_ptr< AbsDistMatrixType > | m_moment1 |
| First moment estimates. More... | |
| std::unique_ptr< AbsDistMatrixType > | m_moment2 |
| Second moment estimates. More... | |
| std::unique_ptr< AbsDistMatrixType > | m_old_gradient |
| Gradient estimate from the prior step (for hypergradient). More... | |
Hypergradient Adam optimizer.
Reference:
Baydin et al. "Online Learning Rate Adaptation with Hypergradient Descent", 2017.
Definition at line 45 of file hypergradient_adam.hpp.
| using lbann::hypergradient_adam< TensorDataType >::AbsDistMatrixType = El::AbstractDistMatrix<TensorDataType> |
The tensor type expected in this object.
Definition at line 56 of file hypergradient_adam.hpp.
|
private |
Definition at line 49 of file hypergradient_adam.hpp.
| using lbann::hypergradient_adam< TensorDataType >::OptimizerType = data_type_optimizer<TensorDataType> |
The base optimizer type for this class.
Definition at line 62 of file hypergradient_adam.hpp.
| using lbann::hypergradient_adam< TensorDataType >::WeightsType = data_type_weights<TensorDataType> |
The concrete weights type used by this object.
Definition at line 59 of file hypergradient_adam.hpp.
| lbann::hypergradient_adam< TensorDataType >::hypergradient_adam | ( | TensorDataType | init_learning_rate = El::To< TensorDataType >(1e-3), |
| TensorDataType | hyper_learning_rate = El::To< TensorDataType >(1e-7), |
||
| TensorDataType | beta1 = El::To< TensorDataType >(0.9), |
||
| TensorDataType | beta2 = El::To< TensorDataType >(0.99), |
||
| TensorDataType | eps = El::To< TensorDataType >(1e-8) |
||
| ) |
Construct a Hypergradient Adam optimizer object.
| init_learning_rate | Initial Adam learning rate (0.001 is reasonable). |
| hyper_learning_rate | Hypergradient learning rate. |
| beta1 | Decay rate for the first moment moving average. |
| beta2 | Decay rate for the second moment moving average. |
| eps | Small factor to avoid division by zero. |
| lbann::hypergradient_adam< TensorDataType >::hypergradient_adam | ( | const hypergradient_adam< TensorDataType > & | other | ) |
|
overridedefault |
|
override |
Human-readable description.
|
inlineoverride |
Human-readable type name.
Definition at line 94 of file hypergradient_adam.hpp.
| hypergradient_adam& lbann::hypergradient_adam< TensorDataType >::operator= | ( | const hypergradient_adam< TensorDataType > & | other | ) |
| void lbann::hypergradient_adam< TensorDataType >::serialize | ( | Archive & | ar | ) |
Archive for checkpoint and restart
Definition at line 37 of file hypergradient_adam_impl.hpp.
|
override |
|
overrideprotected |
Computation for an optimization step.
|
final |
Add optimizer data to prototext
|
private |
Update factor for first moment estimate.
Definition at line 113 of file hypergradient_adam.hpp.
|
private |
Update factor for second moment estimate.
Definition at line 115 of file hypergradient_adam.hpp.
|
private |
beta1 ^ iteration.
Definition at line 119 of file hypergradient_adam.hpp.
|
private |
beta2 ^ iteration.
Definition at line 121 of file hypergradient_adam.hpp.
|
private |
Small factor to avoid division by zero.
Definition at line 117 of file hypergradient_adam.hpp.
|
private |
Hypergradient learning rate.
Definition at line 111 of file hypergradient_adam.hpp.
|
private |
First moment estimates.
Definition at line 123 of file hypergradient_adam.hpp.
|
private |
Second moment estimates.
Definition at line 125 of file hypergradient_adam.hpp.
|
private |
Gradient estimate from the prior step (for hypergradient).
Definition at line 127 of file hypergradient_adam.hpp.