LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
comm_impl.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_COMM_HPP_IMPL_INCLUDED
28 #define LBANN_COMM_HPP_IMPL_INCLUDED
29 
30 #include "lbann/comm.hpp"
31 
32 namespace lbann {
33 
35 template <typename T>
36 void lbann_comm::world_broadcast(int root, T& val) const
37 {
38  broadcast(root, val, get_world_comm());
39 }
41 template <typename T>
42 void lbann_comm::intertrainer_broadcast(int root, T& val) const
43 {
44  broadcast(root, val, get_intertrainer_comm());
45 }
47 template <typename T>
48 void lbann_comm::trainer_broadcast(int root, T& val) const
49 {
50  broadcast(root, val, get_trainer_comm());
51 }
52 
58 // Default to cpu memory
59 template <typename T>
60 void lbann_comm::broadcast(const int root,
61  T* const data,
62  const int count,
63  const El::mpi::Comm& c) const
64 {
65  broadcast(root, data, count, c, El::SyncInfo<El::Device::CPU>{});
66 }
67 
69 template <typename T>
70 void lbann_comm::world_broadcast(const int root,
71  T* const data,
72  const int count) const
73 {
74  world_broadcast(root, data, count, El::SyncInfo<El::Device::CPU>{});
75 }
76 
77 template <typename T, El::Device D>
78 void lbann_comm::world_broadcast(const int root,
79  T* const data,
80  const int count,
81  El::SyncInfo<D> const& syncInfo) const
82 {
83  broadcast(root, data, count, get_world_comm(), syncInfo);
84 }
86 template <typename T>
88  T* data,
89  const int count) const
90 {
91  intertrainer_broadcast(root, data, count, El::SyncInfo<El::Device::CPU>{});
92 }
93 template <typename T, El::Device D>
95  T* const data,
96  const int count,
97  El::SyncInfo<D> const& syncInfo) const
98 {
99  broadcast(root, data, count, get_intertrainer_comm(), syncInfo);
100 }
102 template <typename T>
103 void lbann_comm::trainer_broadcast(const int root,
104  T* const data,
105  const int count) const
106 {
107  trainer_broadcast(root, data, count, El::SyncInfo<El::Device::CPU>{});
108 }
109 
110 template <typename T, El::Device D>
111 void lbann_comm::trainer_broadcast(const int root,
112  T* const data,
113  const int count,
114  El::SyncInfo<D> const& syncInfo) const
115 {
116  broadcast(root, data, count, get_trainer_comm(), syncInfo);
117 }
118 
122 template <typename T>
123 size_t lbann_comm::resize(const int root,
124  std::vector<T>& data,
125  const El::mpi::Comm& c) const
126 {
127  auto const rank_c = El::mpi::Rank(c);
128  size_t count = data.size();
129  El::mpi::Broadcast(&count, 1, root, c, El::SyncInfo<El::Device::CPU>{});
130  count_bytes_broadcast(sizeof(size_t), rank_c, root);
131  data.resize(count);
132  return count;
133 }
134 
139 template <typename T>
140 void lbann_comm::broadcast(const int root,
141  std::vector<T>& data,
142  const El::mpi::Comm& c) const
143 {
144  const int count = static_cast<int>(resize(root, data, c));
145  if (count <= 0) {
146  return;
147  }
148  broadcast(root, data.data(), count, c, El::SyncInfo<El::Device::CPU>{});
149 }
151 template <typename T>
152 void lbann_comm::world_broadcast(const int root, std::vector<T>& data) const
153 {
154  broadcast(root, data, get_world_comm());
155 }
160 template <typename T>
163  std::vector<T>& data) const
164 {
165  broadcast(root, data, get_intertrainer_comm());
166 }
168 template <typename T>
169 void lbann_comm::trainer_broadcast(const int root, std::vector<T>& data) const
170 {
171  broadcast(root, data, get_trainer_comm());
172 }
173 
175 template <typename T>
176 void lbann_comm::all_gather(const T* const src,
177  const int src_count,
178  T* const rcv,
179  const int rcv_count,
180  const El::mpi::Comm& c) const
181 {
182  all_gather(src,
183  src_count,
184  rcv,
185  rcv_count,
186  c,
187  El::SyncInfo<El::Device::CPU>{});
188 }
189 template <typename T, El::Device D>
190 void lbann_comm::all_gather(const T* const src,
191  const int src_count,
192  T* const rcv,
193  const int rcv_count,
194  const El::mpi::Comm& c,
195  El::SyncInfo<D> const& syncInfo) const
196 {
197  El::mpi::AllGather(src, src_count, rcv, rcv_count, c, syncInfo);
198 }
199 
204 template <typename T>
205 void lbann_comm::all_gather(std::vector<T> const& src,
206  std::vector<T>& rcs,
207  std::vector<int> const& rcv_counts,
208  std::vector<int> const& rcv_disp,
209  const El::mpi::Comm& c) const
210 {
211  if (src.size() == 0) {
212  std::ostringstream err;
213  err << __FILE__ << " " << __LINE__ << " :: "
214  << "all_gather for vector<>: vector.size() == 0;\n"
215  << "this doesn't work!";
216  lbann_comm_abort(err.str());
217  }
218  El::mpi::AllGather(src.data(),
219  src.size(),
220  rcs.data(),
221  rcv_counts.data(),
222  rcv_disp.data(),
223  c,
224  El::SyncInfo<El::Device::CPU>{});
225 }
230 template <typename T>
231 void lbann_comm::trainer_all_gather(std::vector<T> const& src,
232  std::vector<T>& rcs,
233  std::vector<int> const& rcv_counts,
234  std::vector<int> const& rcv_disp) const
235 {
236  all_gather(src, rcs, rcv_counts, rcv_disp, get_trainer_comm());
237 }
242 template <typename T>
243 void lbann_comm::all_gather(T const& src,
244  std::vector<T>& data,
245  const El::mpi::Comm& c) const
246 {
247  El::mpi::AllGather(&src,
248  1,
249  data.data(),
250  1,
251  c,
252  El::SyncInfo<El::Device::CPU>{});
253 }
258 template <typename T>
259 void lbann_comm::world_all_gather(T const& src, std::vector<T>& data) const
260 {
261  all_gather(src, data, get_world_comm());
262 }
267 template <typename T>
268 void lbann_comm::trainer_all_gather(T const& src, std::vector<T>& data) const
269 {
270  all_gather(src, data, get_trainer_comm());
271 }
272 
274 template <typename T>
275 void lbann_comm::trainer_gather(const T snd, const int root) const
276 {
277  gather(snd, root, m_trainer_comm);
278 }
280 template <typename T>
281 void lbann_comm::trainer_gather(const T snd, T* rcv) const
282 {
283  gather(snd, rcv, m_trainer_comm);
284 }
286 template <typename T>
287 void lbann_comm::trainer_gather(T const* snd,
288  const int count,
289  const int root) const
290 {
291  gather(snd, count, root, m_trainer_comm);
292 }
294 template <typename T>
295 void lbann_comm::trainer_gather(T const* const snd,
296  const int count,
297  T* const rcv) const
298 {
299  gather(snd, count, rcv, m_trainer_comm);
300 }
302 template <typename T>
303 void lbann_comm::trainer_gatherv(T const* snd,
304  const int count,
305  const int root) const
306 {
307  m_bytes_sent += sizeof(T) * count;
308  El::mpi::Gather(snd, count, nullptr, nullptr, nullptr, root, m_trainer_comm);
309 }
310 template <typename T>
311 void lbann_comm::trainer_gatherv(T const* const snd,
312  const int count,
313  T* const rcv,
314  int const* const rcv_counts,
315  int const* const rcv_displacements) const
316 {
317  El::mpi::Gather(snd,
318  count,
319  rcv,
320  rcv_counts,
321  rcv_displacements,
325  sizeof(T) *
326  (std::accumulate(rcv_counts, &rcv_counts[get_procs_per_trainer()], 0) -
327  rcv_counts[get_rank_in_trainer()]);
328 }
330 template <typename T>
331 void lbann_comm::intertrainer_gather(const T snd, const int root) const
332 {
333  gather(snd, root, m_intertrainer_comm);
334 }
336 template <typename T>
337 void lbann_comm::intertrainer_gather(const T snd, std::vector<T>& rcv) const
338 {
339  gather(snd, rcv, m_intertrainer_comm);
340 }
342 template <typename T>
343 void lbann_comm::intertrainer_gather(T const* const snd,
344  const int count,
345  const int root) const
346 {
347  gather(snd, count, root, m_intertrainer_comm);
348 }
350 template <typename T>
351 void lbann_comm::intertrainer_gather(T const* const snd,
352  const int count,
353  T* const rcv) const
354 {
355  gather(snd, count, rcv, m_intertrainer_comm);
356 }
358 template <typename T>
359 void lbann_comm::gather(const T snd,
360  const int root,
361  const El::mpi::Comm& c) const
362 {
363  m_bytes_sent += sizeof(T);
364  El::mpi::Gather(&snd,
365  1,
366  (T*)nullptr,
367  0,
368  root,
369  c,
370  El::SyncInfo<El::Device::CPU>{});
371 }
373 template <typename T>
374 void lbann_comm::gather(const T snd, T* const rcv, const El::mpi::Comm& c) const
375 {
376  auto const size_c = El::mpi::Size(c);
377  auto const rank_c = El::mpi::Rank(c);
378  El::mpi::Gather(&snd, 1, rcv, 1, rank_c, c, El::SyncInfo<El::Device::CPU>{});
379  m_bytes_received += sizeof(T) * (size_c - 1);
380 }
382 template <typename T>
383 void lbann_comm::gather(const T snd,
384  std::vector<T>& rcv,
385  const El::mpi::Comm& c) const
386 {
387  gather(snd, rcv.data(), c);
388 }
390 template <typename T>
391 void lbann_comm::gather(T const* const snd,
392  const int count,
393  const int root,
394  const El::mpi::Comm& c) const
395 {
396  gather(snd, count, root, c, El::SyncInfo<El::Device::CPU>{});
397 }
398 template <typename T, El::Device D>
399 void lbann_comm::gather(T const* const snd,
400  const int count,
401  const int root,
402  const El::mpi::Comm& c,
403  El::SyncInfo<D> const& syncInfo) const
404 {
405  m_bytes_sent += sizeof(T) * count;
406  El::mpi::Gather(snd, count, (T*)nullptr, 0, root, c, syncInfo);
407 }
409 template <typename T>
410 void lbann_comm::gather(T const* const snd,
411  const int count,
412  T* const rcv,
413  const El::mpi::Comm& c) const
414 {
415  gather(snd, count, rcv, c, El::SyncInfo<El::Device::CPU>{});
416 }
417 template <typename T, El::Device D>
418 void lbann_comm::gather(T const* const snd,
419  const int count,
420  T* const rcv,
421  const El::mpi::Comm& c,
422  El::SyncInfo<D> const& syncInfo) const
423 {
424  auto const size_c = El::mpi::Size(c);
425  auto const rank_c = El::mpi::Rank(c);
426  El::mpi::Gather(snd, count, rcv, count, rank_c, c, syncInfo);
427  m_bytes_received += sizeof(T) * count * (size_c - 1);
428 }
430 template <typename T>
431 T lbann_comm::scatter(const int root, const El::mpi::Comm& c) const
432 {
433  T val = {};
434  El::mpi::Scatter((T*)nullptr,
435  1,
436  &val,
437  1,
438  root,
439  c,
440  El::SyncInfo<El::Device::CPU>{});
441  m_bytes_received += sizeof(T);
442  return val;
443 }
445 template <typename T>
446 T lbann_comm::scatter(T const* const snd, const El::mpi::Comm& c) const
447 {
448  m_bytes_sent += sizeof(T) * (El::mpi::Size(c) - 1);
449  T val = {};
450  auto root = El::mpi::Rank(c);
451  El::mpi::Scatter(snd, 1, &val, 1, root, c, El::SyncInfo<El::Device::CPU>{});
452  return val;
453 }
455 template <typename T>
457  const int root,
458  const El::mpi::Op op) const
459 {
460  reduce(snd, root, m_intertrainer_comm, op);
461 }
463 template <typename T>
464 T lbann_comm::intertrainer_reduce(const T snd, const El::mpi::Op op) const
465 {
466  return reduce(snd, m_intertrainer_comm, op);
467 }
469 template <typename T>
470 void lbann_comm::trainer_reduce(const T snd,
471  const int root,
472  const El::mpi::Op op) const
473 {
474  reduce(snd, root, m_trainer_comm, op);
475 }
477 template <typename T>
478 T lbann_comm::trainer_reduce(const T snd, const El::mpi::Op op) const
479 {
480  return reduce(snd, m_trainer_comm, op);
481 }
483 template <typename T>
484 void lbann_comm::trainer_reduce(T const* const snd,
485  const int count,
486  const int root,
487  const El::mpi::Op op) const
488 {
489  reduce(snd, count, root, m_trainer_comm, op);
490 }
492 template <typename T>
493 void lbann_comm::trainer_reduce(T const* const snd,
494  const int count,
495  T* const rcv,
496  const El::mpi::Op op) const
497 {
498  reduce(snd, count, rcv, m_trainer_comm, op);
499 }
501 template <typename T>
502 void lbann_comm::reduce(const T snd,
503  const int root,
504  const El::mpi::Comm& c,
505  const El::mpi::Op op) const
506 {
507  m_bytes_sent += sizeof(T);
508  El::mpi::Reduce(&snd,
509  (T*)NULL,
510  1,
511  op,
512  root,
513  c,
514  El::SyncInfo<El::Device::CPU>{});
515 }
517 template <typename T>
518 T lbann_comm::reduce(const T snd,
519  const El::mpi::Comm& c,
520  const El::mpi::Op op) const
521 {
522  T val = {};
523  auto const size_c = El::mpi::Size(c);
524  auto const rank_c = El::mpi::Rank(c);
525  El::mpi::Reduce(&snd,
526  &val,
527  1,
528  op,
529  rank_c,
530  c,
531  El::SyncInfo<El::Device::CPU>{});
532  m_bytes_received += sizeof(T) * (size_c - 1);
533  return val;
534 }
535 
537 // Op is "SUM"
538 template <typename T>
539 void lbann_comm::reduce(T const* const snd,
540  const int count,
541  const int root,
542  const El::mpi::Comm& c) const
543 {
544  reduce(snd, count, root, c, El::mpi::SUM, El::SyncInfo<El::Device::CPU>{});
545 }
546 template <typename T, El::Device D>
547 void lbann_comm::reduce(T const* const snd,
548  const int count,
549  const int root,
550  const El::mpi::Comm& c,
551  El::SyncInfo<D> const& syncInfo) const
552 {
553  reduce(snd, count, root, c, El::mpi::SUM, syncInfo);
554 }
555 
556 template <typename T>
557 void lbann_comm::reduce(T const* const snd,
558  const int count,
559  const int root,
560  const El::mpi::Comm& c,
561  const El::mpi::Op op) const
562 {
563  reduce(snd, count, root, c, op, El::SyncInfo<El::Device::CPU>{});
564 }
565 template <typename T, El::Device D>
566 void lbann_comm::reduce(T const* const snd,
567  const int count,
568  const int root,
569  const El::mpi::Comm& c,
570  const El::mpi::Op op,
571  El::SyncInfo<D> const& syncInfo) const
572 {
573  m_bytes_sent += sizeof(T) * count;
574  El::mpi::Reduce(snd, (T*)nullptr, count, op, root, c, syncInfo);
575 }
577 template <typename T, El::Device D>
578 void lbann_comm::reduce(T const* const snd,
579  const int count,
580  T* const rcv,
581  const El::mpi::Comm& c,
582  El::SyncInfo<D> const& syncInfo) const
583 {
584  reduce(snd, count, rcv, c, El::mpi::SUM, syncInfo);
585 }
586 template <typename T>
587 void lbann_comm::reduce(T const* const snd,
588  const int count,
589  T* const rcv,
590  const El::mpi::Comm& c) const
591 {
592  reduce(snd, count, rcv, c, El::mpi::SUM, El::SyncInfo<El::Device::CPU>{});
593 }
594 
595 template <typename T>
596 void lbann_comm::reduce(T const* const snd,
597  const int count,
598  T* const rcv,
599  const El::mpi::Comm& c,
600  const El::mpi::Op op) const
601 {
602  reduce(snd, count, rcv, c, op, El::SyncInfo<El::Device::CPU>{});
603 }
604 template <typename T, El::Device D>
605 void lbann_comm::reduce(T const* snd,
606  const int count,
607  T* const rcv,
608  const El::mpi::Comm& c,
609  El::mpi::Op op,
610  El::SyncInfo<D> const& syncInfo) const
611 {
612  if (snd == rcv) {
613  snd = (T const*)MPI_IN_PLACE;
614  }
615  auto const rank_c = El::mpi::Rank(c);
616  auto const size_c = El::mpi::Size(c);
617  El::mpi::Reduce(snd, rcv, count, op, rank_c, c, syncInfo);
618  m_bytes_received += sizeof(T) * count * (size_c - 1);
619 }
621 template <typename T>
622 T lbann_comm::intertrainer_allreduce(const T snd, const El::mpi::Op op) const
623 {
624  return allreduce(snd, m_intertrainer_comm, op);
625 }
627 template <typename T>
628 T lbann_comm::trainer_allreduce(const T snd, const El::mpi::Op op) const
629 {
630  return allreduce(snd, m_trainer_comm, op);
631 }
633 template <typename T>
634 void lbann_comm::trainer_allreduce(T const* const snd,
635  const int count,
636  T* const rcv,
637  const El::mpi::Op op) const
638 {
639  allreduce(snd, count, rcv, m_trainer_comm, op);
640 }
642 template <typename T>
644  const El::mpi::Comm& c,
645  const El::mpi::Op op) const
646 {
647  auto const size_c = El::mpi::Size(c);
648  m_bytes_sent += sizeof(T);
649  allreduce(&snd, 1, c, op);
650  m_bytes_received += sizeof(T) * (size_c - 1);
651  return snd;
652 }
653 
654 // FIXME (trb): Based on the backend choice of "MPIBackend", I'm
655 // assuming this is intended as a CPU-only call.
657 template <typename T>
658 void lbann_comm::allreduce(T const* const snd,
659  const int count,
660  T* const rcv,
661  const El::mpi::Comm& c,
662  const El::mpi::Op op) const
663 {
664  auto const size_c = El::mpi::Size(c);
665  m_bytes_sent += count * sizeof(T);
666 #ifdef LBANN_HAS_ALUMINUM
667 #ifdef LBANN_ALUMINUM_MPI_PASSTHROUGH
668  ::Al::MPIAllreduceAlgorithm algo =
669  ::Al::MPIAllreduceAlgorithm::mpi_passthrough;
670 #else
671  ::Al::MPIAllreduceAlgorithm algo = ::Al::MPIAllreduceAlgorithm::automatic;
672 #endif
673  ::Al::Allreduce<::Al::MPIBackend>(
674  snd,
675  rcv,
676  count,
677  mpi_op_to_al_op(op),
678  c.template GetComm<::Al::MPIBackend>(El::SyncInfo<El::Device::CPU>{}),
679  algo);
680 #else
681  El::mpi::AllReduce(snd, rcv, count, op, c, El::SyncInfo<El::Device::CPU>{});
682 #endif
683  m_bytes_received += count * sizeof(T) * (size_c - 1);
684 }
686 template <typename T>
688  const int count,
689  const El::mpi::Comm& c,
690  const El::mpi::Op op) const
691 {
692  auto const size_c = El::mpi::Size(c);
693  m_bytes_sent += count * sizeof(T);
694 #ifdef LBANN_HAS_ALUMINUM
695 #ifdef LBANN_ALUMINUM_MPI_PASSTHROUGH
696  ::Al::MPIAllreduceAlgorithm algo =
697  ::Al::MPIAllreduceAlgorithm::mpi_passthrough;
698 #else
699  ::Al::MPIAllreduceAlgorithm algo = ::Al::MPIAllreduceAlgorithm::automatic;
700 #endif
701  ::Al::Allreduce<::Al::MPIBackend>(
702  data,
703  count,
704  mpi_op_to_al_op(op),
705  c.template GetComm<::Al::MPIBackend>(El::SyncInfo<El::Device::CPU>{}),
706  algo);
707 #else
708  El::mpi::AllReduce(data, count, op, c, El::SyncInfo<El::Device::CPU>{});
709 #endif
710  m_bytes_received += count * sizeof(T) * (size_c - 1);
711 }
717 template <typename T>
719  const int count,
720  const El::mpi::Comm& c,
721  Al::request& req,
722  const El::mpi::Op op) const
723 {
724  m_bytes_sent += count * sizeof(T);
725 #ifdef LBANN_HAS_ALUMINUM
727  ::Al::NonblockingAllreduce<::Al::MPIBackend>(
728  data,
729  count,
730  mpi_op_to_al_op(op),
731  c.template GetComm<::Al::MPIBackend>(El::SyncInfo<El::Device::CPU>{}),
732  req.mpi_req);
733 #else
734  MPI_Iallreduce(MPI_IN_PLACE,
735  data,
736  count,
737  El::mpi::TypeMap<T>(),
738  op.op,
739  c.GetMPIComm(),
740  &(req.raw_mpi_req));
741 #endif // LBANN_HAS_ALUMINUM
742  m_bytes_received += count * sizeof(T) * (El::mpi::Size(c) - 1);
743 }
744 
746 template <typename T>
747 void lbann_comm::wait_all(std::vector<El::mpi::Request<T>>& req) const
748 {
749  El::mpi::WaitAll(req.size(), req.data());
750 }
751 
753 template <typename T>
754 void lbann_comm::wait(El::mpi::Request<T>& req) const
755 {
756  El::mpi::Wait(req);
757 }
758 
760 template <typename T>
761 void lbann_comm::send(const T* const data,
762  const int count,
763  const int trainer,
764  const int rank) const
765 {
766  send(data, count, trainer, rank, El::SyncInfo<El::Device::CPU>{});
767 }
768 template <typename T, El::Device D>
769 void lbann_comm::send(const T* const data,
770  const int count,
771  const int trainer,
772  const int rank,
773  El::SyncInfo<D> const& syncInfo) const
774 {
775  m_bytes_sent += sizeof(T) * count;
776  El::mpi::Send(data,
777  count,
778  get_world_rank(trainer, rank),
779  get_world_comm(),
780  syncInfo);
781 }
782 template <typename T, El::Device D>
783 void lbann_comm::send(const T* const data,
784  const int count,
785  const int trainer,
786  El::SyncInfo<D> const& syncInfo) const
787 {
788  send(data, count, trainer, m_rank_in_trainer, syncInfo);
789 }
790 
792 template <typename T>
793 void lbann_comm::nb_send(const T* const data,
794  const int count,
795  const int trainer,
796  const int rank,
797  El::mpi::Request<T>& req) const
798 {
799  m_bytes_sent += sizeof(T) * count;
800  El::mpi::ISend(data,
801  count,
802  get_world_rank(trainer, rank),
803  get_world_comm(),
804  req);
805 }
806 template <typename T>
807 void lbann_comm::nb_tagged_send(const T* const data,
808  const int count,
809  const int rank,
810  const int tag,
811  El::mpi::Request<T>& req,
812  const El::mpi::Comm& c) const
813 {
814  m_bytes_sent += sizeof(T) * count;
815  El::mpi::TaggedISend(data, count, rank, tag, c, req);
816 }
817 template <typename T>
818 void lbann_comm::nb_send(const T* const data,
819  const int count,
820  const int trainer,
821  El::mpi::Request<T>& req) const
822 {
823  nb_send(data, count, trainer, m_rank_in_trainer, req);
824 }
825 
827 template <typename T>
828 void lbann_comm::recv(T* const data,
829  const int count,
830  const int trainer,
831  const int rank) const
832 {
833  recv(data, count, trainer, rank, El::SyncInfo<El::Device::CPU>{});
834 }
835 template <typename T>
836 void lbann_comm::recv(T* data, const int count, const int trainer) const
837 {
838  recv(data, count, trainer, m_rank_in_trainer);
839 }
840 template <typename T>
841 void lbann_comm::recv(T* const data, const int count) const
842 {
843  recv(data, count, El::SyncInfo<El::Device::CPU>{});
844 }
845 template <typename T, El::Device D>
846 void lbann_comm::recv(T* const data,
847  const int count,
848  const int trainer,
849  const int rank,
850  El::SyncInfo<D> const& syncInfo) const
851 {
852  El::mpi::Recv(data,
853  count,
854  get_world_rank(trainer, rank),
855  get_world_comm(),
856  syncInfo);
857  m_bytes_received += sizeof(T) * count;
858 }
859 template <typename T, El::Device D>
860 void lbann_comm::recv(T* const data,
861  const int count,
862  const int trainer,
863  El::SyncInfo<D> const& syncInfo) const
864 {
865  recv(data, count, trainer, m_rank_in_trainer, syncInfo);
866 }
868 template <typename T, El::Device D>
869 void lbann_comm::recv(T* const data,
870  const int count,
871  El::SyncInfo<D> const& syncInfo) const
872 {
873  El::mpi::Recv(data, count, El::mpi::ANY_SOURCE, get_world_comm(), syncInfo);
874  m_bytes_received += sizeof(T) * count;
875 }
876 
878 template <typename T>
879 void lbann_comm::nb_recv(T* const data,
880  const int count,
881  const int trainer,
882  const int rank,
883  El::mpi::Request<T>& req) const
884 {
885  El::mpi::IRecv(data,
886  count,
887  get_world_rank(trainer, rank),
888  get_world_comm(),
889  req);
890  m_bytes_received += sizeof(T) * count;
891 }
892 template <typename T>
894  const int count,
895  const int rank,
896  const int tag,
897  El::mpi::Request<T>& req,
898  const El::mpi::Comm& c) const
899 {
900  El::mpi::TaggedIRecv(data, count, rank, tag, c, req);
901  m_bytes_received += sizeof(T) * count;
902 }
903 
904 template <typename T>
905 void lbann_comm::nb_recv(T* const data,
906  const int count,
907  const int trainer,
908  El::mpi::Request<T>& req) const
909 {
910  nb_recv(data, count, trainer, m_rank_in_trainer, req);
911 }
912 template <typename T>
913 void lbann_comm::nb_recv(T* const data,
914  const int count,
915  El::mpi::Request<T>& req) const
916 {
917  El::mpi::IRecv(data, count, El::mpi::ANY_SOURCE, get_world_comm(), req);
918  m_bytes_received += sizeof(T) * count;
919 }
920 
922 template <typename T, El::Device D>
923 void lbann_comm::sendrecv(const T* const snd,
924  const int send_count,
925  const int send_trainer,
926  const int send_rank,
927  T* const rcv,
928  const int recv_count,
929  const int recv_trainer,
930  const int recv_rank) const
931 {
932  sendrecv(snd,
933  send_count,
934  send_trainer,
935  send_rank,
936  rcv,
937  recv_count,
938  recv_trainer,
939  recv_rank,
940  El::SyncInfo<El::Device::CPU>{});
941 }
942 template <typename T, El::Device D>
943 void lbann_comm::sendrecv(const T* const snd,
944  const int send_count,
945  const int send_trainer,
946  T* const rcv,
947  const int recv_count,
948  const int recv_trainer) const
949 {
950  sendrecv(snd,
951  send_count,
952  send_trainer,
954  rcv,
955  recv_count,
956  recv_trainer,
958  El::SyncInfo<El::Device::CPU>{});
959 }
960 
961 template <typename T, El::Device D>
962 void lbann_comm::sendrecv(const T* const snd,
963  const int send_count,
964  const int send_trainer,
965  const int send_rank,
966  T* const rcv,
967  const int recv_count,
968  const int recv_trainer,
969  const int recv_rank,
970  El::SyncInfo<D> const& syncInfo) const
971 {
972  m_bytes_sent += sizeof(T) * send_count;
973  m_bytes_received += sizeof(T) * recv_count;
974  El::mpi::SendRecv(snd,
975  send_count,
976  get_world_rank(send_trainer, send_rank),
977  rcv,
978  recv_count,
979  get_world_rank(recv_trainer, recv_rank),
980  get_world_comm(),
981  syncInfo);
982 }
983 template <typename T, El::Device D>
984 void lbann_comm::sendrecv(const T* const snd,
985  const int send_count,
986  const int send_trainer,
987  T* const rcv,
988  const int recv_count,
989  const int recv_trainer,
990  El::SyncInfo<D> const& syncInfo) const
991 {
992  sendrecv(snd,
993  send_count,
994  send_trainer,
996  rcv,
997  recv_count,
998  recv_trainer,
1000  syncInfo);
1001 }
1002 
1004 template <typename T>
1005 int lbann_comm::get_count(const int trainer, const int rank) const
1006 {
1007  MPI_Status status;
1008  MPI_Probe(get_world_rank(trainer, rank),
1009  MPI_ANY_TAG,
1010  MPI_COMM_WORLD,
1011  &status);
1012  return El::mpi::GetCount<T>(status);
1013 }
1014 template <typename T>
1015 int lbann_comm::get_count(const int trainer) const
1016 {
1017  return get_count<T>(trainer, m_rank_in_trainer);
1018 }
1019 
1020 template <typename T, bool S>
1021 void lbann_comm::broadcast(const int root, T& val, const El::mpi::Comm& c) const
1022 {
1023  auto const rank_c = El::mpi::Rank(c);
1024  if (S) {
1025  // Avoid linking error from uninstantiated El::mpi routine if !S by
1026  // converting T to El::byte
1027  using TT = typename interpret_as_byte_if_needed<S, T>::type;
1028  broadcast_native<TT>(root, reinterpret_cast<TT&>(val), c);
1029  }
1030  else {
1031  broadcast_custom(root, val, c);
1032  }
1033  count_bytes_broadcast(sizeof(T), rank_c, root);
1034 }
1035 
1036 template <typename T>
1037 void lbann_comm::broadcast_native(const int root,
1038  T& val,
1039  const El::mpi::Comm& c) const
1040 {
1041  El::mpi::Broadcast(val, root, c, El::SyncInfo<El::Device::CPU>{});
1042 }
1043 
1044 template <typename T>
1045 void lbann_comm::broadcast_custom(const int root,
1046  T& val,
1047  const El::mpi::Comm& c) const
1048 {
1049  const int bytes = static_cast<int>(sizeof(T));
1050  El::mpi::Broadcast<El::byte>(reinterpret_cast<El::byte*>(&val),
1051  bytes,
1052  root,
1053  c,
1054  El::SyncInfo<El::Device::CPU>{});
1055 }
1056 
1057 template <typename T, El::Device D, bool S>
1058 void lbann_comm::broadcast(const int root,
1059  T* const data,
1060  const int count,
1061  const El::mpi::Comm& c,
1062  El::SyncInfo<D> const& syncInfo) const
1063 {
1064  auto const rank_c = El::mpi::Rank(c);
1065  const int size = static_cast<int>(S ? count : sizeof(T) * count);
1066  // Avoid linking error from uninstantiated El::mpi routine if !S by converting
1067  // T to El::byte
1068  using TT = typename interpret_as_byte_if_needed<S, T>::type;
1069  El::mpi::Broadcast<TT>(reinterpret_cast<TT*>(data), size, root, c, syncInfo);
1070  count_bytes_broadcast(sizeof(T) * count, rank_c, root);
1071 }
1072 
1074 template <>
1075 void lbann_comm::broadcast<std::string>(int root,
1076  std::string& str,
1077  const El::mpi::Comm& c) const;
1078 
1079 #ifndef LBANN_COMM_INSTANTIATE
1080 #define PROTO(T) \
1081  extern template void lbann_comm::allreduce(El::AbstractMatrix<T>& m, \
1082  const El::mpi::Comm& c, \
1083  El::mpi::Op op) const; \
1084  extern template void lbann_comm::allreduce(El::AbstractDistMatrix<T>& m, \
1085  const El::mpi::Comm& c, \
1086  El::mpi::Op op) const; \
1087  extern template void lbann_comm::nb_allreduce(El::AbstractMatrix<T>& m, \
1088  const El::mpi::Comm& c, \
1089  Al::request& req, \
1090  El::mpi::Op op) const; \
1091  extern template void lbann_comm::nb_allreduce(El::AbstractDistMatrix<T>& m, \
1092  const El::mpi::Comm& c, \
1093  Al::request& req, \
1094  El::mpi::Op op) const
1095 
1096 #define LBANN_INSTANTIATE_CPU_HALF
1097 #define LBANN_INSTANTIATE_GPU_HALF
1099 #undef PROTO
1100 #undef LBANN_INSTANTIATE_CPU_HALF
1101 #undef LBANN_INSTANTIATE_GPU_HALF
1102 #endif // LBANN_COMM_INSTANTIATE
1103 
1104 } // namespace lbann
1105 
1106 #endif // LBANN_COMM_IMPL_HPP_INCLUDED
int get_rank_in_trainer() const noexcept
Definition: comm.hpp:157
void nb_tagged_recv(T *data, int count, int rank, int tag, El::mpi::Request< T > &req, const El::mpi::Comm &c) const
Definition: comm_impl.hpp:893
int get_count(int trainer, int rank) const
Definition: comm_impl.hpp:1005
void reduce(T snd, int root, const El::mpi::Comm &c, El::mpi::Op op=El::mpi::SUM) const
Definition: comm_impl.hpp:502
void trainer_all_gather(std::vector< T > const &src, std::vector< T > &rcs, std::vector< int > const &rcv_counts, std::vector< int > const &rcv_disp) const
Definition: comm_impl.hpp:231
void intertrainer_gather(T snd, int root) const
Definition: comm_impl.hpp:331
void nb_tagged_send(const T *data, int count, int rank, int tag, El::mpi::Request< T > &req, const El::mpi::Comm &c) const
Definition: comm_impl.hpp:807
void gather(T snd, int root, const El::mpi::Comm &c) const
Definition: comm_impl.hpp:359
void nb_recv(T *data, int count, int trainer, int rank, El::mpi::Request< T > &req) const
Definition: comm_impl.hpp:879
El::mpi::Comm m_trainer_comm
Definition: comm.hpp:945
void trainer_broadcast(int root, T &val) const
Within-trainer broadcast of a scalar.
Definition: comm_impl.hpp:48
const El::mpi::Comm & get_intertrainer_comm() const noexcept
Definition: comm.hpp:883
void nb_send(const T *data, int count, int trainer, int rank, El::mpi::Request< T > &req) const
Definition: comm_impl.hpp:793
size_t m_bytes_sent
Definition: comm.hpp:998
T & data(const cnpy::NpyArray &na, const std::vector< size_t > indices)
Definition: cnpy_utils.hpp:75
void nb_allreduce(El::AbstractMatrix< TensorDataType > &m, const El::mpi::Comm &c, Al::request &req, El::mpi::Op op=El::mpi::SUM) const
static const mpi_req_type mpi_null_req
void broadcast_custom(int root, T &val, const El::mpi::Comm &c) const
Definition: comm_impl.hpp:1045
void broadcast_native(int root, T &val, const El::mpi::Comm &c) const
Definition: comm_impl.hpp:1037
size_t resize(const int root, std::vector< T > &data, const El::mpi::Comm &c) const
Definition: comm_impl.hpp:123
void trainer_reduce(T snd, int root, El::mpi::Op op=El::mpi::SUM) const
Definition: comm_impl.hpp:470
void trainer_gather(T snd, int root) const
Definition: comm_impl.hpp:275
void send(const T *data, int count, int trainer, int rank) const
Definition: comm_impl.hpp:761
void count_bytes_broadcast(const size_t bytes, const int rank, const int root) const noexcept
Definition: comm.hpp:1017
void trainer_gatherv(T const *snd, int count, int root) const
Definition: comm_impl.hpp:303
void all_gather(const T *src, int src_count, T *rcv, int rcv_count, const El::mpi::Comm &c) const
Definition: comm_impl.hpp:176
El::mpi::Comm m_intertrainer_comm
Definition: comm.hpp:947
T allreduce(T snd, const El::mpi::Comm &c, El::mpi::Op op=El::mpi::SUM) const
Definition: comm_impl.hpp:643
void world_broadcast(int root, T &val) const
World broadcast of a scalar.
Definition: comm_impl.hpp:36
void intertrainer_broadcast(int root, T &val) const
Inter-trainer broadcast of a scalar.
Definition: comm_impl.hpp:42
size_t m_bytes_received
Definition: comm.hpp:999
void world_all_gather(T const &src, std::vector< T > &data) const
Definition: comm_impl.hpp:259
User-facing class that represents a set of compute resources.
Definition: trainer.hpp:60
T intertrainer_allreduce(T snd, El::mpi::Op op=El::mpi::SUM) const
Definition: comm_impl.hpp:622
const El::mpi::Comm & get_trainer_comm() const noexcept
Definition: comm.hpp:889
void recv(T *data, int count, int trainer, int rank) const
Definition: comm_impl.hpp:828
T scatter(int root, const El::mpi::Comm &c) const
Definition: comm_impl.hpp:431
int get_world_rank(int trainer, int rank) const noexcept
Definition: comm.hpp:164
int get_procs_per_trainer() const noexcept
Definition: comm.hpp:222
const El::mpi::Comm & get_world_comm() const noexcept
Definition: comm.hpp:901
int m_rank_in_trainer
Definition: comm.hpp:967
T trainer_allreduce(T snd, El::mpi::Op op=El::mpi::SUM) const
Definition: comm_impl.hpp:628
void sendrecv(const T *snd, int send_count, int send_trainer, int send_rank, T *rcv, int recv_count, int recv_trainer, int recv_rank) const
Definition: comm_impl.hpp:923
void broadcast(int root, T &val, const El::mpi::Comm &c) const
Broadcast a scalar value over an arbitrary communicator.
Definition: comm_impl.hpp:1021
void lbann_comm_abort(std::string msg) const
void wait_all(std::vector< El::mpi::Request< T >> &req) const
Definition: comm_impl.hpp:747
void wait(El::mpi::Request< T > &req) const
Definition: comm_impl.hpp:754
void intertrainer_reduce(T snd, int root, El::mpi::Op op=El::mpi::SUM) const
Definition: comm_impl.hpp:456