LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
sgd_execution_context.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_SGD_EXECUTION_CONTEXT_HPP
28 #define LBANN_SGD_EXECUTION_CONTEXT_HPP
29 
33 #include "lbann/utils/timer.hpp"
34 
35 namespace lbann {
36 
42 {
43 public:
47  virtual ~SGDExecutionContext() = default;
48 
50  SGDExecutionContext(SGDExecutionContext&& other) = default;
56  SGDExecutionContext(const SGDExecutionContext& other) = delete;
58  SGDExecutionContext& operator=(const SGDExecutionContext& other) = delete;
59 
60  std::unique_ptr<ExecutionContext> get_new() const override
61  {
62  return std::make_unique<SGDExecutionContext>(execution_mode::invalid);
63  }
64 
66  template <class Archive>
67  void serialize(Archive& ar);
68 
70  std::string get_state_string() const noexcept override
71  {
72  return build_string("sgd.",
74  ".epoch.",
75  get_epoch(),
76  ".step.",
77  get_step());
78  }
79 
81  inline size_t get_epoch() const noexcept { return m_epoch; }
82 
87  void inc_epoch() noexcept { ++m_epoch; }
88 
90  void save_to_checkpoint_shared(persist& p) override;
93  void load_from_checkpoint_shared(persist& p) override;
94  void save_to_checkpoint_distributed(persist& p) override;
95  void load_from_checkpoint_distributed(persist& p) override;
96 
97  std::string get_type() const override;
98 
100  void set_execution_mode(execution_mode mode) noexcept
101  {
102  m_execution_mode = mode;
103  }
104 
106  execution_mode get_execution_mode() const noexcept override
107  {
108  return m_execution_mode;
109  }
110 
111  void set_early_stop(bool stop) noexcept { m_stop_early = stop; }
112  bool get_early_stop() const noexcept { return m_stop_early; }
113 
114  void start_timer() noexcept { m_timer.start(); }
115  void stop_timer() noexcept { m_timer.stop(); }
116  double get_current_execution_time() const noexcept { return m_timer.check(); }
117 
118 private:
119  friend class cereal::access;
120  SGDExecutionContext() = default;
121 
122 private:
125 
127  size_t m_epoch = 0;
128 
130 
131  bool m_stop_early = false;
132 };
133 
136  : public TerminationCriteria,
137  public Cloneable<HasAbstractFunction<SGDTerminationCriteria>>
138 {
139 public:
140  SGDTerminationCriteria() = default;
141  virtual ~SGDTerminationCriteria() = default;
142  bool operator()(ExecutionContext const& c_in) const final
143  {
144  auto const& c = dynamic_cast<SGDExecutionContext const&>(c_in);
145  return c.get_early_stop() || this->is_done(c);
146  }
147 
148 private:
149  virtual bool is_done(SGDExecutionContext const& c) const noexcept = 0;
150 };
151 
158  : public Cloneable<BatchTerminationCriteria, SGDTerminationCriteria>
159 {
160 public:
161  BatchTerminationCriteria(size_t num_batches) : m_max_batches{num_batches} {}
162 
163 private:
164  bool is_done(SGDExecutionContext const& c) const noexcept final
165  {
166  return c.get_step() >= m_max_batches;
167  }
168 
169 private:
171 };
172 
174  : public Cloneable<EpochTerminationCriteria, SGDTerminationCriteria>
175 {
176 public:
177  EpochTerminationCriteria(size_t num_epochs) : m_max_epochs{num_epochs} {}
178 
179 private:
180  bool is_done(SGDExecutionContext const& c) const noexcept final
181  {
182  return c.get_epoch() >= m_max_epochs;
183  }
184 
185 private:
186  size_t m_max_epochs;
187 };
188 
190  : public Cloneable<SecondsTerminationCriteria, SGDTerminationCriteria>
191 {
192 public:
193  SecondsTerminationCriteria(double seconds) : m_max_seconds{seconds} {}
194 
195 private:
196  bool is_done(SGDExecutionContext const& c) const noexcept final;
197 
198 private:
200 };
201 } // namespace lbann
202 
203 #endif // LBANN_SGD_EXECUTION_CONTEXT_HPP
std::unique_ptr< ExecutionContext > get_new() const override
SGD Uses the step to track the Current mini-batch step for execution mode.
lbann::Timer m_timer
Timer tracking execution time.
Inject polymorphic clone functions into hierarchies.
Definition: cloneable.hpp:94
bool is_done(SGDExecutionContext const &c) const noexcept final
bool operator()(ExecutionContext const &c_in) const final
void save_to_checkpoint_distributed(persist &p) override
Checkpoint exection_context to a distributed checkpoint.
bool get_early_stop() const noexcept
Stop SGD based on a fixed batch count.
bool is_done(SGDExecutionContext const &c) const noexcept final
double check() const noexcept
Get the current elapsed time (seconds) without stopping.
Definition: utils/timer.hpp:97
void load_from_checkpoint_shared(persist &p) override
double get_current_execution_time() const noexcept
void set_execution_mode(execution_mode mode) noexcept
std::string build_string(Args &&... args)
Build a string from the arguments.
Definition: exception.hpp:157
std::string to_string(El::Device const &d)
Specifies when to stop a training algorithm.
void inc_epoch() noexcept
Increment the current epoch in the execution context.
void save_to_checkpoint_shared(persist &p) override
execution_mode
Neural network execution mode.
Definition: base.hpp:229
void set_early_stop(bool stop) noexcept
virtual ~SGDExecutionContext()=default
SGDExecutionContext & operator=(SGDExecutionContext &&other)=default
void serialize(Archive &ar)
execution_mode get_execution_mode() const noexcept override
std::string get_type() const override
Get a string identifying the type of execution context.
An exceedingly simple duration calculator.
Definition: utils/timer.hpp:52
double stop() noexcept
Get the total elapsed time in seconds.
Definition: utils/timer.hpp:89
std::string get_state_string() const noexcept override
Return the state of the execution context as a string.
size_t get_step() const noexcept
Current step in the training algorithm.
size_t get_epoch() const noexcept
void start() noexcept
Start counting time.
Definition: utils/timer.hpp:80
void load_from_checkpoint_distributed(persist &p) override
Restore execution_context from a distributed checkpoint.
Base class for SGD stopping.