LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
dump_weights.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
25 //
26 // dump_weights .hpp .cpp - Callbacks to dump weight matrices
28 
29 #ifndef LBANN_CALLBACKS_CALLBACK_DUMP_WEIGHTS_HPP_INCLUDED
30 #define LBANN_CALLBACKS_CALLBACK_DUMP_WEIGHTS_HPP_INCLUDED
31 
32 #include <utility>
33 
36 
37 namespace lbann {
38 namespace callback {
39 
40 // Forward declaration
41 namespace dump_weights_internal {
42 class FileFormat;
43 }
44 
57 {
58 public:
64  dump_weights(std::string dir,
65  El::Int epoch_interval,
66  std::unique_ptr<dump_weights_internal::FileFormat> file_format);
67  dump_weights(const dump_weights&);
68  dump_weights& operator=(const dump_weights&);
69  dump_weights* copy() const override { return new dump_weights(*this); }
70  void on_train_begin(model* m) override;
71  void on_epoch_end(model* m) override;
72  std::string name() const override { return "dump weights"; }
73  void set_target_dir(const std::string& dir) { m_directory = dir; }
74  const std::string& get_target_dir() { return m_directory; }
75 
77 
80  template <class Archive>
81  void serialize(Archive& ar);
82 
84 
85 private:
87  void write_specific_proto(lbann_data::Callback& proto) const final;
88 
89  friend class cereal::access;
90  dump_weights();
91 
93  std::string m_directory;
97  std::unique_ptr<dump_weights_internal::FileFormat> m_file_format;
98 
100  void do_dump_weights(const model& m, visitor_hook hook);
101 };
102 
103 // Builder function
104 std::unique_ptr<callback_base>
105 build_dump_weights_callback_from_pbuf(const google::protobuf::Message&,
106  std::shared_ptr<lbann_summary> const&);
107 
108 } // namespace callback
109 } // namespace lbann
110 
111 #endif // LBANN_CALLBACKS_CALLBACK_DUMP_WEIGHTS_HPP_INCLUDED
std::string name() const override
Return this callback&#39;s name.
std::string m_directory
Basename for writing files.
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
std::unique_ptr< dump_weights_internal::FileFormat > m_file_format
Weight file format.
const std::string & get_target_dir()
El::Int m_epoch_interval
Interval at which to dump weights.
Base class for callbacks during training/testing.
Definition: callback.hpp:76
Abstract base class for neural network models.
Definition: model.hpp:83
dump_weights * copy() const override
void set_target_dir(const std::string &dir)
visitor_hook
Neural network execution mode.
std::unique_ptr< callback_base > build_dump_weights_callback_from_pbuf(const google::protobuf::Message &, std::shared_ptr< lbann_summary > const &)
Dump weights to files.