LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
weights/weights.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_WEIGHTS_HPP
28 #define LBANN_WEIGHTS_HPP
29 
30 #include "lbann/base.hpp"
33 
34 #include <memory>
35 #include <string>
36 #include <vector>
37 #ifdef LBANN_HAS_ONNX
38 #include <onnx/onnx_pb.h>
39 #endif // LBANN_HAS_ONNX
40 
41 namespace lbann_data {
42 class Weights;
43 }
44 
45 namespace lbann {
46 
47 namespace Al {
48 // Forward declaration
49 struct request;
50 } // namespace Al
51 
52 // Forward declaration
53 class lbann_comm;
54 class weights;
55 class weights_initializer;
56 class optimizer;
57 
76 using OwningWeightsPtr = std::shared_ptr<weights>;
84 using ViewingWeightsPtr = std::weak_ptr<weights>;
85 
100 class weights : public Cloneable<HasAbstractFunction<weights>>
101 {
102 private:
103  weights();
104  // -----------------------------------------------
105  // Internal method for setting the comm pointer
106  // -----------------------------------------------
107  void set_comm(lbann_comm& comm);
108  void setup_default_matrix_distribution();
109 
110 public:
111  weights(lbann_comm& comm);
112  virtual ~weights() = default;
113 
118  void set_name(std::string name) { m_name = name; }
119 
121  std::string get_name() const { return m_name; }
122 
123  lbann_comm& get_comm() const { return *m_comm; }
124 
126  description get_description() const;
127 
129  virtual std::string get_datatype_name() const = 0;
130 
131  virtual bool has_optimizer() const = 0;
132 
133  // -----------------------------------------------
134  // Dimension accessors
135  // -----------------------------------------------
148  std::vector<size_t> get_dims() const;
150  size_t get_size() const;
155  std::vector<size_t> get_matrix_height_dims() const;
160  std::vector<size_t> get_matrix_width_dims() const;
164  size_t get_matrix_height() const;
168  size_t get_matrix_width() const;
172  void set_dims(std::vector<size_t> matrix_height_dims,
173  std::vector<size_t> matrix_width_dims = {});
175  void set_dims(size_t size) { set_dims({size}, {}); }
176 
177  // -----------------------------------------------
178  // Matrix distribution accessors
179  // -----------------------------------------------
180  El::DistData get_matrix_distribution() const;
181  void set_matrix_distribution(El::DistData dist);
182 
184 
198  void set_values(El::BaseDistMatrix const& values);
199 
201  virtual El::BaseDistMatrix& get_values() = 0;
202  virtual El::BaseDistMatrix const& get_values() const = 0;
204 
205  // -----------------------------------------------
206  // Initializer accessors
207  // -----------------------------------------------
209  virtual weights_initializer* get_initializer() = 0;
211  virtual const weights_initializer* get_initializer() const = 0;
215  virtual void set_initializer(std::unique_ptr<weights_initializer>&& init) = 0;
216 
217  // -----------------------------------------------
218  // Optimizer accessors
219  // -----------------------------------------------
223  virtual optimizer* get_optimizer() = 0;
227  virtual const optimizer* get_optimizer() const = 0;
231  virtual void set_optimizer(std::unique_ptr<optimizer>&& opt) = 0;
232 
233  // -----------------------------------------------
234  // Setup
235  // -----------------------------------------------
236  void setup();
237 
238  // -----------------------------------------------
239  // Freezing
240  // -----------------------------------------------
242  void freeze() { m_frozen = true; }
244  void unfreeze() { m_frozen = false; }
246  bool is_frozen() const { return m_frozen; }
247 
248  // -----------------------------------------------
249  // Weight matrix accessors
250  // -----------------------------------------------
251 
256  virtual void reconcile_values() = 0;
261  virtual void reconcile_values(Al::request& req) = 0;
262 
263  virtual bool load_from_save(std::string const& ckpt_dir,
264  std::vector<std::string> const& weight_list) = 0;
265 
267  virtual void write_proto(lbann_data::Weights& proto) const = 0;
268 
270 
277  template <typename ArchiveT>
278  void serialize(ArchiveT& ar);
279 
280 #ifdef LBANN_HAS_ONNX
281 
282  virtual void fill_onnx_node(onnx::GraphProto& graph) const = 0;
283 #endif // LBANN_HAS_ONNX
284 
286 
287 
304  void steal_values(weights& other);
305 
307 protected:
308  weights(const weights& other) = default;
309  weights& operator=(const weights& other) = default;
310 
311 private:
312  virtual void do_augment_description_(description&) const = 0;
313  virtual void do_setup_() = 0;
314 
315  virtual void do_set_dims_(std::vector<size_t> const& matrix_height_dims,
316  std::vector<size_t> const& matrix_width_dims) = 0;
317  virtual void do_steal_values_(weights& other) = 0;
318 
319 private:
324  std::string m_name;
325 
328 
332  std::vector<size_t> m_matrix_height_dims;
336  std::vector<size_t> m_matrix_width_dims;
338  El::DistData m_matrix_dist;
339 
341  bool m_frozen;
342 };
343 
344 } // namespace lbann
345 
346 #endif // LBANN_WEIGHTS_HPP
std::vector< size_t > m_matrix_height_dims
std::vector< size_t > m_matrix_width_dims
Inject polymorphic clone functions into hierarchies.
Definition: cloneable.hpp:94
El::DistData m_matrix_dist
void set_name(std::string name)
std::string m_name
lbann_comm & get_comm() const
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
Generates nicely formatted description messages.
Definition: description.hpp:49
Abstract base class for gradient-based optimization algorithms.
Definition: optimizer.hpp:85
std::shared_ptr< weights > OwningWeightsPtr
Smart pointer to manage ownership of a weights object.
Definition: model.hpp:77
std::weak_ptr< weights > ViewingWeightsPtr
Smart pointer to reference a weights object.
Definition: layer.hpp:89
Scheme for initializing weight values.
Definition: initializer.hpp:43
lbann_comm * m_comm
std::string get_name() const
bool is_frozen() const
void set_dims(size_t size)