LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
random_pairwise_exchange.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_EXECUTION_ALGORITHMS_LTFB_RANDOM_PAIRWISE_EXCHANGE_HPP_INCLUDED
27 #define LBANN_EXECUTION_ALGORITHMS_LTFB_RANDOM_PAIRWISE_EXCHANGE_HPP_INCLUDED
28 
29 #include "mutation_strategy.hpp"
30 
32 
33 #include <google/protobuf/message.h>
34 
35 #include <cstddef>
36 #include <memory>
37 #include <string>
38 #include <unordered_map>
39 
40 namespace lbann {
41 namespace ltfb {
42 
61  : public Cloneable<RandomPairwiseExchange, MetaLearningStrategy>
62 {
63 public:
72  : public Cloneable<HasAbstractFunction<ExchangeStrategy>>
73  {
74  public:
80  ExchangeStrategy(std::set<std::string> weights_names)
81  : m_weights_names{std::move(weights_names)}
82  {}
83  virtual ~ExchangeStrategy() = default;
84 
95  virtual std::unique_ptr<model>
96  get_partner_model(model const& m, El::Int partner_trainer, size_t step) = 0;
97  // Better API, but complicates "sendrecv_weights":
98  // virtual std::unique_ptr<model> get_partner_model(
99  // lbann_comm const& c, El::Int partner_trainer);
100  protected:
102  std::set<std::string> const& weights_names() const noexcept
103  {
104  return m_weights_names;
105  }
106 
107  private:
108  std::set<std::string> m_weights_names;
109  }; // class ExchangeStrategy
110 
111  enum class metric_strategy
112  {
113  LOWER_IS_BETTER,
114  HIGHER_IS_BETTER,
115  }; // enum class metric_strategy
116 
117 public:
119 
129  RandomPairwiseExchange(std::string metric_name,
130  metric_strategy winner_strategy,
131  std::unique_ptr<ExchangeStrategy> comm_algo,
132  std::unique_ptr<MutationStrategy> mutate_algo);
133 
144  std::unordered_map<std::string, metric_strategy> metrics,
145  std::unique_ptr<ExchangeStrategy> comm_algo,
146  std::unique_ptr<MutationStrategy> mutate_algo);
147 
148  ~RandomPairwiseExchange() = default;
151 
161  void select_next(model& m,
163  data_coordinator& dc) const final;
164 
165 private:
167  std::unordered_map<std::string, EvalType>
169  LTFBExecutionContext& ctxt,
170  data_coordinator& dc) const;
172  int get_partner_trainer(lbann_comm const& c) const noexcept;
183  bool local_is_better(
184  std::unordered_map<std::string, EvalType> const& local_scores,
185  std::unordered_map<std::string, EvalType> const& partner_scores) const;
186 
187 private:
193  std::unordered_map<std::string, metric_strategy> m_metrics;
194 
203  std::unique_ptr<ExchangeStrategy> m_comm_algo;
204 
212  std::unique_ptr<MutationStrategy> m_mutate_algo;
213 
214 }; // class RandomPairwiseExchange
215 
221 class SendRecvWeights final
222  : public Cloneable<SendRecvWeights, RandomPairwiseExchange::ExchangeStrategy>
223 {
224  using BaseType =
226 
227 public:
234  SendRecvWeights(std::set<std::string> const& weights_names,
235  bool exchange_hyperparameters);
236 
243  SendRecvWeights(std::set<std::string>&& weights_names,
244  bool exchange_hyperparameters);
245 
246  SendRecvWeights(SendRecvWeights const&) = default;
247  SendRecvWeights(SendRecvWeights&&) = default;
248 
249  std::unique_ptr<model> get_partner_model(model const& m,
250  El::Int partner_trainer,
251  size_t /*step*/) final;
252 
253 private:
255 }; // class SendRecvWeights
256 
258 class CheckpointFile final
259  : public Cloneable<CheckpointFile, RandomPairwiseExchange::ExchangeStrategy>
260 {
261  using BaseType =
263 
264 public:
265  CheckpointFile(std::set<std::string> const& weights_names,
266  std::string const& ckpt_basedir);
267  CheckpointFile(std::set<std::string>&& weights_names,
268  std::string const& ckpt_basedir);
269  std::unique_ptr<model>
270  get_partner_model(model const& m, El::Int partner_trainer, size_t step) final;
271 
272 private:
273  std::string ckpt_basedir_;
274 }; // class CheckpointFile
275 
276 class CheckpointBinary final
277  : public Cloneable<CheckpointBinary, RandomPairwiseExchange::ExchangeStrategy>
278 {
279  using BaseType =
281 
282 public:
283  CheckpointBinary(std::set<std::string> const& weights_names);
284  CheckpointBinary(std::set<std::string>&& weights_names);
285  std::unique_ptr<model> get_partner_model(model const& m,
286  El::Int partner_trainer,
287  size_t /*step*/) final;
288 }; // class CheckpointBinary
289 
290 } // namespace ltfb
291 
293 
296 template <>
297 std::unique_ptr<ltfb::RandomPairwiseExchange>
298 make(google::protobuf::Message const&);
299 
301 
302 } // namespace lbann
303 #endif // LBANN_EXECUTION_ALGORITHMS_LTFB_RANDOM_PAIRWISE_EXCHANGE_HPP_INCLUDED
ExchangeStrategy(std::set< std::string > weights_names)
Construct with weights names.
Inject polymorphic clone functions into hierarchies.
Definition: cloneable.hpp:94
virtual std::unique_ptr< model > get_partner_model(model const &m, El::Int partner_trainer, size_t step)=0
Get the model from a partner trainer.
int get_partner_trainer(lbann_comm const &c) const noexcept
Generate a new trainer partner from the comm.
std::unordered_map< std::string, EvalType > evaluate_model(model &m, LTFBExecutionContext &ctxt, data_coordinator &dc) const
Get the value of the given metric from the model.
std::unique_ptr< MutationStrategy > m_mutate_algo
The strategy for mutation of a model.
See lbann::callbacks::ltfb::communication_algorithm::checkpoint_file.
Exchange model weights directly using sendrecvs.
void select_next(model &m, ltfb::LTFBExecutionContext &ctxt, data_coordinator &dc) const final
Engage in a tournament with a partner trainer.
RandomPairwiseExchange(std::string metric_name, metric_strategy winner_strategy, std::unique_ptr< ExchangeStrategy > comm_algo, std::unique_ptr< MutationStrategy > mutate_algo)
Constructor.
std::set< std::string > const & weights_names() const noexcept
Access weights_names.
Abstract base class for neural network models.
Definition: model.hpp:83
std::unique_ptr< ltfb::RandomPairwiseExchange > make(google::protobuf::Message const &)
Concrete builder for RandomPairwiseExchange.
std::unique_ptr< ExchangeStrategy > m_comm_algo
The strategy for exchanging two models.
A method for exchanging models with a partner trainer.
std::unordered_map< std::string, metric_strategy > m_metrics
The list of metric/strategy pairs.
bool local_is_better(std::unordered_map< std::string, EvalType > const &local_scores, std::unordered_map< std::string, EvalType > const &partner_scores) const
Evaluate the output of two models according to the input metric strategies.