LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
sgd_training_algorithm.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_SGD_TRAINING_ALGORITHM_HPP
28 #define LBANN_SGD_TRAINING_ALGORITHM_HPP
29 
30 #include "lbann/base.hpp"
37 #include "lbann/utils/memory.hpp"
39 #ifdef LBANN_HAS_GPU
41 #endif // LBANN_HAS_GPU
42 
43 #include <google/protobuf/message.h>
44 
45 #include <memory>
46 
47 namespace lbann {
48 
51 {
52 public:
54  SGDTrainingAlgorithm(std::string name,
55  std::unique_ptr<SGDTerminationCriteria> stop,
56  bool suppress_timer_output);
57 
58  SGDTrainingAlgorithm(const SGDTrainingAlgorithm& other) = delete;
59  SGDTrainingAlgorithm& operator=(const SGDTrainingAlgorithm& other) = delete;
60 
61  SGDTrainingAlgorithm(SGDTrainingAlgorithm&& other) = default;
63 
64  virtual ~SGDTrainingAlgorithm() = default;
66  // virtual sgd_training_algorithm* copy() const = default;
67 
68  std::string get_type() const override;
69 
70  // ===========================================
71  // Execution
72  // ===========================================
73 
76  void apply(ExecutionContext& c,
77  model& model,
78  data_coordinator& dc,
79  execution_mode mode) override;
80 
82  void train(SGDExecutionContext& c,
83  model& model,
84  data_coordinator& dc,
85  SGDTerminationCriteria const& term);
86 
89  model& model,
90  data_coordinator& dc,
91  execution_mode mode,
92  SGDTerminationCriteria const& term);
93 
100  std::unique_ptr<SGDExecutionContext> get_new_execution_context() const;
101 
102 protected:
105  model& model,
106  data_coordinator& dc,
107  ScopeTimer timer);
108 
111  model& model,
112  data_coordinator& dc,
113  execution_mode mode,
114  ScopeTimer timer);
115 
117  // Callbacks
119 
121  void do_train_begin_cbs(model& model, ScopeTimer timer);
123  void do_train_end_cbs(model& model, ScopeTimer timer);
125  void
126  do_evaluate_begin_cbs(model& model, execution_mode mode, ScopeTimer timer);
128  void do_evaluate_end_cbs(model& model, execution_mode mode, ScopeTimer timer);
130  void do_epoch_begin_cbs(model& model, ScopeTimer timer);
132  void do_epoch_end_cbs(model& model, ScopeTimer timer);
134  void do_batch_begin_cbs(model& model, execution_mode mode, ScopeTimer timer);
136  void do_batch_end_cbs(model& model, execution_mode mode, ScopeTimer timer);
137 
139 
140 private:
142  std::unique_ptr<SGDTerminationCriteria> m_stopping_criteria;
143 
144  // FIXME (trb 07/20/21): This is a hack. These aren't actually
145  // copyable objects (it wouldn't make sense), so when the training
146  // algorithm is copied, these are reset to defaults. "In the
147  // future", we'll externalize validation and this won't be an issue.
150 
156  bool m_suppress_timer = false;
157 
158 #ifdef LBANN_HAS_GPU
159  gpu_lib::event_wrapper m_data_prefetch_sync_event;
160 #endif // LBANN_HAS_GPU
161 };
162 
163 template <>
164 std::unique_ptr<SGDTrainingAlgorithm>
165 make<SGDTrainingAlgorithm>(google::protobuf::Message const& params);
166 
167 } // namespace lbann
168 
169 #endif // LBANN_SGD_TRAINING_ALGORITHM_HPP
SGD Uses the step to track the Current mini-batch step for execution mode.
void evaluate(SGDExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode, SGDTerminationCriteria const &term)
void do_batch_end_cbs(model &model, execution_mode mode, ScopeTimer timer)
void train(SGDExecutionContext &c, model &model, data_coordinator &dc, SGDTerminationCriteria const &term)
bool m_suppress_timer
Suppress timer output.
SGDTrainingAlgorithm(std::string name, std::unique_ptr< SGDTerminationCriteria > stop, bool suppress_timer_output)
Construct with a name.
std::unique_ptr< SGDTerminationCriteria > m_stopping_criteria
std::unique_ptr< SGDTrainingAlgorithm > make< SGDTrainingAlgorithm >(google::protobuf::Message const &params)
void do_epoch_begin_cbs(model &model, ScopeTimer timer)
void do_evaluate_end_cbs(model &model, execution_mode mode, ScopeTimer timer)
void do_batch_begin_cbs(model &model, execution_mode mode, ScopeTimer timer)
void do_evaluate_begin_cbs(model &model, execution_mode mode, ScopeTimer timer)
SGDExecutionContext * do_get_new_execution_context() const override
Covariant return-friendly implementation of get_new_exection_context().
bool evaluate_mini_batch(SGDExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode, ScopeTimer timer)
Base class for LBANN SGD-family training algorithms.
A nesting inclusive-timer.
Definition: timer_map.hpp:51
Abstract base class for neural network models.
Definition: model.hpp:83
void do_train_begin_cbs(model &model, ScopeTimer timer)
execution_mode
Neural network execution mode.
Definition: base.hpp:229
SGDTrainingAlgorithm & operator=(const SGDTrainingAlgorithm &other)=delete
void apply(ExecutionContext &c, model &model, data_coordinator &dc, execution_mode mode) override
virtual ~SGDTrainingAlgorithm()=default
void do_train_end_cbs(model &model, ScopeTimer timer)
std::string get_type() const override
std::unique_ptr< SGDExecutionContext > get_new_execution_context() const
Get a default-initialized execution context.
Base class for LBANN training_algorithms.
bool train_mini_batch(SGDExecutionContext &c, model &model, data_coordinator &dc, ScopeTimer timer)
Base class for SGD stopping.
void do_epoch_end_cbs(model &model, ScopeTimer timer)