LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::rmsprop< TensorDataType > Class Template Reference

#include <rmsprop.hpp>

Inheritance diagram for lbann::rmsprop< TensorDataType >:
[legend]
Collaboration diagram for lbann::rmsprop< TensorDataType >:
[legend]

Public Types

Public Types
using AbsDistMatrixType = El::AbstractDistMatrix< TensorDataType >
 The tensor type expected in this object. More...
 
using OptimizerType = data_type_optimizer< TensorDataType >
 The optimizer base type of this object. More...
 
using WeightsType = data_type_weights< TensorDataType >
 The concrete weights type used by this object. More...
 

Public Member Functions

 rmsprop (TensorDataType learning_rate, TensorDataType decay_rate, TensorDataType eps=1e-8)
 
 rmsprop (const rmsprop &other)
 
rmspropoperator= (const rmsprop &other)
 
 ~rmsprop () override=default
 
template<class Archive >
void serialize (Archive &ar)
 
std::string get_type () const override
 
description get_description () const override
 
void setup (WeightsType *w=nullptr) override
 
void write_proto (lbann_data::Optimizer &opt) const final
 
- Public Member Functions inherited from lbann::Cloneable< rmsprop< TensorDataType >, data_type_optimizer< TensorDataType > >
std::unique_ptr< rmsprop< TensorDataType > > clone () const
 Return an exception-safe, memory-safe copy of this object. More...
 

Protected Member Functions

 rmsprop ()
 Default constructor. More...
 
void step_compute (AbsDistMatrixType &values, const AbsDistMatrixType &gradient) override
 

Private Types

using BaseType = Cloneable< rmsprop< TensorDataType >, data_type_optimizer< TensorDataType > >
 

Private Member Functions

void step_compute_cpu (AbsDistMatrixType &values, const AbsDistMatrixType &gradient)
 

Private Attributes

TensorDataType m_decay_rate
 
TensorDataType m_eps
 
std::unique_ptr< AbsDistMatrixTypem_cache
 

Detailed Description

template<typename TensorDataType>
class lbann::rmsprop< TensorDataType >

RMSprop optimizer.

See https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf.

Definition at line 43 of file rmsprop.hpp.

Member Typedef Documentation

◆ AbsDistMatrixType

template<typename TensorDataType >
using lbann::rmsprop< TensorDataType >::AbsDistMatrixType = El::AbstractDistMatrix<TensorDataType>

The tensor type expected in this object.

Definition at line 54 of file rmsprop.hpp.

◆ BaseType

template<typename TensorDataType >
using lbann::rmsprop< TensorDataType >::BaseType = Cloneable<rmsprop<TensorDataType>, data_type_optimizer<TensorDataType> >
private

Definition at line 47 of file rmsprop.hpp.

◆ OptimizerType

template<typename TensorDataType >
using lbann::rmsprop< TensorDataType >::OptimizerType = data_type_optimizer<TensorDataType>

The optimizer base type of this object.

Definition at line 57 of file rmsprop.hpp.

◆ WeightsType

template<typename TensorDataType >
using lbann::rmsprop< TensorDataType >::WeightsType = data_type_weights<TensorDataType>

The concrete weights type used by this object.

Definition at line 60 of file rmsprop.hpp.

Constructor & Destructor Documentation

◆ rmsprop() [1/3]

template<typename TensorDataType >
lbann::rmsprop< TensorDataType >::rmsprop ( TensorDataType  learning_rate,
TensorDataType  decay_rate,
TensorDataType  eps = 1e-8 
)

◆ rmsprop() [2/3]

template<typename TensorDataType >
lbann::rmsprop< TensorDataType >::rmsprop ( const rmsprop< TensorDataType > &  other)

◆ ~rmsprop()

template<typename TensorDataType >
lbann::rmsprop< TensorDataType >::~rmsprop ( )
overridedefault

◆ rmsprop() [3/3]

template<typename TensorDataType >
lbann::rmsprop< TensorDataType >::rmsprop ( )
inlineprotected

Default constructor.

This constructor exists as an implementation detail of the serialization code. It is not for general use.

Definition at line 94 of file rmsprop.hpp.

Here is the call graph for this function:

Member Function Documentation

◆ get_description()

template<typename TensorDataType >
description lbann::rmsprop< TensorDataType >::get_description ( ) const
override

Human-readable description.

Here is the caller graph for this function:

◆ get_type()

template<typename TensorDataType >
std::string lbann::rmsprop< TensorDataType >::get_type ( ) const
inlineoverride

Human-readable type name.

Definition at line 77 of file rmsprop.hpp.

Here is the call graph for this function:

◆ operator=()

template<typename TensorDataType >
rmsprop& lbann::rmsprop< TensorDataType >::operator= ( const rmsprop< TensorDataType > &  other)

◆ serialize()

template<typename TensorDataType >
template<class Archive >
void lbann::rmsprop< TensorDataType >::serialize ( Archive &  ar)

Archive for checkpoint and restart

Definition at line 37 of file rmsprop_impl.hpp.

◆ setup()

template<typename TensorDataType >
void lbann::rmsprop< TensorDataType >::setup ( WeightsType w = nullptr)
override
Here is the caller graph for this function:

◆ step_compute()

template<typename TensorDataType >
void lbann::rmsprop< TensorDataType >::step_compute ( AbsDistMatrixType values,
const AbsDistMatrixType gradient 
)
overrideprotected

Computation for an optimization step.

Here is the caller graph for this function:

◆ step_compute_cpu()

template<typename TensorDataType >
void lbann::rmsprop< TensorDataType >::step_compute_cpu ( AbsDistMatrixType values,
const AbsDistMatrixType gradient 
)
private

CPU implementation of optimization step.

◆ write_proto()

template<typename TensorDataType >
void lbann::rmsprop< TensorDataType >::write_proto ( lbann_data::Optimizer &  opt) const
final

Add optimizer data to prototext

Here is the caller graph for this function:

Member Data Documentation

◆ m_cache

template<typename TensorDataType >
std::unique_ptr<AbsDistMatrixType> lbann::rmsprop< TensorDataType >::m_cache
private

RMSprop cache.

Definition at line 110 of file rmsprop.hpp.

◆ m_decay_rate

template<typename TensorDataType >
TensorDataType lbann::rmsprop< TensorDataType >::m_decay_rate
private

Decay rate.

Definition at line 106 of file rmsprop.hpp.

◆ m_eps

template<typename TensorDataType >
TensorDataType lbann::rmsprop< TensorDataType >::m_eps
private

Small factor to avoid division by zero.

Definition at line 108 of file rmsprop.hpp.


The documentation for this class was generated from the following files: