27 #ifndef LBANN_COMM_HPP_IMPL_INCLUDED 28 #define LBANN_COMM_HPP_IMPL_INCLUDED 63 const El::mpi::Comm& c)
const 65 broadcast(root, data, count, c, El::SyncInfo<El::Device::CPU>{});
72 const int count)
const 77 template <
typename T, El::Device D>
81 El::SyncInfo<D>
const& syncInfo)
const 89 const int count)
const 93 template <
typename T, El::Device D>
97 El::SyncInfo<D>
const& syncInfo)
const 102 template <
typename T>
105 const int count)
const 110 template <
typename T, El::Device D>
114 El::SyncInfo<D>
const& syncInfo)
const 122 template <
typename T>
124 std::vector<T>&
data,
125 const El::mpi::Comm& c)
const 127 auto const rank_c = El::mpi::Rank(c);
128 size_t count = data.size();
129 El::mpi::Broadcast(&count, 1, root, c, El::SyncInfo<El::Device::CPU>{});
139 template <
typename T>
141 std::vector<T>&
data,
142 const El::mpi::Comm& c)
const 144 const int count =
static_cast<int>(
resize(root, data, c));
148 broadcast(root, data.data(), count, c, El::SyncInfo<El::Device::CPU>{});
151 template <
typename T>
160 template <
typename T>
163 std::vector<T>&
data)
const 168 template <
typename T>
175 template <
typename T>
180 const El::mpi::Comm& c)
const 187 El::SyncInfo<El::Device::CPU>{});
189 template <
typename T, El::Device D>
194 const El::mpi::Comm& c,
195 El::SyncInfo<D>
const& syncInfo)
const 197 El::mpi::AllGather(src, src_count, rcv, rcv_count, c, syncInfo);
204 template <
typename T>
207 std::vector<int>
const& rcv_counts,
208 std::vector<int>
const& rcv_disp,
209 const El::mpi::Comm& c)
const 211 if (src.size() == 0) {
212 std::ostringstream err;
213 err << __FILE__ <<
" " << __LINE__ <<
" :: " 214 <<
"all_gather for vector<>: vector.size() == 0;\n" 215 <<
"this doesn't work!";
218 El::mpi::AllGather(src.data(),
224 El::SyncInfo<El::Device::CPU>{});
230 template <
typename T>
233 std::vector<int>
const& rcv_counts,
234 std::vector<int>
const& rcv_disp)
const 242 template <
typename T>
244 std::vector<T>&
data,
245 const El::mpi::Comm& c)
const 247 El::mpi::AllGather(&src,
252 El::SyncInfo<El::Device::CPU>{});
258 template <
typename T>
267 template <
typename T>
274 template <
typename T>
280 template <
typename T>
286 template <
typename T>
289 const int root)
const 294 template <
typename T>
302 template <
typename T>
305 const int root)
const 308 El::mpi::Gather(snd, count,
nullptr,
nullptr,
nullptr, root,
m_trainer_comm);
310 template <
typename T>
314 int const*
const rcv_counts,
315 int const*
const rcv_displacements)
const 330 template <
typename T>
336 template <
typename T>
342 template <
typename T>
345 const int root)
const 350 template <
typename T>
358 template <
typename T>
361 const El::mpi::Comm& c)
const 364 El::mpi::Gather(&snd,
370 El::SyncInfo<El::Device::CPU>{});
373 template <
typename T>
376 auto const size_c = El::mpi::Size(c);
377 auto const rank_c = El::mpi::Rank(c);
378 El::mpi::Gather(&snd, 1, rcv, 1, rank_c, c, El::SyncInfo<El::Device::CPU>{});
382 template <
typename T>
385 const El::mpi::Comm& c)
const 387 gather(snd, rcv.data(), c);
390 template <
typename T>
394 const El::mpi::Comm& c)
const 396 gather(snd, count, root, c, El::SyncInfo<El::Device::CPU>{});
398 template <
typename T, El::Device D>
402 const El::mpi::Comm& c,
403 El::SyncInfo<D>
const& syncInfo)
const 406 El::mpi::Gather(snd, count, (T*)
nullptr, 0, root, c, syncInfo);
409 template <
typename T>
413 const El::mpi::Comm& c)
const 415 gather(snd, count, rcv, c, El::SyncInfo<El::Device::CPU>{});
417 template <
typename T, El::Device D>
421 const El::mpi::Comm& c,
422 El::SyncInfo<D>
const& syncInfo)
const 424 auto const size_c = El::mpi::Size(c);
425 auto const rank_c = El::mpi::Rank(c);
426 El::mpi::Gather(snd, count, rcv, count, rank_c, c, syncInfo);
430 template <
typename T>
434 El::mpi::Scatter((T*)
nullptr,
440 El::SyncInfo<El::Device::CPU>{});
445 template <
typename T>
450 auto root = El::mpi::Rank(c);
451 El::mpi::Scatter(snd, 1, &val, 1, root, c, El::SyncInfo<El::Device::CPU>{});
455 template <
typename T>
458 const El::mpi::Op op)
const 463 template <
typename T>
469 template <
typename T>
472 const El::mpi::Op op)
const 477 template <
typename T>
483 template <
typename T>
487 const El::mpi::Op op)
const 492 template <
typename T>
496 const El::mpi::Op op)
const 501 template <
typename T>
504 const El::mpi::Comm& c,
505 const El::mpi::Op op)
const 508 El::mpi::Reduce(&snd,
514 El::SyncInfo<El::Device::CPU>{});
517 template <
typename T>
519 const El::mpi::Comm& c,
520 const El::mpi::Op op)
const 523 auto const size_c = El::mpi::Size(c);
524 auto const rank_c = El::mpi::Rank(c);
525 El::mpi::Reduce(&snd,
531 El::SyncInfo<El::Device::CPU>{});
538 template <
typename T>
542 const El::mpi::Comm& c)
const 544 reduce(snd, count, root, c, El::mpi::SUM, El::SyncInfo<El::Device::CPU>{});
546 template <
typename T, El::Device D>
550 const El::mpi::Comm& c,
551 El::SyncInfo<D>
const& syncInfo)
const 553 reduce(snd, count, root, c, El::mpi::SUM, syncInfo);
556 template <
typename T>
560 const El::mpi::Comm& c,
561 const El::mpi::Op op)
const 563 reduce(snd, count, root, c, op, El::SyncInfo<El::Device::CPU>{});
565 template <
typename T, El::Device D>
569 const El::mpi::Comm& c,
570 const El::mpi::Op op,
571 El::SyncInfo<D>
const& syncInfo)
const 574 El::mpi::Reduce(snd, (T*)
nullptr, count, op, root, c, syncInfo);
577 template <
typename T, El::Device D>
581 const El::mpi::Comm& c,
582 El::SyncInfo<D>
const& syncInfo)
const 584 reduce(snd, count, rcv, c, El::mpi::SUM, syncInfo);
586 template <
typename T>
590 const El::mpi::Comm& c)
const 592 reduce(snd, count, rcv, c, El::mpi::SUM, El::SyncInfo<El::Device::CPU>{});
595 template <
typename T>
599 const El::mpi::Comm& c,
600 const El::mpi::Op op)
const 602 reduce(snd, count, rcv, c, op, El::SyncInfo<El::Device::CPU>{});
604 template <
typename T, El::Device D>
608 const El::mpi::Comm& c,
610 El::SyncInfo<D>
const& syncInfo)
const 613 snd = (T
const*)MPI_IN_PLACE;
615 auto const rank_c = El::mpi::Rank(c);
616 auto const size_c = El::mpi::Size(c);
617 El::mpi::Reduce(snd, rcv, count, op, rank_c, c, syncInfo);
621 template <
typename T>
627 template <
typename T>
633 template <
typename T>
637 const El::mpi::Op op)
const 642 template <
typename T>
644 const El::mpi::Comm& c,
645 const El::mpi::Op op)
const 647 auto const size_c = El::mpi::Size(c);
657 template <
typename T>
661 const El::mpi::Comm& c,
662 const El::mpi::Op op)
const 664 auto const size_c = El::mpi::Size(c);
666 #ifdef LBANN_HAS_ALUMINUM 667 #ifdef LBANN_ALUMINUM_MPI_PASSTHROUGH 668 ::Al::MPIAllreduceAlgorithm algo =
669 ::Al::MPIAllreduceAlgorithm::mpi_passthrough;
671 ::Al::MPIAllreduceAlgorithm algo = ::Al::MPIAllreduceAlgorithm::automatic;
673 ::Al::Allreduce<::Al::MPIBackend>(
678 c.template GetComm<::Al::MPIBackend>(El::SyncInfo<El::Device::CPU>{}),
681 El::mpi::AllReduce(snd, rcv, count, op, c, El::SyncInfo<El::Device::CPU>{});
686 template <
typename T>
689 const El::mpi::Comm& c,
690 const El::mpi::Op op)
const 692 auto const size_c = El::mpi::Size(c);
694 #ifdef LBANN_HAS_ALUMINUM 695 #ifdef LBANN_ALUMINUM_MPI_PASSTHROUGH 696 ::Al::MPIAllreduceAlgorithm algo =
697 ::Al::MPIAllreduceAlgorithm::mpi_passthrough;
699 ::Al::MPIAllreduceAlgorithm algo = ::Al::MPIAllreduceAlgorithm::automatic;
701 ::Al::Allreduce<::Al::MPIBackend>(
705 c.template GetComm<::Al::MPIBackend>(El::SyncInfo<El::Device::CPU>{}),
708 El::mpi::AllReduce(data, count, op, c, El::SyncInfo<El::Device::CPU>{});
717 template <
typename T>
720 const El::mpi::Comm& c,
722 const El::mpi::Op op)
const 725 #ifdef LBANN_HAS_ALUMINUM 727 ::Al::NonblockingAllreduce<::Al::MPIBackend>(
731 c.template GetComm<::Al::MPIBackend>(El::SyncInfo<El::Device::CPU>{}),
734 MPI_Iallreduce(MPI_IN_PLACE,
737 El::mpi::TypeMap<T>(),
741 #endif // LBANN_HAS_ALUMINUM 746 template <
typename T>
749 El::mpi::WaitAll(req.size(), req.data());
753 template <
typename T>
760 template <
typename T>
764 const int rank)
const 766 send(data, count, trainer, rank, El::SyncInfo<El::Device::CPU>{});
768 template <
typename T, El::Device D>
773 El::SyncInfo<D>
const& syncInfo)
const 782 template <
typename T, El::Device D>
786 El::SyncInfo<D>
const& syncInfo)
const 792 template <
typename T>
797 El::mpi::Request<T>& req)
const 806 template <
typename T>
811 El::mpi::Request<T>& req,
812 const El::mpi::Comm& c)
const 815 El::mpi::TaggedISend(data, count, rank, tag, c, req);
817 template <
typename T>
821 El::mpi::Request<T>& req)
const 827 template <
typename T>
831 const int rank)
const 833 recv(data, count, trainer, rank, El::SyncInfo<El::Device::CPU>{});
835 template <
typename T>
840 template <
typename T>
843 recv(data, count, El::SyncInfo<El::Device::CPU>{});
845 template <
typename T, El::Device D>
850 El::SyncInfo<D>
const& syncInfo)
const 859 template <
typename T, El::Device D>
863 El::SyncInfo<D>
const& syncInfo)
const 868 template <
typename T, El::Device D>
871 El::SyncInfo<D>
const& syncInfo)
const 873 El::mpi::Recv(data, count, El::mpi::ANY_SOURCE,
get_world_comm(), syncInfo);
878 template <
typename T>
883 El::mpi::Request<T>& req)
const 892 template <
typename T>
897 El::mpi::Request<T>& req,
898 const El::mpi::Comm& c)
const 900 El::mpi::TaggedIRecv(data, count, rank, tag, c, req);
904 template <
typename T>
908 El::mpi::Request<T>& req)
const 912 template <
typename T>
915 El::mpi::Request<T>& req)
const 917 El::mpi::IRecv(data, count, El::mpi::ANY_SOURCE,
get_world_comm(), req);
922 template <
typename T, El::Device D>
924 const int send_count,
925 const int send_trainer,
928 const int recv_count,
929 const int recv_trainer,
930 const int recv_rank)
const 940 El::SyncInfo<El::Device::CPU>{});
942 template <
typename T, El::Device D>
944 const int send_count,
945 const int send_trainer,
947 const int recv_count,
948 const int recv_trainer)
const 958 El::SyncInfo<El::Device::CPU>{});
961 template <
typename T, El::Device D>
963 const int send_count,
964 const int send_trainer,
967 const int recv_count,
968 const int recv_trainer,
970 El::SyncInfo<D>
const& syncInfo)
const 974 El::mpi::SendRecv(snd,
983 template <
typename T, El::Device D>
985 const int send_count,
986 const int send_trainer,
988 const int recv_count,
989 const int recv_trainer,
990 El::SyncInfo<D>
const& syncInfo)
const 1004 template <
typename T>
1012 return El::mpi::GetCount<T>(status);
1014 template <
typename T>
1020 template <
typename T,
bool S>
1023 auto const rank_c = El::mpi::Rank(c);
1028 broadcast_native<TT>(root,
reinterpret_cast<TT&
>(val), c);
1036 template <
typename T>
1039 const El::mpi::Comm& c)
const 1041 El::mpi::Broadcast(val, root, c, El::SyncInfo<El::Device::CPU>{});
1044 template <
typename T>
1047 const El::mpi::Comm& c)
const 1049 const int bytes =
static_cast<int>(
sizeof(T));
1050 El::mpi::Broadcast<El::byte>(
reinterpret_cast<El::byte*
>(&val),
1054 El::SyncInfo<El::Device::CPU>{});
1057 template <
typename T, El::Device D,
bool S>
1061 const El::mpi::Comm& c,
1062 El::SyncInfo<D>
const& syncInfo)
const 1064 auto const rank_c = El::mpi::Rank(c);
1065 const int size =
static_cast<int>(S ? count :
sizeof(T) * count);
1069 El::mpi::Broadcast<TT>(
reinterpret_cast<TT*
>(
data), size, root, c, syncInfo);
1075 void lbann_comm::broadcast<std::string>(
int root,
1077 const El::mpi::Comm& c)
const;
1079 #ifndef LBANN_COMM_INSTANTIATE 1081 extern template void lbann_comm::allreduce(El::AbstractMatrix<T>& m, \ 1082 const El::mpi::Comm& c, \ 1083 El::mpi::Op op) const; \ 1084 extern template void lbann_comm::allreduce(El::AbstractDistMatrix<T>& m, \ 1085 const El::mpi::Comm& c, \ 1086 El::mpi::Op op) const; \ 1087 extern template void lbann_comm::nb_allreduce(El::AbstractMatrix<T>& m, \ 1088 const El::mpi::Comm& c, \ 1090 El::mpi::Op op) const; \ 1091 extern template void lbann_comm::nb_allreduce(El::AbstractDistMatrix<T>& m, \ 1092 const El::mpi::Comm& c, \ 1094 El::mpi::Op op) const 1096 #define LBANN_INSTANTIATE_CPU_HALF 1097 #define LBANN_INSTANTIATE_GPU_HALF 1100 #undef LBANN_INSTANTIATE_CPU_HALF 1101 #undef LBANN_INSTANTIATE_GPU_HALF 1102 #endif // LBANN_COMM_INSTANTIATE 1106 #endif // LBANN_COMM_IMPL_HPP_INCLUDED int get_rank_in_trainer() const noexcept
void nb_tagged_recv(T *data, int count, int rank, int tag, El::mpi::Request< T > &req, const El::mpi::Comm &c) const
int get_count(int trainer, int rank) const
void reduce(T snd, int root, const El::mpi::Comm &c, El::mpi::Op op=El::mpi::SUM) const
void trainer_all_gather(std::vector< T > const &src, std::vector< T > &rcs, std::vector< int > const &rcv_counts, std::vector< int > const &rcv_disp) const
void intertrainer_gather(T snd, int root) const
void nb_tagged_send(const T *data, int count, int rank, int tag, El::mpi::Request< T > &req, const El::mpi::Comm &c) const
void gather(T snd, int root, const El::mpi::Comm &c) const
void nb_recv(T *data, int count, int trainer, int rank, El::mpi::Request< T > &req) const
El::mpi::Comm m_trainer_comm
void trainer_broadcast(int root, T &val) const
Within-trainer broadcast of a scalar.
const El::mpi::Comm & get_intertrainer_comm() const noexcept
void nb_send(const T *data, int count, int trainer, int rank, El::mpi::Request< T > &req) const
T & data(const cnpy::NpyArray &na, const std::vector< size_t > indices)
void nb_allreduce(El::AbstractMatrix< TensorDataType > &m, const El::mpi::Comm &c, Al::request &req, El::mpi::Op op=El::mpi::SUM) const
static const mpi_req_type mpi_null_req
void broadcast_custom(int root, T &val, const El::mpi::Comm &c) const
void broadcast_native(int root, T &val, const El::mpi::Comm &c) const
size_t resize(const int root, std::vector< T > &data, const El::mpi::Comm &c) const
void trainer_reduce(T snd, int root, El::mpi::Op op=El::mpi::SUM) const
void trainer_gather(T snd, int root) const
void send(const T *data, int count, int trainer, int rank) const
void count_bytes_broadcast(const size_t bytes, const int rank, const int root) const noexcept
void trainer_gatherv(T const *snd, int count, int root) const
void all_gather(const T *src, int src_count, T *rcv, int rcv_count, const El::mpi::Comm &c) const
El::mpi::Comm m_intertrainer_comm
T allreduce(T snd, const El::mpi::Comm &c, El::mpi::Op op=El::mpi::SUM) const
void world_broadcast(int root, T &val) const
World broadcast of a scalar.
void intertrainer_broadcast(int root, T &val) const
Inter-trainer broadcast of a scalar.
void world_all_gather(T const &src, std::vector< T > &data) const
User-facing class that represents a set of compute resources.
T intertrainer_allreduce(T snd, El::mpi::Op op=El::mpi::SUM) const
const El::mpi::Comm & get_trainer_comm() const noexcept
void recv(T *data, int count, int trainer, int rank) const
T scatter(int root, const El::mpi::Comm &c) const
int get_world_rank(int trainer, int rank) const noexcept
int get_procs_per_trainer() const noexcept
const El::mpi::Comm & get_world_comm() const noexcept
T trainer_allreduce(T snd, El::mpi::Op op=El::mpi::SUM) const
void sendrecv(const T *snd, int send_count, int send_trainer, int send_rank, T *rcv, int recv_count, int recv_trainer, int recv_rank) const
void broadcast(int root, T &val, const El::mpi::Comm &c) const
Broadcast a scalar value over an arbitrary communicator.
void lbann_comm_abort(std::string msg) const
void wait_all(std::vector< El::mpi::Request< T >> &req) const
void wait(El::mpi::Request< T > &req) const
void intertrainer_reduce(T snd, int root, El::mpi::Op op=El::mpi::SUM) const