LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::adam< TensorDataType > Class Template Reference

Adam optimizer. More...

#include <adam.hpp>

Inheritance diagram for lbann::adam< TensorDataType >:
[legend]
Collaboration diagram for lbann::adam< TensorDataType >:
[legend]

Public Types

Public Types
using AbsDistMatrixType = El::AbstractDistMatrix< TensorDataType >
 The tensor type expected in this object. More...
 
using OptimizerType = data_type_optimizer< TensorDataType >
 The optimizer base type of this object. More...
 
using WeightsType = data_type_weights< TensorDataType >
 The concrete weights type used by this object. More...
 

Public Member Functions

void write_proto (lbann_data::Optimizer &opt) const final
 
Life cycle functions
 adam (TensorDataType learning_rate, TensorDataType beta1=0.9, TensorDataType beta2=0.99, TensorDataType eps=1e-8, TensorDataType adamw_weight_decay=0.0)
 
 adam (const adam &other)
 
adamoperator= (const adam &other)
 
 ~adam ()=default
 
template<class Archive >
void serialize (Archive &ar)
 
Descriptions
std::string get_type () const override
 
description get_description () const override
 
Access functions
TensorDataType get_beta1 () const noexcept
 
void set_beta1 (TensorDataType beta1)
 
TensorDataType get_beta2 () const noexcept
 
void set_beta2 (TensorDataType beta2)
 
TensorDataType get_eps () const noexcept
 
void set_eps (TensorDataType eps)
 
TensorDataType get_adamw_weight_decay () const noexcept
 
void set_adamw_weight_decay (TensorDataType adamw_weight_decay)
 
const AbsDistMatrixTypeget_moment1 () const
 
AbsDistMatrixTypeget_moment1 ()
 
const AbsDistMatrixTypeget_moment2 () const
 
AbsDistMatrixTypeget_moment2 ()
 
TensorDataType get_current_beta1 () const noexcept
 
void set_current_beta1 (TensorDataType current_beta1)
 
TensorDataType get_current_beta2 () const noexcept
 
void set_current_beta2 (TensorDataType current_beta2)
 
Setup
void setup (WeightsType *w=nullptr) override
 
- Public Member Functions inherited from lbann::Cloneable< adam< TensorDataType >, data_type_optimizer< TensorDataType > >
std::unique_ptr< adam< TensorDataType > > clone () const
 Return an exception-safe, memory-safe copy of this object. More...
 

Protected Member Functions

 adam ()
 Default constructor. More...
 
void step_compute (AbsDistMatrixType &values, const AbsDistMatrixType &gradient) override
 

Private Types

using BaseType = Cloneable< adam< TensorDataType >, data_type_optimizer< TensorDataType > >
 

Private Member Functions

void step_compute_cpu (AbsDistMatrixType &values, const AbsDistMatrixType &gradient, const TensorDataType &correction)
 

Private Attributes

TensorDataType m_beta1
 
TensorDataType m_beta2
 
TensorDataType m_eps
 
TensorDataType m_adamw_weight_decay
 
TensorDataType m_current_beta1 = TensorDataType(1.)
 
TensorDataType m_current_beta2 = TensorDataType(1.)
 
std::unique_ptr< AbsDistMatrixTypem_moment1
 
std::unique_ptr< AbsDistMatrixTypem_moment2
 

Friends

class callback::perturb_adam
 

Detailed Description

template<typename TensorDataType>
class lbann::adam< TensorDataType >

Adam optimizer.

Reference:

Diederik P. Kingma and Jimmy Ba. "Adam: A method for stochastic optimization." arXiv preprint arXiv:1412.6980 (2014).

Definition at line 47 of file adam.hpp.

Member Typedef Documentation

◆ AbsDistMatrixType

template<typename TensorDataType>
using lbann::adam< TensorDataType >::AbsDistMatrixType = El::AbstractDistMatrix<TensorDataType>

The tensor type expected in this object.

Definition at line 58 of file adam.hpp.

◆ BaseType

template<typename TensorDataType>
using lbann::adam< TensorDataType >::BaseType = Cloneable<adam<TensorDataType>, data_type_optimizer<TensorDataType> >
private

Definition at line 51 of file adam.hpp.

◆ OptimizerType

template<typename TensorDataType>
using lbann::adam< TensorDataType >::OptimizerType = data_type_optimizer<TensorDataType>

The optimizer base type of this object.

Definition at line 61 of file adam.hpp.

◆ WeightsType

template<typename TensorDataType>
using lbann::adam< TensorDataType >::WeightsType = data_type_weights<TensorDataType>

The concrete weights type used by this object.

Definition at line 64 of file adam.hpp.

Constructor & Destructor Documentation

◆ adam() [1/3]

template<typename TensorDataType>
lbann::adam< TensorDataType >::adam ( TensorDataType  learning_rate,
TensorDataType  beta1 = 0.9,
TensorDataType  beta2 = 0.99,
TensorDataType  eps = 1e-8,
TensorDataType  adamw_weight_decay = 0.0 
)

◆ adam() [2/3]

template<typename TensorDataType>
lbann::adam< TensorDataType >::adam ( const adam< TensorDataType > &  other)

◆ ~adam()

template<typename TensorDataType>
lbann::adam< TensorDataType >::~adam ( )
default

◆ adam() [3/3]

template<typename TensorDataType>
lbann::adam< TensorDataType >::adam ( )
inlineprotected

Default constructor.

This constructor exists as an implementation detail of the serialization code. It is not for general use.

Definition at line 175 of file adam.hpp.

Member Function Documentation

◆ get_adamw_weight_decay()

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::get_adamw_weight_decay ( ) const
inlinenoexcept

Regularizer coefficient for AdamW weight decay.

Definition at line 113 of file adam.hpp.

◆ get_beta1()

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::get_beta1 ( ) const
inlinenoexcept

Update factor for first moment estimate.

Definition at line 101 of file adam.hpp.

◆ get_beta2()

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::get_beta2 ( ) const
inlinenoexcept

Update factor for second moment estimate.

Definition at line 105 of file adam.hpp.

◆ get_current_beta1()

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::get_current_beta1 ( ) const
inlinenoexcept

beta1 ^ iteration.

Todo:
This probably shouldn't be exposed.

Definition at line 135 of file adam.hpp.

◆ get_current_beta2()

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::get_current_beta2 ( ) const
inlinenoexcept

beta2 ^ iteration.

Todo:
This probably shouldn't be exposed.

Definition at line 146 of file adam.hpp.

◆ get_description()

template<typename TensorDataType>
description lbann::adam< TensorDataType >::get_description ( ) const
override

Human-readable description.

◆ get_eps()

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::get_eps ( ) const
inlinenoexcept

Small factor to avoid division by zero.

Definition at line 109 of file adam.hpp.

◆ get_moment1() [1/2]

template<typename TensorDataType>
const AbsDistMatrixType& lbann::adam< TensorDataType >::get_moment1 ( ) const

First moment estimates.

◆ get_moment1() [2/2]

template<typename TensorDataType>
AbsDistMatrixType& lbann::adam< TensorDataType >::get_moment1 ( )

First moment estimates.

◆ get_moment2() [1/2]

template<typename TensorDataType>
const AbsDistMatrixType& lbann::adam< TensorDataType >::get_moment2 ( ) const

Second moment estimates.

◆ get_moment2() [2/2]

template<typename TensorDataType>
AbsDistMatrixType& lbann::adam< TensorDataType >::get_moment2 ( )

Second moment estimates.

◆ get_type()

template<typename TensorDataType>
std::string lbann::adam< TensorDataType >::get_type ( ) const
inlineoverride

Human-readable type name.

Definition at line 91 of file adam.hpp.

◆ operator=()

template<typename TensorDataType>
adam& lbann::adam< TensorDataType >::operator= ( const adam< TensorDataType > &  other)

◆ serialize()

template<typename TensorDataType >
template<class Archive >
void lbann::adam< TensorDataType >::serialize ( Archive &  ar)

Archive for checkpoint and restart

Definition at line 37 of file adam_impl.hpp.

◆ set_adamw_weight_decay()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::set_adamw_weight_decay ( TensorDataType  adamw_weight_decay)
inline

Regularizer coefficient for AdamW weight decay.

Definition at line 118 of file adam.hpp.

◆ set_beta1()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::set_beta1 ( TensorDataType  beta1)
inline

Update factor for first moment estimate.

Definition at line 103 of file adam.hpp.

◆ set_beta2()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::set_beta2 ( TensorDataType  beta2)
inline

Update factor for second moment estimate.

Definition at line 107 of file adam.hpp.

◆ set_current_beta1()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::set_current_beta1 ( TensorDataType  current_beta1)
inline

beta1 ^ iteration.

Todo:
This probably shouldn't be exposed.

Definition at line 139 of file adam.hpp.

◆ set_current_beta2()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::set_current_beta2 ( TensorDataType  current_beta2)
inline

beta2 ^ iteration.

Todo:
This probably shouldn't be exposed.

Definition at line 150 of file adam.hpp.

◆ set_eps()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::set_eps ( TensorDataType  eps)
inline

Small factor to avoid division by zero.

Definition at line 111 of file adam.hpp.

◆ setup()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::setup ( WeightsType w = nullptr)
override

◆ step_compute()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::step_compute ( AbsDistMatrixType values,
const AbsDistMatrixType gradient 
)
overrideprotected

Computation for an optimization step.

◆ step_compute_cpu()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::step_compute_cpu ( AbsDistMatrixType values,
const AbsDistMatrixType gradient,
const TensorDataType &  correction 
)
private

CPU implementation of optimization step.

◆ write_proto()

template<typename TensorDataType>
void lbann::adam< TensorDataType >::write_proto ( lbann_data::Optimizer &  opt) const
final

Add optimizer data to prototext

Friends And Related Function Documentation

◆ callback::perturb_adam

template<typename TensorDataType>
friend class callback::perturb_adam
friend

Hyperparameter exploration.

Definition at line 206 of file adam.hpp.

Member Data Documentation

◆ m_adamw_weight_decay

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::m_adamw_weight_decay
private

Regularizer coefficient for AdamW weight decay.

Definition at line 195 of file adam.hpp.

◆ m_beta1

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::m_beta1
private

Update factor for first moment estimate.

Definition at line 189 of file adam.hpp.

◆ m_beta2

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::m_beta2
private

Update factor for second moment estimate.

Definition at line 191 of file adam.hpp.

◆ m_current_beta1

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::m_current_beta1 = TensorDataType(1.)
private

beta1 ^ iteration.

Definition at line 197 of file adam.hpp.

◆ m_current_beta2

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::m_current_beta2 = TensorDataType(1.)
private

beta2 ^ iteration.

Definition at line 199 of file adam.hpp.

◆ m_eps

template<typename TensorDataType>
TensorDataType lbann::adam< TensorDataType >::m_eps
private

Small factor to avoid division by zero.

Definition at line 193 of file adam.hpp.

◆ m_moment1

template<typename TensorDataType>
std::unique_ptr<AbsDistMatrixType> lbann::adam< TensorDataType >::m_moment1
private

First moment estimates.

Definition at line 201 of file adam.hpp.

◆ m_moment2

template<typename TensorDataType>
std::unique_ptr<AbsDistMatrixType> lbann::adam< TensorDataType >::m_moment2
private

Second moment estimates.

Definition at line 203 of file adam.hpp.


The documentation for this class was generated from the following files: