LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
checkpoint_common.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_SRC_EXECUTION_ALGORITHMS_LTFB_CHECKPOINT_COMMON_HPP_INCLUDED
27 #define LBANN_SRC_EXECUTION_ALGORITHMS_LTFB_CHECKPOINT_COMMON_HPP_INCLUDED
28 
29 #include "lbann/models/model.hpp"
31 
32 #include <unordered_set>
33 
34 namespace lbann {
35 namespace ltfb {
36 
37 // Pack model to ship off
38 inline static std::string pack(model const& m)
39 {
40  std::ostringstream oss;
41  {
42  RootedBinaryOutputArchive ar(oss, m.get_comm()->get_trainer_grid());
43  ar(m);
44  }
45  return oss.str();
46 }
47 
48 // Send a string to the root of the destination trainer
49 inline static void send_string(lbann_comm const& comm,
50  std::string const& str,
51  int destination_trainer)
52 {
53  size_t size = str.length();
54  comm.send(&size, 1, destination_trainer, /*rank=*/0);
55  comm.send(reinterpret_cast<El::byte const*>(str.data()),
56  size,
57  destination_trainer,
58  /*rank=*/0);
59 }
60 
61 // Receive a string from the root of src_trainer
62 inline static std::string recv_string(lbann_comm const& comm, int src_trainer)
63 {
64  size_t size = 0;
65  comm.recv(&size, 1, src_trainer);
66  std::string buf;
67  buf.resize(size);
68  comm.recv(reinterpret_cast<El::byte*>(buf.data()), size, src_trainer);
69  return buf;
70 }
71 
72 // Unpack received model
73 inline static void unpack(model& m, std::string const& str)
74 {
75  std::istringstream iss(str);
76  {
77  RootedBinaryInputArchive ar(iss, m.get_comm()->get_trainer_grid());
78  ar(m);
79  }
80 }
81 
82 inline static void restore_model_weights(
83  model& m,
84  std::unordered_map<std::string, std::unique_ptr<weights>>& restore_weights)
85 {
86  // Restore weights that shouldn't be exchanged
87  if (restore_weights.empty())
88  return;
89 
90  // FIXME: Generalize this; enable ptr move??
91  for (auto w : m.get_weights()) {
92  if (restore_weights.count(w->get_name()) > 0) {
93  using TensorDataType = DataType;
94  using WeightsType = data_type_weights<TensorDataType>;
95  dynamic_cast<WeightsType&>(*w) =
96  dynamic_cast<WeightsType&>(*restore_weights[w->get_name()]);
97  }
98  }
99 }
100 
101 inline static std::string sendrecv_string(lbann_comm const& c,
102  std::string const& src,
103  El::Int partner_trainer)
104 {
105 #ifdef LBANN_HAS_ALUMINUM
106  El::mpi::EnsureComm<size_t, El::Collective::SENDRECV>(
107  c.get_world_comm(),
108  El::SyncInfo<El::Device::CPU>{});
109 #endif
110 
111  if (!c.am_trainer_master())
112  return "";
113 
114  // Exchange sizes
115  size_t my_size = src.size();
116  size_t other_size = src.max_size() + 1;
117  c.sendrecv(&my_size,
118  1,
119  partner_trainer,
120  0,
121  &other_size,
122  1,
123  partner_trainer,
124  0,
125  El::SyncInfo<El::Device::CPU>{});
126 
127  // Exchange strings
128  std::string tgt(other_size, '\0');
129 
130  auto const* send_buf = reinterpret_cast<El::byte const*>(src.data());
131  auto* recv_buf = reinterpret_cast<El::byte*>(tgt.data());
132 
133  // Get the max blk size
134  int constexpr max_blk_size_int = std::numeric_limits<int>::max();
135  std::size_t constexpr max_blk_size_size_t = max_blk_size_int;
136 
137  while (my_size || other_size) {
138  int const this_blk_send_size =
139  (my_size > max_blk_size_size_t ? max_blk_size_int : my_size);
140  int const this_blk_recv_size =
141  (other_size > max_blk_size_size_t ? max_blk_size_int : other_size);
142 
143  c.sendrecv(send_buf,
144  this_blk_send_size,
145  partner_trainer,
146  0,
147  recv_buf,
148  this_blk_recv_size,
149  partner_trainer,
150  0,
151  El::SyncInfo<El::Device::CPU>{});
152 
153  send_buf += this_blk_send_size;
154  recv_buf += this_blk_recv_size;
155  my_size =
156  (my_size > max_blk_size_size_t ? my_size - max_blk_size_size_t : 0);
157  other_size =
158  (other_size > max_blk_size_size_t ? other_size - max_blk_size_size_t : 0);
159  }
160  return tgt;
161 }
162 
163 template <typename T>
164 inline static void
165 exchange(lbann_comm const& c, T& object, El::Int partner_trainer)
166 {
167  std::ostringstream oss;
168  {
169  RootedBinaryOutputArchive ar(oss, c.get_trainer_grid());
170  c.trainer_barrier();
171  ar(object);
172  }
173  c.trainer_barrier(); // I don't think this is necessary
174  {
175  std::istringstream iss{sendrecv_string(c, oss.str(), partner_trainer)};
176  RootedBinaryInputArchive ar(iss, c.get_trainer_grid());
177  ar(object);
178  }
179  c.trainer_barrier(); // I don't think this is necessary either
180 }
181 
182 } // namespace ltfb
183 } // namespace lbann
184 #endif // LBANN_SRC_EXECUTION_ALGORITHMS_LTFB_CHECKPOINT_COMMON_HPP_INCLUDED
static void exchange(lbann_comm const &c, T &object, El::Int partner_trainer)
lbann_comm * get_comm() const noexcept
Get the model&#39;s comm.
Definition: model.hpp:652
static std::string recv_string(lbann_comm const &comm, int src_trainer)
void trainer_barrier() const
void send(const T *data, int count, int trainer, int rank) const
Definition: comm_impl.hpp:761
std::vector< weights * > get_weights()
Abstract base class for neural network models.
Definition: model.hpp:83
bool am_trainer_master() const noexcept
Definition: comm.hpp:192
static std::string sendrecv_string(lbann_comm const &c, std::string const &src, El::Int partner_trainer)
static void restore_model_weights(model &m, std::unordered_map< std::string, std::unique_ptr< weights >> &restore_weights)
void recv(T *data, int count, int trainer, int rank) const
Definition: comm_impl.hpp:828
const El::mpi::Comm & get_world_comm() const noexcept
Definition: comm.hpp:901
static void unpack(model &m, std::string const &str)
void sendrecv(const T *snd, int send_count, int send_trainer, int send_rank, T *rcv, int recv_count, int recv_trainer, int recv_rank) const
Definition: comm_impl.hpp:923
static void send_string(lbann_comm const &comm, std::string const &str, int destination_trainer)
El::Grid & get_trainer_grid()
Definition: comm.hpp:202
static std::string pack(model const &m)