LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
execution_algorithms/ltfb.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_EXECUTION_ALGORITHMS_LTFB_HPP_INCLUDED
27 #define LBANN_EXECUTION_ALGORITHMS_LTFB_HPP_INCLUDED
28 
35 
39 
40 #include <google/protobuf/message.h>
41 #include <memory>
42 
43 namespace lbann {
44 
66 class LTFB final : public TrainingAlgorithm
67 {
68 public:
71 
72 public:
74 
83  LTFB(std::string name,
84  std::unique_ptr<TrainingAlgorithm> local_training_algorithm,
85  std::unique_ptr<ltfb::MetaLearningStrategy> meta_learning_strategy,
86  ltfb::LTFBTerminationCriteria stopping_criteria,
87  bool suppress_timer)
88  : TrainingAlgorithm{std::move(name)},
89  m_local_algo{std::move(local_training_algorithm)},
90  m_meta_learning_strategy{std::move(meta_learning_strategy)},
91  m_termination_criteria{std::move(stopping_criteria)},
92  m_suppress_timer{suppress_timer}
93  {}
94 
95  ~LTFB() noexcept = default;
96  LTFB(LTFB const& other) = delete;
97  LTFB& operator=(LTFB const&) = delete;
98  LTFB(LTFB&&) = default;
99  LTFB& operator=(LTFB&&) = default;
101 
102  std::string get_type() const final { return "LTFB"; }
105 
106 
114  void apply(ExecutionContext& context,
115  model& m,
116  data_coordinator& dc,
117  execution_mode mode) final;
119 protected:
124  {
125  return new ltfb::LTFBExecutionContext();
126  }
127 
128 private:
130  std::unique_ptr<TrainingAlgorithm> m_local_algo;
131 
133  std::unique_ptr<ltfb::MetaLearningStrategy> m_meta_learning_strategy;
134 
137 
143  bool m_suppress_timer = false;
144 }; // class LTFB
145 
146 } // namespace lbann
147 
151 template <>
152 std::unique_ptr<lbann::LTFB>
153 lbann::make<lbann::LTFB>(google::protobuf::Message const& msg);
154 
155 #endif // LBANN_EXECUTION_ALGORITHMS_LTFB_HPP_INCLUDED
LTFB & operator=(LTFB const &)=delete
void apply(ExecutionContext &context, model &m, data_coordinator &dc, execution_mode mode) final
Apply the training algorithm to refine model weights.
std::unique_ptr< TrainingAlgorithm > m_local_algo
The training algorithm for trainer-local training.
std::unique_ptr< ltfb::MetaLearningStrategy > m_meta_learning_strategy
The strategy for postprocessing local training outputs.
std::string get_type() const final
Queries.
An implementation of the LTFB training algorithm.
Abstract base class for neural network models.
Definition: model.hpp:83
execution_mode
Neural network execution mode.
Definition: base.hpp:229
ltfb::LTFBTerminationCriteria m_termination_criteria
The LTFB stopping criteria.
ltfb::LTFBExecutionContext * do_get_new_execution_context() const final
Covariant return-friendly implementation of get_new_exection_context().
bool m_suppress_timer
Suppress timer output.
Base class for LBANN training_algorithms.
~LTFB() noexcept=default
LTFB(std::string name, std::unique_ptr< TrainingAlgorithm > local_training_algorithm, std::unique_ptr< ltfb::MetaLearningStrategy > meta_learning_strategy, ltfb::LTFBTerminationCriteria stopping_criteria, bool suppress_timer)
Construct LTFB from its component pieces.