27 #ifndef LBANN_UTILS_NVSHMEM_HPP_INCLUDED 28 #define LBANN_UTILS_NVSHMEM_HPP_INCLUDED 31 #ifdef LBANN_HAS_NVSHMEM 35 #define NVSHMEM_USE_NCCL 43 bool is_initialized() noexcept;
46 bool is_finalized() noexcept;
53 bool is_active() noexcept;
61 void initialize(MPI_Comm comm = MPI_COMM_WORLD);
74 template <typename T =
void>
75 T* malloc(
size_t size);
81 template <typename T =
void>
82 T* realloc(T* ptr,
size_t size);
95 T* malloc(
size_t size)
101 CHECK_CUDA(cudaDeviceSynchronize());
102 auto* ptr = nvshmem_malloc(size *
sizeof(T));
103 if (ptr ==
nullptr) {
104 LBANN_ERROR(
"NVSHMEM failed to allocate a GPU buffer ",
105 "from the symmetric heap ",
110 return reinterpret_cast<T*
>(ptr);
113 template <
typename T>
114 T* realloc(T* ptr,
size_t size)
119 if (ptr !=
nullptr) {
122 return malloc<T>(size);
128 #endif // LBANN_HAS_NVSHMEM 130 #endif // LBANN_UTILS_NVSHMEM_HPP_INCLUDED
world_comm_ptr initialize(int &argc, char **&argv)
void finalize(lbann_comm *comm=nullptr)