27 #ifndef LBANN_UTILS_SERIALIZATION_SERIALIZE_MATRICES_IMPL_HPP_ 28 #define LBANN_UTILS_SERIALIZATION_SERIALIZE_MATRICES_IMPL_HPP_ 32 #include <El/blas_like/level1/Copy/Translate.hpp> 33 #include <El/blas_like/level1/Copy/TranslateBetweenGrids.hpp> 38 template <
typename ArchiveT,
typename T>
39 void save(ArchiveT& ar, ::El::AbstractMatrix<T>
const& mat)
41 switch (mat.GetDevice()) {
42 case ::El::Device::CPU:
43 save(ar,
static_cast<::El::Matrix<T, ::El::Device::CPU> const&
>(mat));
46 case ::El::Device::GPU:
47 save(ar,
static_cast<::El::Matrix<T, ::El::Device::GPU> const&
>(mat));
49 #endif // LBANN_HAS_GPU 55 template <
typename ArchiveT,
59 void save(ArchiveT& ar, ::El::Matrix<T, D>
const& mat)
62 ar(::cereal::make_nvp(
"height", mat.Height()),
63 ::cereal::make_nvp(
"width", mat.Width()));
67 template <
typename ArchiveT,
70 void do_save(ArchiveT& ar, ::El::Matrix<T, ::El::Device::CPU>
const& mat)
73 ar(mat.Height(), mat.Width());
74 if (mat.Contiguous()) {
75 ar(::cereal::binary_data(mat.LockedBuffer(),
76 mat.LDim() * mat.Width() *
sizeof(T)));
79 for (::El::Int col = 0; col < mat.Width(); ++col)
80 ar(::cereal::binary_data(mat.LockedBuffer() + col * mat.LDim(),
81 mat.Height() *
sizeof(T)));
86 template <
typename ArchiveT,
89 void do_save(ArchiveT& ar, ::El::Matrix<T, ::El::Device::GPU>
const& mat)
91 ::El::Matrix<T, ::El::Device::CPU> cpu_mat(mat);
94 #endif // LBANN_HAS_GPU 98 template <
typename ArchiveT,
102 void save(ArchiveT& ar, ::El::Matrix<T, D>
const& mat)
108 template <
typename ArchiveT,
typename T, ::El::Device D>
110 ::El::Matrix<T, D>
const& mat)
117 template <
typename ArchiveT,
120 void load(ArchiveT& archive, ::El::AbstractMatrix<T>& mat)
122 switch (mat.GetDevice()) {
123 case ::El::Device::CPU:
124 load(archive,
static_cast<::El::Matrix<T, ::El::Device::CPU>&
>(mat));
127 case ::El::Device::GPU:
128 load(archive,
static_cast<::El::Matrix<T, ::El::Device::GPU>&
>(mat));
130 #endif // LBANN_HAS_GPU 136 template <
typename ArchiveT,
140 void load(ArchiveT& archive, ::El::Matrix<T, D>& mat)
143 ::El::Int height, width;
144 archive(CEREAL_NVP(height), CEREAL_NVP(width));
145 mat.Resize(height, width);
148 template <
typename ArchiveT,
151 void load(ArchiveT& archive, ::El::Matrix<T, ::El::Device::CPU>& mat)
154 ::El::Int height, width;
155 archive(CEREAL_NVP(height), CEREAL_NVP(width));
156 mat.Resize(height, width);
157 archive(::cereal::binary_data(mat.Buffer(),
158 mat.Height() * mat.Width() *
sizeof(T)));
161 #if defined LBANN_HAS_GPU 162 template <
typename ArchiveT,
165 void load(ArchiveT& archive, ::El::Matrix<T, ::El::Device::GPU>& mat)
168 ::El::Matrix<T, ::El::Device::CPU> cpu_mat;
169 load(archive, cpu_mat);
170 ::El::Copy(cpu_mat, mat);
172 #endif // defined LBANN_HAS_GPU 174 template <
typename ArchiveT,
typename T, ::El::Device D>
176 ::El::Matrix<T, D>& mat)
184 auto height = mat.Height();
185 auto width = mat.Width();
186 ::El::mpi::Broadcast(height,
189 ::El::SyncInfo<::El::Device::CPU>{});
190 ::El::mpi::Broadcast(width,
193 ::El::SyncInfo<::El::Device::CPU>{});
197 mat.Resize(height, width);
200 ::El::Broadcast(
static_cast<::El::AbstractMatrix<T>&
>(mat),
207 template <
typename ArchiveT,
210 void save(ArchiveT& ar, ::El::AbstractDistMatrix<T>
const& mat)
213 ar(::cereal::make_nvp(
"global_height", mat.Height()),
214 ::cereal::make_nvp(
"global_width", mat.Width()));
217 template <
typename ArchiveT,
220 void load(ArchiveT& ar, ::El::AbstractDistMatrix<T>& mat)
223 ::El::Int global_height, global_width;
224 ar(::cereal::make_nvp(
"global_height", global_height),
225 ::cereal::make_nvp(
"global_width", global_width));
226 mat.Resize(global_height, global_width);
229 template <
typename ArchiveT,
232 void save(ArchiveT& ar, ::El::AbstractDistMatrix<T>
const& mat)
237 ar(mat.Height(), mat.Width(), mat.LockedMatrix());
240 template <
typename ArchiveT,
243 void load(ArchiveT& ar, ::El::AbstractDistMatrix<T>& mat)
246 ::El::Int global_height, global_width;
247 ar(global_height, global_width);
248 mat.Resize(global_height, global_width);
250 ::El::Matrix<T, ::El::Device::CPU> mat_cpu;
254 mat.Matrix() = mat_cpu;
260 template <
typename ArchiveT,
264 ::El::AbstractDistMatrix<T>
const& mat)
266 ar(::cereal::make_nvp(
"global_height", mat.Height()),
267 ::cereal::make_nvp(
"global_width", mat.Width()));
270 template <
typename ArchiveT,
274 ::El::AbstractDistMatrix<T>& mat)
276 El::Int height, width;
277 ar(::cereal::make_nvp(
"global_height", height),
278 ::cereal::make_nvp(
"global_width", width));
279 mat.Resize(height, width);
282 template <
typename ArchiveT,
286 ::El::AbstractDistMatrix<T>
const& mat)
289 using CircMatType = ::El::
290 DistMatrix<T, ::El::CIRC, ::El::CIRC, ::El::ELEMENT, ::El::Device::CPU>;
291 CircMatType circ_mat(mat);
292 CircMatType circ_mat_ar(ar.
grid(), ar.
root());
293 if (circ_mat.DistData() == circ_mat_ar.DistData()) {
294 circ_mat_ar = std::move(circ_mat);
297 ::El::copy::Translate(circ_mat, circ_mat_ar);
299 save(ar, circ_mat_ar);
302 template <
typename ArchiveT,
306 ::El::AbstractDistMatrix<T>& mat)
309 using CircMatType = ::El::
310 DistMatrix<T, ::El::CIRC, ::El::CIRC, ::El::ELEMENT, ::El::Device::CPU>;
313 CircMatType circ_mat(mat.Grid(), mat.Root());
314 CircMatType circ_mat_ar(ar.
grid(), ar.
root());
315 load(ar, circ_mat_ar);
316 if (circ_mat.DistData() == circ_mat_ar.DistData()) {
317 circ_mat = std::move(circ_mat_ar);
320 ::El::copy::Translate(circ_mat_ar, circ_mat);
324 ::El::Copy(circ_mat, mat);
327 template <
typename ArchiveT,
331 ::El::DistMatrix<T, ::El::CIRC, ::El::CIRC>
const& mat)
336 ar(::cereal::make_nvp(
"global_height", mat.Height()),
337 ::cereal::make_nvp(
"global_width", mat.Width()));
338 save(ar, ::cereal::make_nvp(
"matrix_data", mat.LockedMatrix()));
341 template <
typename ArchiveT,
345 ::El::DistMatrix<T, ::El::CIRC, ::El::CIRC>& mat)
354 ::El::Int height, width;
355 ar(::cereal::make_nvp(
"global_height", height),
356 ::cereal::make_nvp(
"global_width", width));
359 mat.Resize(height, width);
369 template <
typename DataT,
376 ::h2::meta::EnableWhen<::lbann::utils::IsBuiltinArchive<ArchiveT>,
int>>
377 void LoadAndConstruct<::El::DistMatrix<DataT, CDist, RDist, Wrap, D>>::
378 load_and_construct(ArchiveT& ar, cereal::construct<DistMatrixType>& construct)
388 load(ar, *construct.ptr());
391 template <
typename DataT,
396 template <
typename ArchiveT>
397 void LoadAndConstruct<::El::DistMatrix<DataT, CDist, RDist, Wrap, D>>::
399 cereal::construct<DistMatrixType>& construct)
401 construct(ar.
grid(), 0);
402 load(ar, *construct.ptr());
407 #endif // LBANN_UTILS_SERIALIZATION_SERIALIZE_MATRICES_IMPL_HPP_
EnableWhen<!IsTextArchive< ArchiveT > &&IsBuiltinArchive< ArchiveT >, ResultT > WhenNotTextArchive
SFINAE helper for splitting text-based and non-text-based serialization functions.
Grid const & get_current_grid() noexcept
Get the current grid being used for deserialization.
constexpr El::Device Device
#define LBANN_ASSERT(cond)
void save_on_root(T const &data)
El::Int root() const noexcept
void load(ArchiveT &archive, ::El::AbstractMatrix< T > &mat)
void save(ArchiveT &ar, ::El::AbstractMatrix< T > const &mat)
Save a matrix to a text-based archive.
EnableWhen< IsTextArchive< ArchiveT > &&IsBuiltinArchive< ArchiveT >, ResultT > WhenTextArchive
SFINAE helper for splitting text-based and non-text-based serialization functions.
void do_save(ArchiveT &ar, ::El::Matrix< T, ::El::Device::CPU > const &mat)
Save a CPU matrix to a non-text-based archive.
El::Grid const & grid() const noexcept
::distconv::tensor::Distribution Dist