LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
adam.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_OPTIMIZERS_ADAM_HPP_INCLUDED
28 #define LBANN_OPTIMIZERS_ADAM_HPP_INCLUDED
29 
30 #include "lbann/io/persist.hpp"
32 #include "lbann/proto/optimizers.pb.h"
33 
34 namespace lbann {
35 namespace callback {
36 class perturb_adam;
37 } // namespace callback
38 
46 template <typename TensorDataType>
47 class adam
48  : public Cloneable<adam<TensorDataType>, data_type_optimizer<TensorDataType>>
49 {
50  using BaseType =
52 
53 public:
55 
58  using AbsDistMatrixType = El::AbstractDistMatrix<TensorDataType>;
59 
61  using OptimizerType = data_type_optimizer<TensorDataType>;
62 
65 
67 
68 public:
70 
72  adam(TensorDataType learning_rate,
73  TensorDataType beta1 = 0.9,
74  TensorDataType beta2 = 0.99,
75  TensorDataType eps = 1e-8,
76  TensorDataType adamw_weight_decay = 0.0);
77  adam(const adam& other);
78  adam& operator=(const adam& other);
79  ~adam() = default;
80 
82  template <class Archive>
83  void serialize(Archive& ar);
84 
86 
88 
91  std::string get_type() const override { return "Adam"; }
93  description get_description() const override;
94 
96 
98 
101  TensorDataType get_beta1() const noexcept { return m_beta1; }
103  void set_beta1(TensorDataType beta1) { m_beta1 = beta1; }
105  TensorDataType get_beta2() const noexcept { return m_beta2; }
107  void set_beta2(TensorDataType beta2) { m_beta2 = beta2; }
109  TensorDataType get_eps() const noexcept { return m_eps; }
111  void set_eps(TensorDataType eps) { m_eps = eps; }
113  TensorDataType get_adamw_weight_decay() const noexcept
114  {
115  return m_adamw_weight_decay;
116  }
118  void set_adamw_weight_decay(TensorDataType adamw_weight_decay)
119  {
120  m_adamw_weight_decay = adamw_weight_decay;
121  }
122 
124  const AbsDistMatrixType& get_moment1() const;
126  AbsDistMatrixType& get_moment1();
128  const AbsDistMatrixType& get_moment2() const;
130  AbsDistMatrixType& get_moment2();
131 
135  TensorDataType get_current_beta1() const noexcept { return m_current_beta1; }
139  void set_current_beta1(TensorDataType current_beta1)
140  {
141  m_current_beta1 = current_beta1;
142  }
146  TensorDataType get_current_beta2() const noexcept { return m_current_beta2; }
150  void set_current_beta2(TensorDataType current_beta2)
151  {
152  m_current_beta2 = current_beta2;
153  }
154 
156 
158 
160  using OptimizerType::setup;
161  void setup(WeightsType* w = nullptr) override;
162 
164 
166  void write_proto(lbann_data::Optimizer& opt) const final;
167 
168 protected:
169  friend cereal::access;
170 
176  : adam(El::To<TensorDataType>(1.f),
177  El::To<TensorDataType>(0.9),
178  El::To<TensorDataType>(0.99),
179  El::To<TensorDataType>(1e-8),
180  El::To<TensorDataType>(0))
181  {}
182 
184  void step_compute(AbsDistMatrixType& values,
185  const AbsDistMatrixType& gradient) override;
186 
187 private:
189  TensorDataType m_beta1;
191  TensorDataType m_beta2;
193  TensorDataType m_eps;
195  TensorDataType m_adamw_weight_decay;
197  TensorDataType m_current_beta1 = TensorDataType(1.);
199  TensorDataType m_current_beta2 = TensorDataType(1.);
201  std::unique_ptr<AbsDistMatrixType> m_moment1;
203  std::unique_ptr<AbsDistMatrixType> m_moment2;
204 
207 
209  void step_compute_cpu(AbsDistMatrixType& values,
210  const AbsDistMatrixType& gradient,
211  const TensorDataType& correction);
212 #ifdef LBANN_HAS_GPU
213 
214  void step_compute_gpu(AbsDistMatrixType& values,
215  const AbsDistMatrixType& gradient,
216  const TensorDataType& correction);
217 #endif // LBANN_HAS_GPU
218 };
219 
220 template <typename TensorDataType>
221 std::unique_ptr<optimizer>
222 build_adam_optimizer_from_pbuf(google::protobuf::Message const&);
223 
224 } // namespace lbann
225 
226 #endif // LBANN_OPTIMIZERS_ADAM_HPP_INCLUDED
std::unique_ptr< AbsDistMatrixType > m_moment2
Definition: adam.hpp:203
Adam optimizer.
Definition: adam.hpp:47
TensorDataType get_beta2() const noexcept
Definition: adam.hpp:105
Inject polymorphic clone functions into hierarchies.
Definition: cloneable.hpp:94
void set_current_beta1(TensorDataType current_beta1)
Definition: adam.hpp:139
std::string get_type() const override
Definition: adam.hpp:91
TensorDataType get_current_beta1() const noexcept
Definition: adam.hpp:135
TensorDataType m_eps
Definition: adam.hpp:193
adam()
Default constructor.
Definition: adam.hpp:175
TensorDataType get_current_beta2() const noexcept
Definition: adam.hpp:146
TensorDataType get_beta1() const noexcept
Definition: adam.hpp:101
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
Generates nicely formatted description messages.
Definition: description.hpp:49
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
Definition: adam.hpp:58
TensorDataType m_beta2
Definition: adam.hpp:191
Hyperparameter exploration with Adam optimizers.
std::unique_ptr< AbsDistMatrixType > m_moment1
Definition: adam.hpp:201
TensorDataType m_beta1
Definition: adam.hpp:189
void set_beta1(TensorDataType beta1)
Definition: adam.hpp:103
void set_beta2(TensorDataType beta2)
Definition: adam.hpp:107
void set_adamw_weight_decay(TensorDataType adamw_weight_decay)
Definition: adam.hpp:118
void set_eps(TensorDataType eps)
Definition: adam.hpp:111
TensorDataType m_adamw_weight_decay
Definition: adam.hpp:195
TensorDataType get_adamw_weight_decay() const noexcept
Definition: adam.hpp:113
TensorDataType get_eps() const noexcept
Definition: adam.hpp:109
void set_current_beta2(TensorDataType current_beta2)
Definition: adam.hpp:150
std::unique_ptr< optimizer > build_adam_optimizer_from_pbuf(google::protobuf::Message const &)