LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
persist.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
25 //
26 // lbann_file_io .hpp .cpp - Input / output utilities
28 
29 #ifndef LBANN_PERSIST_H
30 #define LBANN_PERSIST_H
31 
32 #include "El.hpp"
33 #include "lbann/base.hpp"
35 #include <sstream>
36 
37 namespace lbann {
38 
39 enum class persist_type
40 {
41  train, // data should be saved in file with train data
42  model, // data should be saved in file with model data
43  metrics,
44  validate,
45  testing,
52 };
53 
57 
58 inline persist_type execution_mode_to_persist_type(execution_mode m);
59 
60 inline std::string to_string(persist_type pt);
61 
63 enum class callback_type
64 {
65  model_only,
69  invalid
70 };
71 
72 class persist
73 {
74 private:
75  std::map<persist_type, uint64_t> m_bytes;
76  std::map<persist_type, std::string> m_filenames;
78 
79 public:
80  std::string m_checkpoint_dir;
81 
82 public:
83  persist();
84  ~persist(){};
85 
87  template <class Archive>
88  void serialize(Archive& ar);
89 
90  callback_type get_cb_type() const { return ckpt_type; }
91 
92  void set_cb_type(callback_type type) { ckpt_type = type; }
93 
94  void open_checkpoint_dir(const std::string& dir, bool create_dir);
95  void open_checkpoint(const std::string& dir, bool create_dir);
96  void close_checkpoint();
97 
98  void open_restart(const std::string& dir);
99  void close_restart();
100  void set_restart_dir(const std::string& dir) { m_checkpoint_dir = dir; }
101 
102  uint64_t get_bytes() const
103  {
104  uint64_t bytes = 0;
105  for (auto& pt : m_bytes) {
106  bytes += pt.second;
107  }
108  return bytes;
109  }
110 
111  void reset_bytes()
112  {
113  for (auto& pt : m_bytes) {
114  pt.second = 0;
115  }
116  }
117 
118  template <typename TensorDataType>
119  bool write_rank_distmat(persist_type type,
120  const char* name,
121  const El::AbstractDistMatrix<TensorDataType>& M);
122  template <typename TensorDataType>
123  bool read_rank_distmat(persist_type type,
124  const char* name,
125  El::AbstractDistMatrix<TensorDataType>& M);
126 
127  template <typename TensorDataType>
128  bool write_distmat(persist_type type,
129  const char* name,
130  El::AbstractDistMatrix<TensorDataType>* M);
131  template <typename TensorDataType>
132  bool read_distmat(persist_type type,
133  const char* name,
134  El::AbstractDistMatrix<TensorDataType>* M);
135 
136  const std::string& get_checkpoint_dir() const { return m_checkpoint_dir; }
137 
138  std::string get_filename(persist_type type) const;
139 };
140 
141 bool write_bytes(int fd, const char* name, const void* buf, size_t size);
142 bool read_bytes(int fd, const char* name, void* buf, size_t size);
143 
144 bool write_string(int fd, const char* name, const char* buf, size_t size);
145 bool read_string(int fd, const char* name, char* buf, size_t size);
146 
147 class NonexistentArchiveFile : public std::runtime_error
148 {
149 public:
150  NonexistentArchiveFile(std::string const& filename)
151  : std::runtime_error(std::string("Archive file not found: ") + filename)
152  {}
153 };
154 
155 template <typename C>
156 void write_cereal_archive(C& obj, const std::string& filename);
157 
158 template <typename C>
159 void write_cereal_archive(C& obj, persist& p, const std::string& filename);
160 
161 template <typename C>
162 void write_cereal_archive(C& obj,
163  persist& p,
164  persist_type pt,
165  const std::string& suffix);
166 
167 template <typename C>
168 void write_cereal_archive(C& obj,
169  persist& p,
170  execution_mode mode,
171  const std::string& suffix);
172 
173 template <typename C>
174 void read_cereal_archive(C& obj, const std::string& filename);
175 
176 template <typename C>
177 void read_cereal_archive(C& obj, persist& p, const std::string& filename);
178 
179 template <typename C>
180 void read_cereal_archive(C& obj,
181  persist& p,
182  persist_type pt,
183  const std::string& suffix);
184 
185 template <typename C>
186 void read_cereal_archive(C& obj,
187  persist& p,
188  execution_mode mode,
189  const std::string& suffix);
190 
191 template <typename C>
192 std::string create_cereal_archive_binary_string(C& obj);
193 
194 template <typename C>
195 void unpack_cereal_archive_binary_string(C& obj, const std::string& buf);
196 
197 template <typename C>
199  lbann_comm& comm,
200  const std::string& filename);
201 
202 template <typename C>
204  persist& p,
205  lbann_comm& comm,
206  const std::string& filename);
207 
208 template <typename C>
210  persist& p,
211  persist_type pt,
212  lbann_comm& comm,
213  const std::string& suffix);
214 
215 template <typename C>
217  persist& p,
218  execution_mode mode,
219  lbann_comm& comm,
220  const std::string& suffix);
221 
222 #ifndef LBANN_PERSIST_INSTANTIATE
223 #define PROTO(T) \
224  extern template bool persist::write_rank_distmat<T>( \
225  persist_type type, \
226  const char* name, \
227  const El::AbstractDistMatrix<T>& M); \
228  extern template bool persist::read_rank_distmat<T>( \
229  persist_type type, \
230  const char* name, \
231  El::AbstractDistMatrix<T>& M); \
232  extern template bool persist::write_distmat<T>( \
233  persist_type type, \
234  const char* name, \
235  El::AbstractDistMatrix<T>* M); \
236  extern template bool persist::read_distmat<T>(persist_type type, \
237  const char* name, \
238  El::AbstractDistMatrix<T>* M)
239 
240 #define LBANN_INSTANTIATE_CPU_HALF
241 #define LBANN_INSTANTIATE_GPU_HALF
243 #undef PROTO
244 #undef LBANN_INSTANTIATE_CPU_HALF
245 #undef LBANN_INSTANTIATE_GPU_HALF
246 #endif // LBANN_PERSIST_INSTANTIATE
247 
248 } // namespace lbann
249 
250 #endif // LBANN_PERSIST_H
callback_type ckpt_type
Definition: persist.hpp:77
Create an iterator that goes over a contiguous (unit-step) enum class.
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
persist_type execution_mode_to_persist_type(execution_mode m)
const std::string & get_checkpoint_dir() const
Definition: persist.hpp:136
uint64_t get_bytes() const
Definition: persist.hpp:102
std::map< persist_type, uint64_t > m_bytes
Definition: persist.hpp:75
void read_cereal_archive(C &obj, const std::string &filename)
std::string to_string(El::Device const &d)
void load_from_shared_cereal_archive(C &obj, lbann_comm &comm, const std::string &filename)
callback_type
Definition: persist.hpp:63
void reset_bytes()
Definition: persist.hpp:111
NonexistentArchiveFile(std::string const &filename)
Definition: persist.hpp:150
std::string create_cereal_archive_binary_string(C &obj)
callback_type get_cb_type() const
Definition: persist.hpp:90
bool read_bytes(int fd, const char *name, void *buf, size_t size)
void set_restart_dir(const std::string &dir)
Definition: persist.hpp:100
execution_mode
Neural network execution mode.
Definition: base.hpp:229
persist_type
Definition: persist.hpp:39
bool create_dir(const std::string output_dir)
bool read_string(int fd, const char *name, char *buf, size_t size)
bool write_string(int fd, const char *name, const char *buf, size_t size)
bool write_bytes(int fd, const char *name, const void *buf, size_t size)
void write_cereal_archive(C &obj, const std::string &filename)
std::map< persist_type, std::string > m_filenames
Definition: persist.hpp:76
void set_cb_type(callback_type type)
Definition: persist.hpp:92
std::string m_checkpoint_dir
Definition: persist.hpp:80
void unpack_cereal_archive_binary_string(C &obj, const std::string &buf)