27 #ifndef LBANN_WEIGHTS_HPP 28 #define LBANN_WEIGHTS_HPP 38 #include <onnx/onnx_pb.h> 39 #endif // LBANN_HAS_ONNX 55 class weights_initializer;
108 void setup_default_matrix_distribution();
129 virtual std::string get_datatype_name()
const = 0;
131 virtual bool has_optimizer()
const = 0;
148 std::vector<size_t> get_dims()
const;
150 size_t get_size()
const;
155 std::vector<size_t> get_matrix_height_dims()
const;
160 std::vector<size_t> get_matrix_width_dims()
const;
164 size_t get_matrix_height()
const;
168 size_t get_matrix_width()
const;
172 void set_dims(std::vector<size_t> matrix_height_dims,
173 std::vector<size_t> matrix_width_dims = {});
180 El::DistData get_matrix_distribution()
const;
181 void set_matrix_distribution(El::DistData dist);
198 void set_values(El::BaseDistMatrix
const& values);
201 virtual El::BaseDistMatrix& get_values() = 0;
202 virtual El::BaseDistMatrix
const& get_values()
const = 0;
215 virtual void set_initializer(std::unique_ptr<weights_initializer>&& init) = 0;
227 virtual const optimizer* get_optimizer()
const = 0;
231 virtual void set_optimizer(std::unique_ptr<optimizer>&& opt) = 0;
256 virtual void reconcile_values() = 0;
261 virtual void reconcile_values(
Al::request& req) = 0;
263 virtual bool load_from_save(std::string
const& ckpt_dir,
264 std::vector<std::string>
const& weight_list) = 0;
267 virtual void write_proto(lbann_data::Weights& proto)
const = 0;
277 template <
typename ArchiveT>
280 #ifdef LBANN_HAS_ONNX 282 virtual void fill_onnx_node(onnx::GraphProto& graph)
const = 0;
283 #endif // LBANN_HAS_ONNX 304 void steal_values(
weights& other);
312 virtual void do_augment_description_(
description&)
const = 0;
313 virtual void do_setup_() = 0;
315 virtual void do_set_dims_(std::vector<size_t>
const& matrix_height_dims,
316 std::vector<size_t>
const& matrix_width_dims) = 0;
317 virtual void do_steal_values_(
weights& other) = 0;
346 #endif // LBANN_WEIGHTS_HPP std::vector< size_t > m_matrix_height_dims
std::vector< size_t > m_matrix_width_dims
Inject polymorphic clone functions into hierarchies.
El::DistData m_matrix_dist
void set_name(std::string name)
lbann_comm & get_comm() const
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
Generates nicely formatted description messages.
Abstract base class for gradient-based optimization algorithms.
std::shared_ptr< weights > OwningWeightsPtr
Smart pointer to manage ownership of a weights object.
std::weak_ptr< weights > ViewingWeightsPtr
Smart pointer to reference a weights object.
Scheme for initializing weight values.
std::string get_name() const
void set_dims(size_t size)