26 #ifndef LBANN_SRC_LAYERS_TRANSFORM_CUTENSOR_SUPPORT_HPP_INCLUDED 27 #define LBANN_SRC_LAYERS_TRANSFORM_CUTENSOR_SUPPORT_HPP_INCLUDED 32 #include <cuda_runtime.h> 35 #define CHECK_CUTENSOR(cmd) \ 37 auto const lbann_chk_cutensor_status__ = (cmd); \ 38 if (CUTENSOR_STATUS_SUCCESS != lbann_chk_cutensor_status__) { \ 39 LBANN_ERROR("cuTENSOR error (status=", \ 40 lbann_chk_cutensor_status__, \ 42 cutensorGetErrorString(lbann_chk_cutensor_status__)); \ 48 template <
typename CppType>
55 static constexpr
auto value = CUDA_R_16F;
62 static constexpr
auto value = CUDA_R_32F;
68 static constexpr
auto value = CUDA_R_64F;
74 static constexpr
auto value = CUDA_C_32F;
80 static constexpr
auto value = CUDA_C_64F;
83 template <
typename CppType>
86 template <
typename CppType>
89 template <
typename CppType>
94 cutensorHandle_t handle;
106 #endif // LBANN_SRC_LAYERS_TRANSFORM_CUTENSOR_SUPPORT_HPP_INCLUDED static cutensorHandle_t * get_handle_ptr()
El::Complex< float > scalar_type
#define CHECK_CUTENSOR(cmd)
static cutensorHandle_t make_handle()
constexpr auto CUDAScalarType
typename CUDATypeT< CppType >::scalar_type CUDAScalar
El::Complex< double > scalar_type