LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::kfac::KFACExecutionContext Class Referencefinal

#include <execution_context.hpp>

Inheritance diagram for lbann::kfac::KFACExecutionContext:
[legend]
Collaboration diagram for lbann::kfac::KFACExecutionContext:
[legend]

Public Member Functions

 KFACExecutionContext (double damping_act, double damping_err, double damping_bn_act, double damping_bn_err)
 
 ~KFACExecutionContext ()=default
 
 KFACExecutionContext (const KFACExecutionContext &other)=delete
 
KFACExecutionContextoperator= (const KFACExecutionContext &other)=delete
 
std::unique_ptr< lbann::ExecutionContextget_new () const override
 
std::string get_type () const override
 Get a string identifying the type of execution context. More...
 
std::string get_state_string () const noexcept override
 Return the state of the execution context as a string. More...
 
SGDExecutionContextget_sgd_execution_context () noexcept
 Return execution context for SGD-family training algorithm. More...
 
El::Matrix< DataType, Device > & get_workspace_matrix (const std::string &key, const size_t height, const size_t width)
 Gets the Kronecker factor matrix of a FC layer. The same key is tied with the same matrix instance. More...
 
void print_workspace_size (model &model)
 
Checkpointing and Serialization
template<class Archive >
void serialize (Archive &ar)
 
void save_to_checkpoint_shared (persist &p) override
 Checkpoint exection_context to a shared checkpoint. More...
 
void load_from_checkpoint_shared (persist &p) override
 Restore execution_context from a shared checkpoint. More...
 
void save_to_checkpoint_distributed (persist &p) override
 Checkpoint exection_context to a distributed checkpoint. More...
 
void load_from_checkpoint_distributed (persist &p) override
 Restore execution_context from a distributed checkpoint. More...
 
- Public Member Functions inherited from lbann::ExecutionContext
 ExecutionContext ()
 
virtual ~ExecutionContext ()=default
 
virtual execution_mode get_execution_mode () const noexcept
 
size_t get_step () const noexcept
 Current step in the training algorithm. More...
 
void inc_step () noexcept
 Increment the current step in the training algorithm. More...
 
template<class Archive >
void serialize (Archive &ar)
 

Private Attributes

SGDExecutionContext m_sgd_execution_context
 
double m_damping_act
 The current damping values. More...
 
double m_damping_err
 
double m_damping_bn_act
 
double m_damping_bn_err
 
size_t m_update_interval
 The current update interval. More...
 
std::vector< std::shared_ptr< kfac_block< Device > > > m_blocks
 K-FAC per-layer blocks. More...
 
std::unordered_map< std::string, El::Matrix< DataType, Device > > m_workspace
 Workspace matrices that are used by m_blocks. More...
 

Friends

class ::lbann::KFAC
 

Additional Inherited Members

- Protected Member Functions inherited from lbann::ExecutionContext
 ExecutionContext (const ExecutionContext &other)=delete
 
ExecutionContextoperator= (const ExecutionContext &other)=delete
 
 ExecutionContext (ExecutionContext &&other)=default
 
ExecutionContextoperator= (ExecutionContext &&other)=default
 

Detailed Description

Definition at line 57 of file kfac/execution_context.hpp.

Constructor & Destructor Documentation

◆ KFACExecutionContext() [1/2]

lbann::kfac::KFACExecutionContext::KFACExecutionContext ( double  damping_act,
double  damping_err,
double  damping_bn_act,
double  damping_bn_err 
)

Constructor.

◆ ~KFACExecutionContext()

lbann::kfac::KFACExecutionContext::~KFACExecutionContext ( )
default

Destructor.

◆ KFACExecutionContext() [2/2]

lbann::kfac::KFACExecutionContext::KFACExecutionContext ( const KFACExecutionContext other)
delete

Copy constructor – deleted.

Member Function Documentation

◆ get_new()

std::unique_ptr<lbann::ExecutionContext> lbann::kfac::KFACExecutionContext::get_new ( ) const
overridevirtual

Get a "clean" execution_context of the same type.

Implements lbann::ExecutionContext.

◆ get_sgd_execution_context()

SGDExecutionContext& lbann::kfac::KFACExecutionContext::get_sgd_execution_context ( )
inlinenoexcept

Return execution context for SGD-family training algorithm.

Definition at line 88 of file kfac/execution_context.hpp.

Here is the call graph for this function:

◆ get_state_string()

std::string lbann::kfac::KFACExecutionContext::get_state_string ( ) const
overridevirtualnoexcept

Return the state of the execution context as a string.

Implements lbann::ExecutionContext.

◆ get_type()

std::string lbann::kfac::KFACExecutionContext::get_type ( ) const
overridevirtual

Get a string identifying the type of execution context.

Should match the training algorithm.

Todo:
Absorb completely into get_state_string().

Implements lbann::ExecutionContext.

◆ get_workspace_matrix()

El::Matrix<DataType, Device>& lbann::kfac::KFACExecutionContext::get_workspace_matrix ( const std::string &  key,
const size_t  height,
const size_t  width 
)

Gets the Kronecker factor matrix of a FC layer. The same key is tied with the same matrix instance.

Here is the caller graph for this function:

◆ load_from_checkpoint_distributed()

void lbann::kfac::KFACExecutionContext::load_from_checkpoint_distributed ( persist p)
overridevirtual

Restore execution_context from a distributed checkpoint.

Implements lbann::ExecutionContext.

Here is the caller graph for this function:

◆ load_from_checkpoint_shared()

void lbann::kfac::KFACExecutionContext::load_from_checkpoint_shared ( persist p)
overridevirtual

Restore execution_context from a shared checkpoint.

Implements lbann::ExecutionContext.

Here is the caller graph for this function:

◆ operator=()

KFACExecutionContext& lbann::kfac::KFACExecutionContext::operator= ( const KFACExecutionContext other)
delete

Copy assignment operator – deleted.

◆ print_workspace_size()

void lbann::kfac::KFACExecutionContext::print_workspace_size ( model model)
Here is the caller graph for this function:

◆ save_to_checkpoint_distributed()

void lbann::kfac::KFACExecutionContext::save_to_checkpoint_distributed ( persist p)
overridevirtual

Checkpoint exection_context to a distributed checkpoint.

Implements lbann::ExecutionContext.

Here is the caller graph for this function:

◆ save_to_checkpoint_shared()

void lbann::kfac::KFACExecutionContext::save_to_checkpoint_shared ( persist p)
overridevirtual

Checkpoint exection_context to a shared checkpoint.

Implements lbann::ExecutionContext.

Here is the caller graph for this function:

◆ serialize()

template<class Archive >
void lbann::kfac::KFACExecutionContext::serialize ( Archive &  ar)

Archive for checkpoint and restart

Here is the caller graph for this function:

Friends And Related Function Documentation

◆ ::lbann::KFAC

friend class ::lbann::KFAC
friend

Definition at line 60 of file kfac/execution_context.hpp.

Member Data Documentation

◆ m_blocks

std::vector<std::shared_ptr<kfac_block<Device> > > lbann::kfac::KFACExecutionContext::m_blocks
private

K-FAC per-layer blocks.

Definition at line 128 of file kfac/execution_context.hpp.

◆ m_damping_act

double lbann::kfac::KFACExecutionContext::m_damping_act
private

The current damping values.

Definition at line 122 of file kfac/execution_context.hpp.

◆ m_damping_bn_act

double lbann::kfac::KFACExecutionContext::m_damping_bn_act
private

Definition at line 122 of file kfac/execution_context.hpp.

◆ m_damping_bn_err

double lbann::kfac::KFACExecutionContext::m_damping_bn_err
private

Definition at line 122 of file kfac/execution_context.hpp.

◆ m_damping_err

double lbann::kfac::KFACExecutionContext::m_damping_err
private

Definition at line 122 of file kfac/execution_context.hpp.

◆ m_sgd_execution_context

SGDExecutionContext lbann::kfac::KFACExecutionContext::m_sgd_execution_context
private

Definition at line 119 of file kfac/execution_context.hpp.

◆ m_update_interval

size_t lbann::kfac::KFACExecutionContext::m_update_interval
private

The current update interval.

Definition at line 125 of file kfac/execution_context.hpp.

◆ m_workspace

std::unordered_map<std::string, El::Matrix<DataType, Device> > lbann::kfac::KFACExecutionContext::m_workspace
private

Workspace matrices that are used by m_blocks.

Definition at line 131 of file kfac/execution_context.hpp.


The documentation for this class was generated from the following file: