27 #ifndef LBANN_SGD_EXECUTION_CONTEXT_HPP 28 #define LBANN_SGD_EXECUTION_CONTEXT_HPP 60 std::unique_ptr<ExecutionContext>
get_new()
const override 66 template <
class Archive>
97 std::string
get_type()
const override;
137 public Cloneable<HasAbstractFunction<SGDTerminationCriteria>>
158 :
public Cloneable<BatchTerminationCriteria, SGDTerminationCriteria>
166 return c.get_step() >= m_max_batches;
174 :
public Cloneable<EpochTerminationCriteria, SGDTerminationCriteria>
182 return c.get_epoch() >= m_max_epochs;
190 :
public Cloneable<SecondsTerminationCriteria, SGDTerminationCriteria>
203 #endif // LBANN_SGD_EXECUTION_CONTEXT_HPP std::unique_ptr< ExecutionContext > get_new() const override
SGD Uses the step to track the Current mini-batch step for execution mode.
lbann::Timer m_timer
Timer tracking execution time.
friend class cereal::access
Inject polymorphic clone functions into hierarchies.
bool is_done(SGDExecutionContext const &c) const noexcept final
bool operator()(ExecutionContext const &c_in) const final
void save_to_checkpoint_distributed(persist &p) override
Checkpoint exection_context to a distributed checkpoint.
bool get_early_stop() const noexcept
Stop SGD based on a fixed batch count.
bool is_done(SGDExecutionContext const &c) const noexcept final
double check() const noexcept
Get the current elapsed time (seconds) without stopping.
void load_from_checkpoint_shared(persist &p) override
double get_current_execution_time() const noexcept
void set_execution_mode(execution_mode mode) noexcept
std::string build_string(Args &&... args)
Build a string from the arguments.
std::string to_string(El::Device const &d)
Specifies when to stop a training algorithm.
EpochTerminationCriteria(size_t num_epochs)
void inc_epoch() noexcept
Increment the current epoch in the execution context.
void save_to_checkpoint_shared(persist &p) override
execution_mode
Neural network execution mode.
SecondsTerminationCriteria(double seconds)
void set_early_stop(bool stop) noexcept
virtual ~SGDExecutionContext()=default
SGDExecutionContext & operator=(SGDExecutionContext &&other)=default
void start_timer() noexcept
void serialize(Archive &ar)
void stop_timer() noexcept
BatchTerminationCriteria(size_t num_batches)
execution_mode m_execution_mode
execution_mode get_execution_mode() const noexcept override
std::string get_type() const override
Get a string identifying the type of execution context.
An exceedingly simple duration calculator.
double stop() noexcept
Get the total elapsed time in seconds.
std::string get_state_string() const noexcept override
Return the state of the execution context as a string.
size_t get_step() const noexcept
Current step in the training algorithm.
size_t get_epoch() const noexcept
void start() noexcept
Start counting time.
SGDExecutionContext()=default
void load_from_checkpoint_distributed(persist &p) override
Restore execution_context from a distributed checkpoint.
Base class for SGD stopping.