LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
persist_impl.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
25 //
26 // lbann_file_io .hpp .cpp - Input / output utilities
28 
29 #ifndef LBANN_IO_PERSIST_IMPL_H
30 #define LBANN_IO_PERSIST_IMPL_H
31 
32 #include "lbann/comm_impl.hpp"
33 #include "lbann/io/persist.hpp"
35 
36 namespace lbann {
37 
39 {
40  switch (m) {
54  default:
55  LBANN_ERROR("Invalid execution mode specified");
56  }
57 }
58 
59 inline std::string to_string(persist_type pt)
60 {
61  switch (pt) {
63  return "model";
65  return "metrics";
67  return "train";
69  return "validate";
71  return "test";
73  return "prediction";
75  return "training";
77  return "validation";
79  return "tournament";
81  return "testing";
83  return "inference";
84  default:
85  LBANN_ERROR("Invalid persist type specified");
86  }
87 }
88 
90 template <class Archive>
91 void persist::serialize(Archive& ar)
92 {
93  ar(CEREAL_NVP(ckpt_type));
94 }
95 
96 template <typename C>
97 void write_cereal_archive(C& obj, const std::string& filename)
98 {
99  std::ofstream os(filename);
100  if (!os.is_open()) {
101  throw NonexistentArchiveFile(filename);
102  }
103 #ifdef LBANN_HAS_CEREAL_XML_ARCHIVES
104  cereal::XMLOutputArchive archive(os);
105 #else // defined LBANN_HAS_CEREAL_BINARY_ARCHIVES
106  cereal::BinaryOutputArchive archive(os);
107 #endif // LBANN_HAS_CEREAL_XML_ARCHIVES
108  archive(obj);
109 }
110 
111 template <typename C>
112 void write_cereal_archive(C& obj, persist& p, const std::string& filename)
113 {
114  write_cereal_archive<C>(obj, p.get_checkpoint_dir() + "/" + filename);
115 }
116 
117 template <typename C>
119  persist& p,
120  persist_type pt,
121  const std::string& suffix)
122 {
123  write_cereal_archive<C>(obj, p.get_filename(pt) + suffix);
124 }
125 
126 template <typename C>
128  persist& p,
129  execution_mode mode,
130  const std::string& suffix)
131 {
133  write_cereal_archive<C>(obj, p, pt, suffix);
134 }
135 
136 template <typename C>
137 void read_cereal_archive(C& obj, const std::string& filename)
138 {
139  std::ifstream is(filename);
140  if (!is.is_open()) {
141  throw NonexistentArchiveFile(filename);
142  }
143 #ifdef LBANN_HAS_CEREAL_XML_ARCHIVES
144  cereal::XMLInputArchive archive(is);
145 #else // defined LBANN_HAS_CEREAL_BINARY_ARCHIVES
146  cereal::BinaryInputArchive archive(is);
147 #endif // LBANN_HAS_CEREAL_XML_ARCHIVES
148  archive(obj);
149 }
150 
151 template <typename C>
152 void read_cereal_archive(C& obj, persist& p, const std::string& filename)
153 {
154  read_cereal_archive(obj, p.get_checkpoint_dir() + "/" + filename);
155 }
156 
157 template <typename C>
159  persist& p,
160  persist_type pt,
161  const std::string& suffix)
162 {
163  read_cereal_archive(obj, p.get_filename(pt) + suffix);
164 }
165 
166 template <typename C>
168  persist& p,
169  execution_mode mode,
170  const std::string& suffix)
171 {
173  read_cereal_archive<C>(obj, p, pt, suffix);
174 }
175 
176 template <typename C>
178 {
179  std::ostringstream ss;
180  {
181  cereal::BinaryOutputArchive archive(ss);
182  archive(obj);
183  } // archive goes out of scope, ensuring all contents are flushed
184  return ss.str();
185 }
186 
187 template <typename C>
188 void unpack_cereal_archive_binary_string(C& obj, const std::string& buf)
189 {
190  std::istringstream ss(buf);
191  {
192  cereal::BinaryInputArchive archive(ss);
193  archive(obj);
194  } // archive goes out of scope, ensuring all contents are flushed
195 }
196 
197 template <typename C>
199  lbann_comm& comm,
200  const std::string& filename)
201 {
202  std::string buf;
203  if (comm.am_trainer_master()) {
204  read_cereal_archive<C>(obj, filename);
205  buf = create_cereal_archive_binary_string<C>(obj);
206  }
207  else {
208  // If you are not the trainer master, still check to see if the file exists
209  std::ifstream is(filename);
210  if (!is.is_open()) {
211  throw NonexistentArchiveFile(filename);
212  }
213  }
214 
215  // TODO: this assumes homogeneous processors
216  // broadcast state from rank 0
217  comm.trainer_broadcast(0, buf);
218 
219  if (!comm.am_trainer_master()) {
220  unpack_cereal_archive_binary_string<C>(obj, buf);
221  }
222 }
223 
224 template <typename C>
226  persist& p,
227  lbann_comm& comm,
228  const std::string& filename)
229 {
230  load_from_shared_cereal_archive(obj, comm, p.get_checkpoint_dir() + filename);
231 }
232 
233 template <typename C>
235  persist& p,
236  persist_type pt,
237  lbann_comm& comm,
238  const std::string& suffix)
239 {
240  load_from_shared_cereal_archive(obj, comm, p.get_filename(pt) + suffix);
241 }
242 
243 template <typename C>
245  persist& p,
246  execution_mode mode,
247  lbann_comm& comm,
248  const std::string& suffix)
249 {
251  load_from_shared_cereal_archive<C>(obj, p, pt, comm, suffix);
252 }
253 
254 } // namespace lbann
255 #endif // LBANN_IO_PERSIST_IMPL_H
callback_type ckpt_type
Definition: persist.hpp:77
void trainer_broadcast(int root, T &val) const
Within-trainer broadcast of a scalar.
Definition: comm_impl.hpp:48
#define LBANN_ERROR(...)
Definition: exception.hpp:37
void serialize(Archive &ar)
persist_type execution_mode_to_persist_type(execution_mode m)
const std::string & get_checkpoint_dir() const
Definition: persist.hpp:136
void read_cereal_archive(C &obj, const std::string &filename)
std::string to_string(El::Device const &d)
void load_from_shared_cereal_archive(C &obj, lbann_comm &comm, const std::string &filename)
std::string create_cereal_archive_binary_string(C &obj)
execution_mode
Neural network execution mode.
Definition: base.hpp:229
persist_type
Definition: persist.hpp:39
bool am_trainer_master() const noexcept
Definition: comm.hpp:192
void write_cereal_archive(C &obj, const std::string &filename)
std::string get_filename(persist_type type) const
void unpack_cereal_archive_binary_string(C &obj, const std::string &buf)