26 #ifndef LBANN_SRC_EXECUTION_ALGORITHMS_LTFB_CHECKPOINT_COMMON_HPP_INCLUDED 27 #define LBANN_SRC_EXECUTION_ALGORITHMS_LTFB_CHECKPOINT_COMMON_HPP_INCLUDED 32 #include <unordered_set> 40 std::ostringstream oss;
50 std::string
const& str,
51 int destination_trainer)
53 size_t size = str.length();
54 comm.
send(&size, 1, destination_trainer, 0);
55 comm.
send(reinterpret_cast<El::byte const*>(str.data()),
65 comm.
recv(&size, 1, src_trainer);
68 comm.
recv(reinterpret_cast<El::byte*>(buf.data()), size, src_trainer);
75 std::istringstream iss(str);
84 std::unordered_map<std::string, std::unique_ptr<weights>>& restore_weights)
87 if (restore_weights.empty())
92 if (restore_weights.count(w->get_name()) > 0) {
93 using TensorDataType = DataType;
95 dynamic_cast<WeightsType&
>(*w) =
96 dynamic_cast<WeightsType&
>(*restore_weights[w->get_name()]);
102 std::string
const& src,
103 El::Int partner_trainer)
105 #ifdef LBANN_HAS_ALUMINUM 106 El::mpi::EnsureComm<size_t, El::Collective::SENDRECV>(
108 El::SyncInfo<El::Device::CPU>{});
115 size_t my_size = src.size();
116 size_t other_size = src.max_size() + 1;
125 El::SyncInfo<El::Device::CPU>{});
128 std::string tgt(other_size,
'\0');
130 auto const* send_buf =
reinterpret_cast<El::byte const*
>(src.data());
131 auto* recv_buf =
reinterpret_cast<El::byte*
>(tgt.data());
134 int constexpr max_blk_size_int = std::numeric_limits<int>::max();
135 std::size_t constexpr max_blk_size_size_t = max_blk_size_int;
137 while (my_size || other_size) {
138 int const this_blk_send_size =
139 (my_size > max_blk_size_size_t ? max_blk_size_int : my_size);
140 int const this_blk_recv_size =
141 (other_size > max_blk_size_size_t ? max_blk_size_int : other_size);
151 El::SyncInfo<El::Device::CPU>{});
153 send_buf += this_blk_send_size;
154 recv_buf += this_blk_recv_size;
156 (my_size > max_blk_size_size_t ? my_size - max_blk_size_size_t : 0);
158 (other_size > max_blk_size_size_t ? other_size - max_blk_size_size_t : 0);
163 template <
typename T>
167 std::ostringstream oss;
175 std::istringstream iss{
sendrecv_string(c, oss.str(), partner_trainer)};
184 #endif // LBANN_SRC_EXECUTION_ALGORITHMS_LTFB_CHECKPOINT_COMMON_HPP_INCLUDED static void exchange(lbann_comm const &c, T &object, El::Int partner_trainer)
lbann_comm * get_comm() const noexcept
Get the model's comm.
static std::string recv_string(lbann_comm const &comm, int src_trainer)
void trainer_barrier() const
void send(const T *data, int count, int trainer, int rank) const
std::vector< weights * > get_weights()
Abstract base class for neural network models.
bool am_trainer_master() const noexcept
static std::string sendrecv_string(lbann_comm const &c, std::string const &src, El::Int partner_trainer)
static void restore_model_weights(model &m, std::unordered_map< std::string, std::unique_ptr< weights >> &restore_weights)
void recv(T *data, int count, int trainer, int rank) const
const El::mpi::Comm & get_world_comm() const noexcept
static void unpack(model &m, std::string const &str)
void sendrecv(const T *snd, int send_count, int send_trainer, int send_rank, T *rcv, int recv_count, int recv_trainer, int recv_rank) const
static void send_string(lbann_comm const &comm, std::string const &str, int destination_trainer)
El::Grid & get_trainer_grid()
static std::string pack(model const &m)