LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
training_algorithm.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_EXECUTION_ALGORITHMS_TRAINING_ALGORITHM_HPP_INCLUDED
28 #define LBANN_EXECUTION_ALGORITHMS_TRAINING_ALGORITHM_HPP_INCLUDED
29 
30 #include "lbann/base.hpp"
33 #include "lbann/utils/memory.hpp"
34 #include <google/protobuf/message.h>
35 #include <memory>
36 
37 namespace lbann {
38 
39 // Forward-declarations
40 class data_coordinator;
41 class ExecutionContext;
42 class model;
43 
87 {
88 public:
90 
94  TrainingAlgorithm(std::string name);
95  virtual ~TrainingAlgorithm() = default;
97 
99 
101  virtual std::string get_type() const = 0;
102 
104  std::string const& get_name() const noexcept;
105 
107 
108 
118  virtual void apply(ExecutionContext& context,
119  model& model,
120  data_coordinator& dc,
121  execution_mode mode) = 0;
122 
129  void apply(model& model, data_coordinator& dc)
130  {
132  }
133 
140  void setup_models(std::vector<observer_ptr<model>> const& models,
141  size_t max_mini_batch_size,
142  const std::vector<El::Grid*>& grids);
143 
157  std::unique_ptr<ExecutionContext> get_new_execution_context() const
158  {
160  }
162 
163 protected:
165  TrainingAlgorithm(const TrainingAlgorithm& other) = delete;
167  TrainingAlgorithm& operator=(const TrainingAlgorithm& other) = delete;
168  TrainingAlgorithm(TrainingAlgorithm&& other) = default;
169  TrainingAlgorithm& operator=(TrainingAlgorithm&& other) = default;
171 
175  virtual ExecutionContext* do_get_new_execution_context() const = 0;
176 
177 private:
179  std::string m_name;
180 };
181 
182 } // namespace lbann
183 
184 #endif // LBANN_EXECUTION_ALGORITHMS_TRAINING_ALGORITHM_HPP_INCLUDED
virtual std::string get_type() const =0
A string identifying the type of the object.
std::string m_name
The user-defined name of the algorithm.
virtual ~TrainingAlgorithm()=default
std::unique_ptr< ExecutionContext > get_new_execution_context() const
Get a default-initialized execution context that fits this training algorithm.
std::string const & get_name() const noexcept
A user-defined string identifying the algorithm object.
void apply(model &model, data_coordinator &dc)
Apply the algorithm to the given model.
Abstract base class for neural network models.
Definition: model.hpp:83
The execution context for an KFAC algorithm.
TrainingAlgorithm & operator=(const TrainingAlgorithm &other)=delete
typename std::add_pointer< T >::type observer_ptr
Creating an observer_ptr to complement the unique_ptr and shared_ptr.
Definition: base.hpp:54
execution_mode
Neural network execution mode.
Definition: base.hpp:229
virtual void apply(ExecutionContext &context, model &model, data_coordinator &dc, execution_mode mode)=0
Apply the algorithm to the given model.
virtual ExecutionContext * do_get_new_execution_context() const =0
Covariant return-friendly implementation of get_new_exection_context().
std::unique_ptr< T > to_unique_ptr(T *ptr)
Convert the raw pointer to a unique_ptr.
Definition: memory.hpp:38
void setup_models(std::vector< observer_ptr< model >> const &models, size_t max_mini_batch_size, const std::vector< El::Grid *> &grids)
Setup a collection of models.
TrainingAlgorithm(std::string name)
Constructor.
Base class for LBANN training_algorithms.