27 #ifndef LBANN_UTILS_DNN_LIB_CUDNN_HPP 28 #define LBANN_UTILS_DNN_LIB_CUDNN_HPP 32 #ifdef LBANN_HAS_CUDNN 37 #define CHECK_CUDNN_NODEBUG(cudnn_call) \ 39 const cudnnStatus_t status_CHECK_CUDNN = (cudnn_call); \ 40 if (status_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \ 41 LBANN_ERROR("cuDNN error (", \ 42 cudnnGetErrorString(status_CHECK_CUDNN), \ 46 #define CHECK_CUDNN_DEBUG(cudnn_call) \ 48 LBANN_CUDA_CHECK_LAST_ERROR(true); \ 49 CHECK_CUDNN_NODEBUG(cudnn_call); \ 52 #define CHECK_CUDNN(cudnn_call) CHECK_CUDNN_DEBUG(cudnn_call) 54 #define CHECK_CUDNN(cudnn_call) CHECK_CUDNN_NODEBUG(cudnn_call) 55 #endif // #ifdef LBANN_DEBUG 57 #define CHECK_CUDNN_DTOR(cudnn_call) \ 59 CHECK_CUDNN(cudnn_call); \ 61 catch (std::exception const& e) { \ 62 std::cerr << "Caught exception:\n\n what(): " << e.what() \ 63 << "\n\nCalling std::terminate() now." << std::endl; \ 67 std::cerr << "Caught something that isn't an std::exception.\n\n" \ 68 << "Calling std::terminate() now." << std::endl; \ 79 using dnnHandle_t = cudnnHandle_t;
80 using dnnDataType_t = cudnnDataType_t;
81 using dnnTensorDescriptor_t = cudnnTensorDescriptor_t;
82 using dnnFilterDescriptor_t = cudnnFilterDescriptor_t;
83 using dnnTensorFormat_t = cudnnTensorFormat_t;
84 using dnnDropoutDescriptor_t = cudnnDropoutDescriptor_t;
85 using dnnRNGType_t = int;
86 using dnnRNNDescriptor_t = cudnnRNNDescriptor_t;
87 using dnnRNNAlgo_t = cudnnRNNAlgo_t;
88 using dnnRNNMode_t = cudnnRNNMode_t;
89 using dnnRNNBiasMode_t = cudnnRNNBiasMode_t;
90 using dnnDirectionMode_t = cudnnDirectionMode_t;
91 using dnnRNNInputMode_t = cudnnRNNInputMode_t;
92 using dnnMathType_t = cudnnMathType_t;
93 using dnnRNNDataDescriptor_t = cudnnRNNDataDescriptor_t;
94 using dnnRNNDataLayout_t = cudnnRNNDataLayout_t;
95 using dnnConvolutionDescriptor_t = cudnnConvolutionDescriptor_t;
96 using dnnConvolutionMode_t = cudnnConvolutionMode_t;
97 using dnnActivationDescriptor_t = cudnnActivationDescriptor_t;
98 using dnnActivationMode_t = cudnnActivationMode_t;
99 using dnnNanPropagation_t = cudnnNanPropagation_t;
100 using dnnPoolingDescriptor_t = cudnnPoolingDescriptor_t;
101 using dnnPoolingMode_t = cudnnPoolingMode_t;
102 using dnnLRNDescriptor_t = cudnnLRNDescriptor_t;
103 using dnnLRNMode_t = cudnnLRNMode_t;
104 using dnnConvolutionFwdAlgo_t = cudnnConvolutionFwdAlgo_t;
105 using dnnConvolutionBwdDataAlgo_t = cudnnConvolutionBwdDataAlgo_t;
106 using dnnConvolutionBwdFilterAlgo_t = cudnnConvolutionBwdFilterAlgo_t;
108 constexpr dnnConvolutionMode_t DNN_CROSS_CORRELATION = CUDNN_CROSS_CORRELATION;
109 constexpr dnnNanPropagation_t DNN_PROPAGATE_NAN = CUDNN_PROPAGATE_NAN;
110 constexpr dnnMathType_t DNN_DEFAULT_MATH = CUDNN_DEFAULT_MATH;
111 constexpr dnnTensorFormat_t DNN_TENSOR_NCHW = CUDNN_TENSOR_NCHW;
112 constexpr dnnRNGType_t DNN_RNG_PSEUDO_XORWOW = 0;
113 constexpr dnnLRNMode_t DNN_LRN_CROSS_CHANNEL = CUDNN_LRN_CROSS_CHANNEL_DIM1;
114 constexpr dnnMathType_t DNN_TENSOR_OP_MATH_ALLOW_CONVERSION =
115 CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
123 inline cudnnConvolutionFwdAlgo_t to_cudnn(
fwd_conv_alg a)
127 return CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
129 return CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
131 return CUDNN_CONVOLUTION_FWD_ALGO_GEMM;
133 return CUDNN_CONVOLUTION_FWD_ALGO_DIRECT;
135 return CUDNN_CONVOLUTION_FWD_ALGO_FFT;
137 return CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING;
139 return CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD;
141 return CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED;
143 LBANN_ERROR(
"Invalid forward convolution algorithm requested.");
149 inline fwd_conv_alg from_cudnn(cudnnConvolutionFwdAlgo_t a)
152 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
154 case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
156 case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
158 case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
160 case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
162 case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
164 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
166 case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
169 LBANN_ERROR(
"Invalid forward convolution algorithm requested.");
179 return CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
181 return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
183 return CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT;
185 return CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING;
187 return CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD;
189 return CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED;
191 LBANN_ERROR(
"Invalid backward convolution algorithm requested.");
200 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
202 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
204 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
206 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
208 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
210 case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
213 LBANN_ERROR(
"Invalid backward convolution algorithm requested.");
223 return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
225 return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
227 return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT;
229 return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3;
231 return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED;
233 return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING;
235 LBANN_ERROR(
"Invalid backward convolution filter requested.");
244 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
246 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
248 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
250 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
252 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
254 case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
257 LBANN_ERROR(
"Invalid backward convolution filter requested.");
265 #ifdef LBANN_DETERMINISTIC 266 return CUDNN_POOLING_MAX_DETERMINISTIC;
268 return CUDNN_POOLING_MAX;
269 #endif // LBANN_DETERMINISTIC 271 return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
273 return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
275 return CUDNN_POOLING_MAX_DETERMINISTIC;
284 case CUDNN_POOLING_MAX:
286 case CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING:
288 case CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING:
290 case CUDNN_POOLING_MAX_DETERMINISTIC:
302 return CUDNN_SOFTMAX_MODE_INSTANCE;
304 return CUDNN_SOFTMAX_MODE_CHANNEL;
312 inline cudnnSoftmaxAlgorithm_t to_cudnn(
softmax_alg alg)
316 return CUDNN_SOFTMAX_FAST;
318 return CUDNN_SOFTMAX_ACCURATE;
320 return CUDNN_SOFTMAX_LOG;
322 LBANN_ERROR(
"Invalid softmax algorithm requested.");
330 using namespace cudnn;
333 class RNNDataDescriptor
337 RNNDataDescriptor(dnnRNNDataDescriptor_t desc =
nullptr);
339 ~RNNDataDescriptor();
342 RNNDataDescriptor(
const RNNDataDescriptor&) =
delete;
343 RNNDataDescriptor(RNNDataDescriptor&&);
344 RNNDataDescriptor& operator=(RNNDataDescriptor);
345 friend void swap(RNNDataDescriptor& first, RNNDataDescriptor& second);
348 void reset(dnnRNNDataDescriptor_t desc =
nullptr);
350 dnnRNNDataDescriptor_t release();
352 dnnRNNDataDescriptor_t
get()
const noexcept;
354 operator dnnRNNDataDescriptor_t() const noexcept;
366 void set(dnnDataType_t data_type,
367 dnnRNNDataLayout_t layout,
368 size_t max_seq_length,
371 const
int seq_length_array[],
375 dnnRNNDataDescriptor_t desc_{
nullptr};
382 #endif // LBANN_HAS_CUDNN 383 #endif // LBANN_UTILS_DNN_LIB_CUDNN_HPP
bwd_data_conv_alg
Which backward convolution algorithm to use.
bwd_filter_conv_alg
Which backward convolution filter algorithm to use.
softmax_alg
Internal LBANN names for supported softmax algorithms.
pooling_mode
Which pooling mode to use.
fwd_conv_alg
Which forward convolution algorithm to use.
softmax_mode
Which tensor dimensions to apply softmax over.