LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
variable_minibatch.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
25 //
26 // lbann_variable_minibatch .hpp .cpp - Callback for variable-size mini-batches
28 
29 #ifndef LBANN_CALLBACKS_VARIABLE_MINIBATCH_HPP_INCLUDED
30 #define LBANN_CALLBACKS_VARIABLE_MINIBATCH_HPP_INCLUDED
31 
33 
34 namespace lbann {
35 namespace callback {
36 
43 {
44 public:
45  variable_minibatch(size_t starting_mbsize);
46  variable_minibatch(const variable_minibatch&) = default;
49  void on_train_begin(model* m) override;
51  void on_epoch_end(model* m) override;
52 
53 protected:
67  virtual bool
68  schedule(model* m, size_t& new_mbsize, float& new_lr, size_t& ramp_time) = 0;
70  void change_learning_rate(model* m, float new_lr) const;
72  float get_current_learning_rate(model* m) const;
73 
84  size_t m_ramp_count = 0;
86  float m_lr_incr = 0.0f;
87 };
88 
94 {
95 public:
96  step_minibatch(size_t starting_mbsize, size_t step, size_t ramp_time = 0);
97  step_minibatch(const step_minibatch&) = default;
98  step_minibatch& operator=(const step_minibatch&) = delete;
99  step_minibatch* copy() const override { return new step_minibatch(*this); }
100  std::string name() const override { return "step minibatch"; }
101 
102 protected:
103  bool schedule(model* m,
104  size_t& new_mbsize,
105  float& new_lr,
106  size_t& ramp_time) override;
107 
108 private:
110  void write_specific_proto(lbann_data::Callback& proto) const final;
111 
113  size_t m_step;
115  size_t m_ramp_time;
116 };
117 
118 // Builder function
119 std::unique_ptr<callback_base>
120 build_step_minibatch_callback_from_pbuf(const google::protobuf::Message&,
121  std::shared_ptr<lbann_summary> const&);
122 
124 {
125 public:
128  {
130  size_t epoch;
132  size_t mbsize;
134  float lr;
136  size_t ramp_time;
137  minibatch_step(size_t _epoch, size_t _mbsize, float _lr, size_t _ramp_time)
138  : epoch(_epoch), mbsize(_mbsize), lr(_lr), ramp_time(_ramp_time)
139  {}
140  };
141 
142  minibatch_schedule(size_t starting_mbsize, std::vector<minibatch_step> steps);
143  minibatch_schedule(const minibatch_schedule&) = default;
145  minibatch_schedule* copy() const override
146  {
147  return new minibatch_schedule(*this);
148  }
149  std::string name() const override { return "minibatch schedule"; }
150 
151 protected:
152  bool schedule(model* m,
153  size_t& new_mbsize,
154  float& new_lr,
155  size_t& ramp_time) override;
156 
157 private:
159  void write_specific_proto(lbann_data::Callback& proto) const final;
160 
162  std::vector<minibatch_step> m_steps;
163 };
164 
165 // Builder function
166 std::unique_ptr<callback_base> build_minibatch_schedule_callback_from_pbuf(
167  const google::protobuf::Message&,
168  std::shared_ptr<lbann_summary> const&);
169 
170 } // namespace callback
171 } // namespace lbann
172 
173 #endif // LBANN_CALLBACKS_VARIABLE_MINIBATCH_HPP_INCLUDED
std::unique_ptr< callback_base > build_minibatch_schedule_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
step_minibatch * copy() const override
std::string name() const override
Return this callback&#39;s name.
float get_current_learning_rate(model *m) const
Get the current learning rate (assumes every layer has the same one).
std::vector< minibatch_step > m_steps
Steps in the mini-batch schedule, stored in reverse sorted order.
variable_minibatch & operator=(const variable_minibatch &)=default
float m_lr_incr
Amount to increment the learning rate by when ramping.
virtual bool schedule(model *m, size_t &new_mbsize, float &new_lr, size_t &ramp_time)=0
size_t m_ramp_count
Current number of epochs left to ramp the learning rate.
std::string name() const override
Return this callback&#39;s name.
std::unique_ptr< callback_base > build_step_minibatch_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
size_t m_ramp_time
Number of steps to ramp the learning rate over.
size_t epoch
Epoch for this schedule to start.
Represents a step in a schedule of mini-batch sizes.
minibatch_step(size_t _epoch, size_t _mbsize, float _lr, size_t _ramp_time)
Base class for callbacks during training/testing.
Definition: callback.hpp:76
Abstract base class for neural network models.
Definition: model.hpp:83
variable_minibatch(size_t starting_mbsize)
minibatch_schedule * copy() const override
virtual void write_specific_proto(lbann_data::Callback &proto) const =0
Add callback specific data to prototext.
size_t m_step
Number of epochs between mini-batch size increases.
void change_learning_rate(model *m, float new_lr) const
Change the learning rate of every layer in m to new_lr.
size_t ramp_time
Number of epochs to ramp the learning rate over.
size_t m_starting_mbsize
Initial mini-batch size.
void on_train_begin(model *m) override
Set the initial mini-batch size.
void on_epoch_end(model *m) override
Potentially change the mini-batch size.