LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
kfac/execution_context.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_EXECUTION_ALGORITHMS_KFAC_EXECUTION_CONTEXT_HPP_INCLUDED
27 #define LBANN_EXECUTION_ALGORITHMS_KFAC_EXECUTION_CONTEXT_HPP_INCLUDED
28 
33 #include <memory>
34 #include <string>
35 
36 // Forward declarations
37 namespace lbann {
38 class KFAC;
39 template <El::Device Device>
40 class kfac_block;
41 class model;
42 } // namespace lbann
43 
44 namespace lbann {
45 namespace kfac {
46 
47 // Typedefs
48 #ifdef LBANN_HAS_GPU
49 constexpr El::Device Device = El::Device::GPU;
50 #else
51 constexpr El::Device Device = El::Device::CPU;
52 #endif // LBANN_HAS_GPU
53 
58 {
59 public:
60  friend class ::lbann::KFAC;
61 
63  KFACExecutionContext(double damping_act,
64  double damping_err,
65  double damping_bn_act,
66  double damping_bn_err);
68  ~KFACExecutionContext() = default;
69 
71  KFACExecutionContext(const KFACExecutionContext& other) = delete;
73  KFACExecutionContext& operator=(const KFACExecutionContext& other) = delete;
74 
76  std::unique_ptr<lbann::ExecutionContext> get_new() const override;
77 
82  std::string get_type() const override;
83 
85  std::string get_state_string() const noexcept override;
86 
89  {
91  }
92 
95  El::Matrix<DataType, Device>& get_workspace_matrix(const std::string& key,
96  const size_t height,
97  const size_t width);
98 
100 
103  template <class Archive>
104  void serialize(Archive& ar);
105 
107  void save_to_checkpoint_shared(persist& p) override;
109  void load_from_checkpoint_shared(persist& p) override;
111  void save_to_checkpoint_distributed(persist& p) override;
113  void load_from_checkpoint_distributed(persist& p) override;
115 
117 
118 private:
120 
123 
126 
128  std::vector<std::shared_ptr<kfac_block<Device>>> m_blocks;
129 
131  std::unordered_map<std::string, El::Matrix<DataType, Device>> m_workspace;
132 
133 }; // class ExecutionContext
134 
135 } // namespace kfac
136 } // namespace lbann
137 #endif // LBANN_EXECUTION_ALGORITHMS_KFAC_EXECUTION_CONTEXT_HPP_INCLUDED
void save_to_checkpoint_distributed(persist &p) override
Checkpoint exection_context to a distributed checkpoint.
double m_damping_act
The current damping values.
SGD Uses the step to track the Current mini-batch step for execution mode.
constexpr El::Device Device
std::unordered_map< std::string, El::Matrix< DataType, Device > > m_workspace
Workspace matrices that are used by m_blocks.
std::unique_ptr< lbann::ExecutionContext > get_new() const override
KFACExecutionContext & operator=(const KFACExecutionContext &other)=delete
std::vector< std::shared_ptr< kfac_block< Device > > > m_blocks
K-FAC per-layer blocks.
std::string get_type() const override
Get a string identifying the type of execution context.
std::string get_state_string() const noexcept override
Return the state of the execution context as a string.
SGDExecutionContext & get_sgd_execution_context() noexcept
Return execution context for SGD-family training algorithm.
KFACExecutionContext(double damping_act, double damping_err, double damping_bn_act, double damping_bn_err)
El::Matrix< DataType, Device > & get_workspace_matrix(const std::string &key, const size_t height, const size_t width)
Gets the Kronecker factor matrix of a FC layer. The same key is tied with the same matrix instance...
void load_from_checkpoint_shared(persist &p) override
Restore execution_context from a shared checkpoint.
Abstract base class for neural network models.
Definition: model.hpp:83
void load_from_checkpoint_distributed(persist &p) override
Restore execution_context from a distributed checkpoint.
size_t m_update_interval
The current update interval.
void save_to_checkpoint_shared(persist &p) override
Checkpoint exection_context to a shared checkpoint.
void print_workspace_size(model &model)