26 #ifndef LBANN_EXECUTION_ALGORITHMS_KFAC_EXECUTION_CONTEXT_HPP_INCLUDED 27 #define LBANN_EXECUTION_ALGORITHMS_KFAC_EXECUTION_CONTEXT_HPP_INCLUDED 39 template <El::Device Device>
52 #endif // LBANN_HAS_GPU 60 friend class ::lbann::KFAC;
65 double damping_bn_act,
66 double damping_bn_err);
76 std::unique_ptr<lbann::ExecutionContext>
get_new()
const override;
82 std::string
get_type()
const override;
103 template <
class Archive>
128 std::vector<std::shared_ptr<kfac_block<Device>>>
m_blocks;
131 std::unordered_map<std::string, El::Matrix<DataType, Device>>
m_workspace;
137 #endif // LBANN_EXECUTION_ALGORITHMS_KFAC_EXECUTION_CONTEXT_HPP_INCLUDED void save_to_checkpoint_distributed(persist &p) override
Checkpoint exection_context to a distributed checkpoint.
double m_damping_act
The current damping values.
SGD Uses the step to track the Current mini-batch step for execution mode.
constexpr El::Device Device
std::unordered_map< std::string, El::Matrix< DataType, Device > > m_workspace
Workspace matrices that are used by m_blocks.
std::unique_ptr< lbann::ExecutionContext > get_new() const override
KFACExecutionContext & operator=(const KFACExecutionContext &other)=delete
std::vector< std::shared_ptr< kfac_block< Device > > > m_blocks
K-FAC per-layer blocks.
~KFACExecutionContext()=default
std::string get_type() const override
Get a string identifying the type of execution context.
std::string get_state_string() const noexcept override
Return the state of the execution context as a string.
SGDExecutionContext & get_sgd_execution_context() noexcept
Return execution context for SGD-family training algorithm.
KFACExecutionContext(double damping_act, double damping_err, double damping_bn_act, double damping_bn_err)
El::Matrix< DataType, Device > & get_workspace_matrix(const std::string &key, const size_t height, const size_t width)
Gets the Kronecker factor matrix of a FC layer. The same key is tied with the same matrix instance...
void load_from_checkpoint_shared(persist &p) override
Restore execution_context from a shared checkpoint.
Abstract base class for neural network models.
void load_from_checkpoint_distributed(persist &p) override
Restore execution_context from a distributed checkpoint.
size_t m_update_interval
The current update interval.
void save_to_checkpoint_shared(persist &p) override
Checkpoint exection_context to a shared checkpoint.
SGDExecutionContext m_sgd_execution_context
void print_workspace_size(model &model)
void serialize(Archive &ar)