LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::sgd< TensorDataType > Class Template Reference

Stochastic gradient descent optimizer. More...

#include <sgd.hpp>

Inheritance diagram for lbann::sgd< TensorDataType >:
[legend]
Collaboration diagram for lbann::sgd< TensorDataType >:
[legend]

Public Types

Public Types
using AbsDistMatrixType = El::AbstractDistMatrix< TensorDataType >
 The tensor type expected in this object. More...
 
using OptimizerType = data_type_optimizer< TensorDataType >
 The optimizer base type of this object. More...
 
using WeightsType = data_type_weights< TensorDataType >
 The concrete weights type used by this object. More...
 

Public Member Functions

void write_proto (lbann_data::Optimizer &opt) const final
 
Life cycle functions
 sgd (TensorDataType learning_rate, TensorDataType momentum=0, bool nesterov=false)
 
 sgd (const sgd &other)
 
sgdoperator= (const sgd &other)
 
 ~sgd () override=default
 
Serialization
template<class ArchiveT >
void serialize (ArchiveT &ar)
 Serialize to the archive. More...
 
Descriptions
std::string get_type () const override
 
description get_description () const override
 
Access functions
TensorDataType get_momentum () const noexcept
 Decay rate for gradient accumulation. More...
 
void set_momentum (TensorDataType momentum)
 Decay rate for gradient accumulation. More...
 
bool using_nesterov () const noexcept
 
void set_nesterov (bool nesterov)
 
const AbsDistMatrixTypeget_velocity () const
 
AbsDistMatrixTypeget_velocity ()
 
Setup
void setup (WeightsType *w=nullptr) override
 
- Public Member Functions inherited from lbann::Cloneable< sgd< TensorDataType >, data_type_optimizer< TensorDataType > >
std::unique_ptr< sgd< TensorDataType > > clone () const
 Return an exception-safe, memory-safe copy of this object. More...
 

Protected Member Functions

 sgd ()
 Default constructor. More...
 
void step_compute (AbsDistMatrixType &values, const AbsDistMatrixType &gradient) override
 

Private Types

using BaseType = Cloneable< sgd< TensorDataType >, data_type_optimizer< TensorDataType > >
 

Private Member Functions

void momentum_step_cpu (AbsDistMatrixType &values, const AbsDistMatrixType &gradient)
 

Private Attributes

TensorDataType m_momentum
 Decay rate for gradient accumulation. More...
 
bool m_nesterov
 
std::unique_ptr< AbsDistMatrixTypem_velocity
 Accumulated gradients. More...
 

Friends

class cereal::access
 

Detailed Description

template<typename TensorDataType>
class lbann::sgd< TensorDataType >

Stochastic gradient descent optimizer.

Supports momentum and Nesterov acceleration.

Todo:
Dedicated optimizers for momentum or Nesterov SGD.

Definition at line 41 of file sgd.hpp.

Member Typedef Documentation

◆ AbsDistMatrixType

template<typename TensorDataType >
using lbann::sgd< TensorDataType >::AbsDistMatrixType = El::AbstractDistMatrix<TensorDataType>

The tensor type expected in this object.

Definition at line 52 of file sgd.hpp.

◆ BaseType

template<typename TensorDataType >
using lbann::sgd< TensorDataType >::BaseType = Cloneable<sgd<TensorDataType>, data_type_optimizer<TensorDataType> >
private

Definition at line 45 of file sgd.hpp.

◆ OptimizerType

template<typename TensorDataType >
using lbann::sgd< TensorDataType >::OptimizerType = data_type_optimizer<TensorDataType>

The optimizer base type of this object.

Definition at line 55 of file sgd.hpp.

◆ WeightsType

template<typename TensorDataType >
using lbann::sgd< TensorDataType >::WeightsType = data_type_weights<TensorDataType>

The concrete weights type used by this object.

Definition at line 58 of file sgd.hpp.

Constructor & Destructor Documentation

◆ sgd() [1/3]

template<typename TensorDataType >
lbann::sgd< TensorDataType >::sgd ( TensorDataType  learning_rate,
TensorDataType  momentum = 0,
bool  nesterov = false 
)

◆ sgd() [2/3]

template<typename TensorDataType >
lbann::sgd< TensorDataType >::sgd ( const sgd< TensorDataType > &  other)

◆ ~sgd()

template<typename TensorDataType >
lbann::sgd< TensorDataType >::~sgd ( )
overridedefault

◆ sgd() [3/3]

template<typename TensorDataType >
lbann::sgd< TensorDataType >::sgd ( )
inlineprotected

Default constructor.

This constructor exists as an implementation detail of the serialization code. It is not for general use.

Definition at line 131 of file sgd.hpp.

Here is the call graph for this function:

Member Function Documentation

◆ get_description()

template<typename TensorDataType >
description lbann::sgd< TensorDataType >::get_description ( ) const
override

Human-readable description.

Here is the caller graph for this function:

◆ get_momentum()

template<typename TensorDataType >
TensorDataType lbann::sgd< TensorDataType >::get_momentum ( ) const
inlinenoexcept

Decay rate for gradient accumulation.

A momentum of zero corresponds to vanilla SGD.

Definition at line 98 of file sgd.hpp.

◆ get_type()

template<typename TensorDataType >
std::string lbann::sgd< TensorDataType >::get_type ( ) const
inlineoverride

Human-readable type name.

Definition at line 86 of file sgd.hpp.

Here is the call graph for this function:

◆ get_velocity() [1/2]

template<typename TensorDataType >
const AbsDistMatrixType& lbann::sgd< TensorDataType >::get_velocity ( ) const

Accumulated gradients for momentum optimizer.

Here is the caller graph for this function:

◆ get_velocity() [2/2]

template<typename TensorDataType >
AbsDistMatrixType& lbann::sgd< TensorDataType >::get_velocity ( )

Accumulated gradients for momentum optimizer.

◆ momentum_step_cpu()

template<typename TensorDataType >
void lbann::sgd< TensorDataType >::momentum_step_cpu ( AbsDistMatrixType values,
const AbsDistMatrixType gradient 
)
private

CPU implementation of momentum or Nesterov step.

◆ operator=()

template<typename TensorDataType >
sgd& lbann::sgd< TensorDataType >::operator= ( const sgd< TensorDataType > &  other)

◆ serialize()

template<typename TensorDataType >
template<class ArchiveT >
void lbann::sgd< TensorDataType >::serialize ( ArchiveT &  ar)

Serialize to the archive.

Definition at line 37 of file sgd_impl.hpp.

◆ set_momentum()

template<typename TensorDataType >
void lbann::sgd< TensorDataType >::set_momentum ( TensorDataType  momentum)
inline

Decay rate for gradient accumulation.

A momentum of zero corresponds to vanilla SGD.

Definition at line 102 of file sgd.hpp.

◆ set_nesterov()

template<typename TensorDataType >
void lbann::sgd< TensorDataType >::set_nesterov ( bool  nesterov)
inline

Whether Nesterov acceleration is applied.

Definition at line 107 of file sgd.hpp.

Here is the call graph for this function:

◆ setup()

template<typename TensorDataType >
void lbann::sgd< TensorDataType >::setup ( WeightsType w = nullptr)
override
Here is the caller graph for this function:

◆ step_compute()

template<typename TensorDataType >
void lbann::sgd< TensorDataType >::step_compute ( AbsDistMatrixType values,
const AbsDistMatrixType gradient 
)
overrideprotected

Computation for an optimization step.

Here is the caller graph for this function:

◆ using_nesterov()

template<typename TensorDataType >
bool lbann::sgd< TensorDataType >::using_nesterov ( ) const
inlinenoexcept

Whether Nesterov acceleration is applied.

Definition at line 105 of file sgd.hpp.

◆ write_proto()

template<typename TensorDataType >
void lbann::sgd< TensorDataType >::write_proto ( lbann_data::Optimizer &  opt) const
final

Add optimizer data to prototext

Here is the caller graph for this function:

Friends And Related Function Documentation

◆ cereal::access

template<typename TensorDataType >
friend class cereal::access
friend

Definition at line 159 of file sgd.hpp.

Member Data Documentation

◆ m_momentum

template<typename TensorDataType >
TensorDataType lbann::sgd< TensorDataType >::m_momentum
private

Decay rate for gradient accumulation.

A momentum of zero corresponds to vanilla SGD.

Definition at line 142 of file sgd.hpp.

◆ m_nesterov

template<typename TensorDataType >
bool lbann::sgd< TensorDataType >::m_nesterov
private

Whether Nesterov acceleration is used.

Definition at line 144 of file sgd.hpp.

◆ m_velocity

template<typename TensorDataType >
std::unique_ptr<AbsDistMatrixType> lbann::sgd< TensorDataType >::m_velocity
private

Accumulated gradients.

Not used for vanilla SGD.

Definition at line 148 of file sgd.hpp.


The documentation for this class was generated from the following files: