LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
rooted_archive_adaptor.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #pragma once
27 #ifndef LBANN_UTILS_SERIALIZATION_ROOTED_ARCHIVE_ADAPTOR_HPP_
28 #define LBANN_UTILS_SERIALIZATION_ROOTED_ARCHIVE_ADAPTOR_HPP_
29 
30 #if !(defined __CUDACC__)
31 
32 #include "cereal_utils.hpp"
33 
34 #include <El.hpp>
35 
36 #include <optional>
37 #include <string>
38 
39 namespace details {
40 
41 template <typename ArchiveT, lbann::utils::WhenTextArchive<ArchiveT> = 1>
42 void set_next_name(ArchiveT& ar, char const* name)
43 {
44  ar.setNextName(name);
45 }
46 template <typename ArchiveT, lbann::utils::WhenNotTextArchive<ArchiveT> = 1>
47 void set_next_name(ArchiveT&, char const*)
48 {}
49 
50 } // namespace details
51 
52 namespace lbann {
53 
54 // An archive that collects data to the root of a grid on save and
55 // broadcasts/scatters it on load.
56 template <typename OutputArchiveT>
58  : public cereal::OutputArchive<RootedOutputArchiveAdaptor<OutputArchiveT>>
59 {
60  static_assert(lbann::utils::IsBuiltinArchive<OutputArchiveT>,
61  "At this time only built-in Cereal archives are supported.");
62  static_assert(lbann::utils::IsOutputArchive<OutputArchiveT>,
63  "The given archive type must be an \"output\" archive type.");
64 
65 public:
66  using archive_type = OutputArchiveT;
67 
68 private:
70  using BaseType_ = cereal::OutputArchive<ThisType_>;
71 
72 public:
73  RootedOutputArchiveAdaptor(std::ostream& os,
74  El::Grid const& g,
75  El::Int root = 0)
76  : BaseType_{this},
77  ar_(g.Rank() == root ? std::make_optional<archive_type>(os)
78  : std::nullopt),
79  grid_{&g},
80  root_{root}
81  {}
82 
83  El::Grid const& grid() const noexcept { return *grid_; }
84 
85  El::Int root() const noexcept { return root_; }
86 
87  bool am_root() const noexcept { return (this->root() == grid_->Rank()); }
88 
89  void set_next_name(char const* name)
90  {
91  if (name && this->am_root())
92  ::details::set_next_name(ar_.value(), name);
93  }
94 
95  template <typename T>
96  void save_on_root(T const& data)
97  {
98  if (this->am_root())
99  ar_.value()(data);
100  }
101 
102  template <typename T>
103  void prologue_on_root(T const& data)
104  {
105  if (this->am_root())
106  prologue(ar_.value(), data);
107  }
108 
109  template <typename T>
110  void epilogue_on_root(T const& data)
111  {
112  if (this->am_root())
113  epilogue(ar_.value(), data);
114  }
115 
116 private:
117  std::optional<archive_type> ar_;
118  El::Grid const* grid_;
119  El::Int root_;
120 }; // RootedOutputArchiveAdaptor
121 
122 template <typename InputArchiveT>
124  : public cereal::InputArchive<RootedInputArchiveAdaptor<InputArchiveT>>
125 {
126  static_assert(lbann::utils::IsBuiltinArchive<InputArchiveT>,
127  "At this time only built-in Cereal archives are supported.");
128  static_assert(lbann::utils::IsInputArchive<InputArchiveT>,
129  "The given archive type must be an \"input\" archive type.");
130 
131 public:
132  using archive_type = InputArchiveT;
133 
134 private:
136  using BaseType_ = cereal::InputArchive<ThisType_>;
137 
138 public:
139  RootedInputArchiveAdaptor(std::istream& is,
140  El::Grid const& g,
141  El::Int root = 0)
142  : BaseType_{this},
143  ar_(g.Rank() == root ? std::make_optional<archive_type>(is)
144  : std::nullopt),
145  grid_{&g},
146  root_{root}
147  {}
148 
149  El::Grid const& grid() const noexcept { return *grid_; }
150 
151  El::Int root() const noexcept { return root_; }
152 
153  bool am_root() const noexcept { return (this->root() == grid_->Rank()); }
154 
155  void set_next_name(char const* name)
156  {
157  if (this->am_root())
158  ::details::set_next_name(ar_.value(), name);
159  }
160 
161  template <typename T>
163  {
164  if (this->am_root())
165  ar_.value()(data);
166  }
167 
168  template <typename T>
169  void prologue_on_root(T const& data)
170  {
171  if (this->am_root())
172  prologue(ar_.value(), data);
173  }
174 
175  template <typename T>
176  void epilogue_on_root(T const& data)
177  {
178  if (this->am_root())
179  epilogue(ar_.value(), data);
180  }
181 
182 private:
183  std::optional<archive_type> ar_;
184  El::Grid const* grid_;
185  El::Int root_;
186 }; // RootedInputArchiveAdaptor
187 
188 #ifdef LBANN_HAS_CEREAL_BINARY_ARCHIVES
189 using RootedBinaryInputArchive =
191 using RootedBinaryOutputArchive =
193 #endif // LBANN_HAS_CEREAL_BINARY_ARCHIVES
194 
195 #ifdef LBANN_HAS_CEREAL_JSON_ARCHIVES
196 using RootedJSONInputArchive =
198 using RootedJSONOutputArchive =
200 #endif // LBANN_HAS_CEREAL_JSON_ARCHIVES
201 
202 #ifdef LBANN_HAS_CEREAL_PORTABLE_BINARY_ARCHIVES
203 using RootedPortableBinaryInputArchive =
205 using RootedPortableBinaryOutputArchive =
207 #endif // LBANN_HAS_CEREAL_PORTABLE_BINARY_ARCHIVES
208 
209 #ifdef LBANN_HAS_CEREAL_XML_ARCHIVES
210 using RootedXMLInputArchive =
212 using RootedXMLOutputArchive =
214 #endif // LBANN_HAS_CEREAL_XML_ARCHIVES
215 
216 } // namespace lbann
217 
218 namespace cereal {
219 
220 // POD types are "broadcast" types by default. That is, the root value
221 // is stored in the archive and merely "forgotten" on non-root
222 // processes. Ideally, this would be controlled by a
223 // "HasValidMPIDataType" trait or something.
224 template <typename OutputArchiveT, typename DataT>
225 h2::meta::EnableWhen<std::is_arithmetic_v<DataT>, void>
227  DataT const& val)
228 {
229  ar.save_on_root(val);
230 }
231 
232 template <typename OutputArchiveT>
235  bool const& b)
236 {
237  ar.save_on_root(b);
238 }
239 
240 template <typename OutputArchiveT, typename DataT>
243  NameValuePair<DataT> const& nvp)
244 {
245  ar.set_next_name(nvp.name);
246  ar(nvp.value);
247 }
248 
249 // POD types are "broadcast" types by default. They are read on the
250 // root and broadcast across the grid.
251 template <typename InputArchiveT, typename DataT>
252 h2::meta::EnableWhen<std::is_arithmetic_v<DataT>, void>
254  DataT& val)
255 {
256  static_assert(!std::is_same_v<DataT, char>,
257  "Don't be a basic char. "
258  "Apparently Hydrogen doesn't support them.");
259 
260  ar.load_on_root(val);
261  El::mpi::Broadcast(val,
262  ar.root(),
263  ar.grid().Comm(),
264  El::SyncInfo<El::Device::CPU>{});
265 }
266 
267 template <typename InputArchiveT>
270  bool& b)
271 {
272  ar.load_on_root(b);
273  int val = b;
274  El::mpi::Broadcast(val,
275  ar.root(),
276  ar.grid().Comm(),
277  El::SyncInfo<El::Device::CPU>{});
278  if (!ar.am_root())
279  b = val;
280 }
281 
282 template <typename ArchiveT, typename CharT, typename TraitsT, typename AllocT>
285  std::basic_string<CharT, TraitsT, AllocT> const& str)
286 {
287  ar.save_on_root(str);
288 }
289 
290 template <typename ArchiveT, typename CharT, typename TraitsT, typename AllocT>
292  std::basic_string<CharT, TraitsT, AllocT>& str)
293 {
294  ar.load_on_root(str);
295  auto str_len = str.size();
296  El::mpi::Broadcast(str_len,
297  ar.root(),
298  ar.grid().Comm(),
299  El::SyncInfo<El::Device::CPU>{});
300  str.resize(str_len);
301  // I was seeing an undefined reference if using plain ol' char. I
302  // fear the day someone uses a wstring in here.
303  El::mpi::Broadcast(reinterpret_cast<El::byte*>(str.data()),
304  str_len * sizeof(CharT),
305  ar.root(),
306  ar.grid().Comm(),
307  El::SyncInfo<El::Device::CPU>{});
308 }
309 
310 // TODO: This may need some work. The current implementation is
311 // inspired by the XML archives in Cereal.
312 template <class InputArchiveT, class DataT>
315  NameValuePair<DataT>& nvp)
316 {
317  ar.set_next_name(nvp.name);
318  ar(nvp.value);
319 }
320 
321 template <class ArchiveT, class T>
323  SizeTag<T> const& tag)
324 {
325  ar.save_on_root(tag);
326 }
327 
328 template <class ArchiveT, class T>
330  SizeTag<T>& tag)
331 {
332  ar.load_on_root(tag);
333  El::mpi::Broadcast(tag.size,
334  ar.root(),
335  ar.grid().Comm(),
336  El::SyncInfo<El::Device::CPU>{});
337 }
338 
339 template <
340  class ArchiveT,
341  class T,
342  h2::meta::EnableWhen<
343  !std::is_arithmetic_v<T> &&
344  !::cereal::traits::has_minimal_base_class_serialization<
345  T,
346  ::cereal::traits::has_minimal_output_serialization,
347  ArchiveT>::value &&
348  !::cereal::traits::has_minimal_output_serialization<T, ArchiveT>::value,
349  int> = 1>
351 {
352  ar.prologue_on_root(data);
353 }
354 
355 template <
356  class ArchiveT,
357  class T,
358  h2::meta::EnableWhen<
359  !std::is_arithmetic_v<T> &&
360  !::cereal::traits::has_minimal_base_class_serialization<
361  T,
362  ::cereal::traits::has_minimal_output_serialization,
363  ArchiveT>::value &&
364  !::cereal::traits::has_minimal_output_serialization<T, ArchiveT>::value,
365  int> = 1>
367 {
368  ar.epilogue_on_root(data);
369 }
370 
371 template <
372  class ArchiveT,
373  class T,
374  h2::meta::EnableWhen<
375  !std::is_arithmetic_v<T> &&
376  !::cereal::traits::has_minimal_base_class_serialization<
377  T,
378  ::cereal::traits::has_minimal_input_serialization,
379  ArchiveT>::value &&
380  !::cereal::traits::has_minimal_input_serialization<T, ArchiveT>::value,
381  int> = 1>
383 {
384  ar.prologue_on_root(data);
385 }
386 
387 template <
388  class ArchiveT,
389  class T,
390  h2::meta::EnableWhen<
391  !std::is_arithmetic_v<T> &&
392  !::cereal::traits::has_minimal_base_class_serialization<
393  T,
394  ::cereal::traits::has_minimal_input_serialization,
395  ArchiveT>::value &&
396  !::cereal::traits::has_minimal_input_serialization<T, ArchiveT>::value,
397  int> = 1>
399 {
400  ar.epilogue_on_root(data);
401 }
402 
403 // For strings:
404 template <typename ArchiveT,
405  typename CharT,
406  typename TraitsT,
407  typename AllocatorT>
409  std::basic_string<CharT, TraitsT, AllocatorT> const&)
410 {}
411 
412 template <typename ArchiveT,
413  typename CharT,
414  typename TraitsT,
415  typename AllocatorT>
417  std::basic_string<CharT, TraitsT, AllocatorT> const&)
418 {}
419 
420 template <typename ArchiveT,
421  typename CharT,
422  typename TraitsT,
423  typename AllocatorT>
425  std::basic_string<CharT, TraitsT, AllocatorT> const&)
426 {}
427 
428 template <typename ArchiveT,
429  typename CharT,
430  typename TraitsT,
431  typename AllocatorT>
433  std::basic_string<CharT, TraitsT, AllocatorT> const&)
434 {}
435 
436 } // namespace cereal
437 
438 #ifdef LBANN_HAS_CEREAL_BINARY_ARCHIVES
439 CEREAL_REGISTER_ARCHIVE(
441 CEREAL_REGISTER_ARCHIVE(
443 CEREAL_SETUP_ARCHIVE_TRAITS(
446 #endif // LBANN_HAS_CEREAL_BINARY_ARCHIVES
447 
448 #ifdef LBANN_HAS_CEREAL_JSON_ARCHIVES
449 CEREAL_REGISTER_ARCHIVE(
451 CEREAL_REGISTER_ARCHIVE(
453 CEREAL_SETUP_ARCHIVE_TRAITS(
456 #endif // LBANN_HAS_CEREAL_JSON_ARCHIVES
457 
458 #ifdef LBANN_HAS_CEREAL_PORTABLE_BINARY_ARCHIVES
459 CEREAL_REGISTER_ARCHIVE(
461 CEREAL_REGISTER_ARCHIVE(
463 CEREAL_SETUP_ARCHIVE_TRAITS(
466 #endif // LBANN_HAS_CEREAL_PORTABLE_BINARY_ARCHIVES
467 
468 #ifdef LBANN_HAS_CEREAL_XML_ARCHIVES
469 CEREAL_REGISTER_ARCHIVE(
471 CEREAL_REGISTER_ARCHIVE(
473 CEREAL_SETUP_ARCHIVE_TRAITS(
476 #endif // LBANN_HAS_CEREAL_XML_ARCHIVES
477 
478 #endif // __CUDACC__
479 #endif // LBANN_UTILS_SERIALIZATION_ROOTED_ARCHIVE_ADAPTOR_HPP_
cereal::InputArchive< ThisType_ > BaseType_
cereal::OutputArchive< ThisType_ > BaseType_
El::Grid Grid
Definition: base.hpp:126
std::optional< archive_type > ar_
T & data(const cnpy::NpyArray &na, const std::vector< size_t > indices)
Definition: cnpy_utils.hpp:75
El::Grid const & grid() const noexcept
RootedInputArchiveAdaptor(std::istream &is, El::Grid const &g, El::Int root=0)
void CEREAL_SAVE_FUNCTION_NAME(lbann::RootedOutputArchiveAdaptor< ArchiveT > &ar, SizeTag< T > const &tag)
void epilogue(lbann::RootedInputArchiveAdaptor< ArchiveT > &, std::basic_string< CharT, TraitsT, AllocatorT > const &)
void set_next_name(ArchiveT &ar, char const *name)
void CEREAL_LOAD_FUNCTION_NAME(lbann::RootedInputArchiveAdaptor< ArchiveT > &ar, SizeTag< T > &tag)
void prologue(lbann::RootedInputArchiveAdaptor< ArchiveT > &, std::basic_string< CharT, TraitsT, AllocatorT > const &)
El::Grid const & grid() const noexcept
RootedOutputArchiveAdaptor(std::ostream &os, El::Grid const &g, El::Int root=0)