LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::adagrad< TensorDataType > Class Template Reference

#include <adagrad.hpp>

Inheritance diagram for lbann::adagrad< TensorDataType >:
[legend]
Collaboration diagram for lbann::adagrad< TensorDataType >:
[legend]

Public Types

Public Types
using AbsDistMatrixType = El::AbstractDistMatrix< TensorDataType >
 The tensor type expected in this object. More...
 
using OptimizerType = data_type_optimizer< TensorDataType >
 The optimizer base type of this object. More...
 
using WeightsType = data_type_weights< TensorDataType >
 The concrete weights type used by this object. More...
 

Public Member Functions

 adagrad (TensorDataType learning_rate, TensorDataType eps=1e-8)
 
 adagrad (const adagrad &other)
 
adagradoperator= (const adagrad &other)
 
 ~adagrad () override=default
 
template<class Archive >
void serialize (Archive &ar)
 
std::string get_type () const override
 
description get_description () const override
 
void setup (WeightsType *w=nullptr) override
 
void write_proto (lbann_data::Optimizer &opt) const final
 
- Public Member Functions inherited from lbann::Cloneable< adagrad< TensorDataType >, data_type_optimizer< TensorDataType > >
std::unique_ptr< adagrad< TensorDataType > > clone () const
 Return an exception-safe, memory-safe copy of this object. More...
 

Protected Member Functions

 adagrad ()
 Default constructor. More...
 
void step_compute (AbsDistMatrixType &values, const AbsDistMatrixType &gradient) override
 

Private Types

using BaseType = Cloneable< adagrad< TensorDataType >, data_type_optimizer< TensorDataType > >
 

Private Member Functions

void step_compute_cpu (AbsDistMatrixType &values, const AbsDistMatrixType &gradient)
 

Private Attributes

TensorDataType m_eps
 
std::unique_ptr< AbsDistMatrixTypem_cache
 

Detailed Description

template<typename TensorDataType>
class lbann::adagrad< TensorDataType >

AdaGrad optimizer.

Reference:

John Duchi, Elad Hazan, and Yoram Singer. "Adaptive subgradient methods for online learning and stochastic optimization." Journal of Machine Learning Research 12, no. Jul (2011): 2121-2159.

Definition at line 45 of file adagrad.hpp.

Member Typedef Documentation

◆ AbsDistMatrixType

template<typename TensorDataType >
using lbann::adagrad< TensorDataType >::AbsDistMatrixType = El::AbstractDistMatrix<TensorDataType>

The tensor type expected in this object.

Definition at line 56 of file adagrad.hpp.

◆ BaseType

template<typename TensorDataType >
using lbann::adagrad< TensorDataType >::BaseType = Cloneable<adagrad<TensorDataType>, data_type_optimizer<TensorDataType> >
private

Definition at line 49 of file adagrad.hpp.

◆ OptimizerType

template<typename TensorDataType >
using lbann::adagrad< TensorDataType >::OptimizerType = data_type_optimizer<TensorDataType>

The optimizer base type of this object.

Definition at line 59 of file adagrad.hpp.

◆ WeightsType

template<typename TensorDataType >
using lbann::adagrad< TensorDataType >::WeightsType = data_type_weights<TensorDataType>

The concrete weights type used by this object.

Definition at line 62 of file adagrad.hpp.

Constructor & Destructor Documentation

◆ adagrad() [1/3]

template<typename TensorDataType >
lbann::adagrad< TensorDataType >::adagrad ( TensorDataType  learning_rate,
TensorDataType  eps = 1e-8 
)

◆ adagrad() [2/3]

template<typename TensorDataType >
lbann::adagrad< TensorDataType >::adagrad ( const adagrad< TensorDataType > &  other)

◆ ~adagrad()

template<typename TensorDataType >
lbann::adagrad< TensorDataType >::~adagrad ( )
overridedefault

◆ adagrad() [3/3]

template<typename TensorDataType >
lbann::adagrad< TensorDataType >::adagrad ( )
inlineprotected

Default constructor.

This constructor exists as an implementation detail of the serialization code. It is not for general use.

Definition at line 94 of file adagrad.hpp.

Here is the call graph for this function:

Member Function Documentation

◆ get_description()

template<typename TensorDataType >
description lbann::adagrad< TensorDataType >::get_description ( ) const
override

Human-readable description.

Here is the caller graph for this function:

◆ get_type()

template<typename TensorDataType >
std::string lbann::adagrad< TensorDataType >::get_type ( ) const
inlineoverride

Human-readable type name.

Definition at line 77 of file adagrad.hpp.

Here is the call graph for this function:

◆ operator=()

template<typename TensorDataType >
adagrad& lbann::adagrad< TensorDataType >::operator= ( const adagrad< TensorDataType > &  other)

◆ serialize()

template<typename TensorDataType >
template<class Archive >
void lbann::adagrad< TensorDataType >::serialize ( Archive &  ar)

Archive for checkpoint and restart

Definition at line 37 of file adagrad_impl.hpp.

◆ setup()

template<typename TensorDataType >
void lbann::adagrad< TensorDataType >::setup ( WeightsType w = nullptr)
override
Here is the caller graph for this function:

◆ step_compute()

template<typename TensorDataType >
void lbann::adagrad< TensorDataType >::step_compute ( AbsDistMatrixType values,
const AbsDistMatrixType gradient 
)
overrideprotected

Computation for an optimization step.

Here is the caller graph for this function:

◆ step_compute_cpu()

template<typename TensorDataType >
void lbann::adagrad< TensorDataType >::step_compute_cpu ( AbsDistMatrixType values,
const AbsDistMatrixType gradient 
)
private

CPU implementation of optimization step.

◆ write_proto()

template<typename TensorDataType >
void lbann::adagrad< TensorDataType >::write_proto ( lbann_data::Optimizer &  opt) const
final

Add optimizer data to prototext

Here is the caller graph for this function:

Member Data Documentation

◆ m_cache

template<typename TensorDataType >
std::unique_ptr<AbsDistMatrixType> lbann::adagrad< TensorDataType >::m_cache
private

AdaGrad cache.

Definition at line 105 of file adagrad.hpp.

◆ m_eps

template<typename TensorDataType >
TensorDataType lbann::adagrad< TensorDataType >::m_eps
private

Small factor to avoid division by zero.

Definition at line 103 of file adagrad.hpp.


The documentation for this class was generated from the following files: