27 #ifndef LBANN_SGD_TRAINING_ALGORITHM_HPP 28 #define LBANN_SGD_TRAINING_ALGORITHM_HPP 41 #endif // LBANN_HAS_GPU 43 #include <google/protobuf/message.h> 55 std::unique_ptr<SGDTerminationCriteria> stop,
56 bool suppress_timer_output);
68 std::string
get_type()
const override;
159 gpu_lib::event_wrapper m_data_prefetch_sync_event;
160 #endif // LBANN_HAS_GPU 164 std::unique_ptr<SGDTrainingAlgorithm>
169 #endif // LBANN_SGD_TRAINING_ALGORITHM_HPP SGD Uses the step to track the Current mini-batch step for execution mode.
void evaluate(SGDExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode, SGDTerminationCriteria const &term)
void do_batch_end_cbs(model &model, execution_mode mode, ScopeTimer timer)
size_t m_validation_epochs
void train(SGDExecutionContext &c, model &model, data_coordinator &dc, SGDTerminationCriteria const &term)
bool m_suppress_timer
Suppress timer output.
SGDTrainingAlgorithm(std::string name, std::unique_ptr< SGDTerminationCriteria > stop, bool suppress_timer_output)
Construct with a name.
std::unique_ptr< SGDTerminationCriteria > m_stopping_criteria
std::unique_ptr< SGDTrainingAlgorithm > make< SGDTrainingAlgorithm >(google::protobuf::Message const ¶ms)
void do_epoch_begin_cbs(model &model, ScopeTimer timer)
void do_evaluate_end_cbs(model &model, execution_mode mode, ScopeTimer timer)
void do_batch_begin_cbs(model &model, execution_mode mode, ScopeTimer timer)
void do_evaluate_begin_cbs(model &model, execution_mode mode, ScopeTimer timer)
SGDExecutionContext * do_get_new_execution_context() const override
Covariant return-friendly implementation of get_new_exection_context().
bool evaluate_mini_batch(SGDExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode, ScopeTimer timer)
Base class for LBANN SGD-family training algorithms.
A nesting inclusive-timer.
SGDExecutionContext m_validation_context
Abstract base class for neural network models.
void do_train_begin_cbs(model &model, ScopeTimer timer)
execution_mode
Neural network execution mode.
SGDTrainingAlgorithm & operator=(const SGDTrainingAlgorithm &other)=delete
void apply(ExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode) override
virtual ~SGDTrainingAlgorithm()=default
void do_train_end_cbs(model &model, ScopeTimer timer)
std::string get_type() const override
std::unique_ptr< SGDExecutionContext > get_new_execution_context() const
Get a default-initialized execution context.
Base class for LBANN training_algorithms.
bool train_mini_batch(SGDExecutionContext &c, model &model, data_coordinator &dc, ScopeTimer timer)
Base class for SGD stopping.
void do_epoch_end_cbs(model &model, ScopeTimer timer)