LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
cudnn.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_UTILS_DNN_LIB_CUDNN_HPP
28 #define LBANN_UTILS_DNN_LIB_CUDNN_HPP
29 
31 
32 #ifdef LBANN_HAS_CUDNN
33 
34 #include <cudnn.h>
35 
36 // Error utility macros
37 #define CHECK_CUDNN_NODEBUG(cudnn_call) \
38  do { \
39  const cudnnStatus_t status_CHECK_CUDNN = (cudnn_call); \
40  if (status_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \
41  LBANN_ERROR("cuDNN error (", \
42  cudnnGetErrorString(status_CHECK_CUDNN), \
43  ")"); \
44  } \
45  } while (0)
46 #define CHECK_CUDNN_DEBUG(cudnn_call) \
47  do { \
48  LBANN_CUDA_CHECK_LAST_ERROR(true); \
49  CHECK_CUDNN_NODEBUG(cudnn_call); \
50  } while (0)
51 #ifdef LBANN_DEBUG
52 #define CHECK_CUDNN(cudnn_call) CHECK_CUDNN_DEBUG(cudnn_call)
53 #else
54 #define CHECK_CUDNN(cudnn_call) CHECK_CUDNN_NODEBUG(cudnn_call)
55 #endif // #ifdef LBANN_DEBUG
56 
57 #define CHECK_CUDNN_DTOR(cudnn_call) \
58  try { \
59  CHECK_CUDNN(cudnn_call); \
60  } \
61  catch (std::exception const& e) { \
62  std::cerr << "Caught exception:\n\n what(): " << e.what() \
63  << "\n\nCalling std::terminate() now." << std::endl; \
64  std::terminate(); \
65  } \
66  catch (...) { \
67  std::cerr << "Caught something that isn't an std::exception.\n\n" \
68  << "Calling std::terminate() now." << std::endl; \
69  std::terminate(); \
70  }
71 
72 namespace lbann {
73 
74 // Forward declaration
75 class Layer;
76 
77 namespace cudnn {
78 
79 using dnnHandle_t = cudnnHandle_t;
80 using dnnDataType_t = cudnnDataType_t;
81 using dnnTensorDescriptor_t = cudnnTensorDescriptor_t;
82 using dnnFilterDescriptor_t = cudnnFilterDescriptor_t;
83 using dnnTensorFormat_t = cudnnTensorFormat_t;
84 using dnnDropoutDescriptor_t = cudnnDropoutDescriptor_t;
85 using dnnRNGType_t = int;
86 using dnnRNNDescriptor_t = cudnnRNNDescriptor_t;
87 using dnnRNNAlgo_t = cudnnRNNAlgo_t;
88 using dnnRNNMode_t = cudnnRNNMode_t;
89 using dnnRNNBiasMode_t = cudnnRNNBiasMode_t;
90 using dnnDirectionMode_t = cudnnDirectionMode_t;
91 using dnnRNNInputMode_t = cudnnRNNInputMode_t;
92 using dnnMathType_t = cudnnMathType_t;
93 using dnnRNNDataDescriptor_t = cudnnRNNDataDescriptor_t;
94 using dnnRNNDataLayout_t = cudnnRNNDataLayout_t;
95 using dnnConvolutionDescriptor_t = cudnnConvolutionDescriptor_t;
96 using dnnConvolutionMode_t = cudnnConvolutionMode_t;
97 using dnnActivationDescriptor_t = cudnnActivationDescriptor_t;
98 using dnnActivationMode_t = cudnnActivationMode_t;
99 using dnnNanPropagation_t = cudnnNanPropagation_t;
100 using dnnPoolingDescriptor_t = cudnnPoolingDescriptor_t;
101 using dnnPoolingMode_t = cudnnPoolingMode_t;
102 using dnnLRNDescriptor_t = cudnnLRNDescriptor_t;
103 using dnnLRNMode_t = cudnnLRNMode_t;
104 using dnnConvolutionFwdAlgo_t = cudnnConvolutionFwdAlgo_t;
105 using dnnConvolutionBwdDataAlgo_t = cudnnConvolutionBwdDataAlgo_t;
106 using dnnConvolutionBwdFilterAlgo_t = cudnnConvolutionBwdFilterAlgo_t;
107 
108 constexpr dnnConvolutionMode_t DNN_CROSS_CORRELATION = CUDNN_CROSS_CORRELATION;
109 constexpr dnnNanPropagation_t DNN_PROPAGATE_NAN = CUDNN_PROPAGATE_NAN;
110 constexpr dnnMathType_t DNN_DEFAULT_MATH = CUDNN_DEFAULT_MATH;
111 constexpr dnnTensorFormat_t DNN_TENSOR_NCHW = CUDNN_TENSOR_NCHW;
112 constexpr dnnRNGType_t DNN_RNG_PSEUDO_XORWOW = 0;
113 constexpr dnnLRNMode_t DNN_LRN_CROSS_CHANNEL = CUDNN_LRN_CROSS_CHANNEL_DIM1;
114 constexpr dnnMathType_t DNN_TENSOR_OP_MATH_ALLOW_CONVERSION =
115  CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
116 
118 // Functions for to/from cuDNN types conversion
120 
123 inline cudnnConvolutionFwdAlgo_t to_cudnn(fwd_conv_alg a)
124 {
125  switch (a) {
127  return CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
129  return CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
130  case fwd_conv_alg::GEMM:
131  return CUDNN_CONVOLUTION_FWD_ALGO_GEMM;
133  return CUDNN_CONVOLUTION_FWD_ALGO_DIRECT;
134  case fwd_conv_alg::FFT:
135  return CUDNN_CONVOLUTION_FWD_ALGO_FFT;
137  return CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING;
139  return CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD;
141  return CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED;
142  default:
143  LBANN_ERROR("Invalid forward convolution algorithm requested.");
144  }
145 }
146 
149 inline fwd_conv_alg from_cudnn(cudnnConvolutionFwdAlgo_t a)
150 {
151  switch (a) {
152  case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM:
154  case CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM:
156  case CUDNN_CONVOLUTION_FWD_ALGO_GEMM:
157  return fwd_conv_alg::GEMM;
158  case CUDNN_CONVOLUTION_FWD_ALGO_DIRECT:
159  return fwd_conv_alg::DIRECT;
160  case CUDNN_CONVOLUTION_FWD_ALGO_FFT:
161  return fwd_conv_alg::FFT;
162  case CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING:
164  case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD:
165  return fwd_conv_alg::WINOGRAD;
166  case CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED:
168  default:
169  LBANN_ERROR("Invalid forward convolution algorithm requested.");
170  }
171 }
172 
175 inline cudnnConvolutionBwdDataAlgo_t to_cudnn(bwd_data_conv_alg a)
176 {
177  switch (a) {
179  return CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
181  return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
183  return CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT;
185  return CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING;
187  return CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD;
189  return CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED;
190  default:
191  LBANN_ERROR("Invalid backward convolution algorithm requested.");
192  }
193 }
194 
197 inline bwd_data_conv_alg from_cudnn(cudnnConvolutionBwdDataAlgo_t a)
198 {
199  switch (a) {
200  case CUDNN_CONVOLUTION_BWD_DATA_ALGO_0:
202  case CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
204  case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
205  return bwd_data_conv_alg::FFT;
206  case CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
208  case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD:
210  case CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED:
212  default:
213  LBANN_ERROR("Invalid backward convolution algorithm requested.");
214  }
215 }
216 
219 inline cudnnConvolutionBwdFilterAlgo_t to_cudnn(bwd_filter_conv_alg a)
220 {
221  switch (a) {
223  return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
225  return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
227  return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT;
229  return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3;
231  return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED;
233  return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING;
234  default:
235  LBANN_ERROR("Invalid backward convolution filter requested.");
236  }
237 }
238 
241 inline bwd_filter_conv_alg from_cudnn(cudnnConvolutionBwdFilterAlgo_t a)
242 {
243  switch (a) {
244  case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0:
246  case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1:
248  case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT:
250  case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3:
252  case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED:
254  case CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING:
256  default:
257  LBANN_ERROR("Invalid backward convolution filter requested.");
258  }
259 }
260 
261 inline cudnnPoolingMode_t to_cudnn(pooling_mode m)
262 {
263  switch (m) {
264  case pooling_mode::MAX:
265 #ifdef LBANN_DETERMINISTIC
266  return CUDNN_POOLING_MAX_DETERMINISTIC;
267 #else
268  return CUDNN_POOLING_MAX;
269 #endif // LBANN_DETERMINISTIC
271  return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
273  return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
275  return CUDNN_POOLING_MAX_DETERMINISTIC;
276  default:
277  LBANN_ERROR("Invalid pooling mode requested.");
278  }
279 }
280 
281 inline pooling_mode from_cudnn(cudnnPoolingMode_t m)
282 {
283  switch (m) {
284  case CUDNN_POOLING_MAX:
285  return pooling_mode::MAX;
286  case CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING:
288  case CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING:
290  case CUDNN_POOLING_MAX_DETERMINISTIC:
292  default:
293  LBANN_ERROR("Invalid pooling mode requested.");
294  }
295 }
296 
298 inline cudnnSoftmaxMode_t to_cudnn(softmax_mode m)
299 {
300  switch (m) {
302  return CUDNN_SOFTMAX_MODE_INSTANCE;
304  return CUDNN_SOFTMAX_MODE_CHANNEL;
306  default:
307  LBANN_ERROR("Invalid softmax mode requested.");
308  }
309 }
310 
312 inline cudnnSoftmaxAlgorithm_t to_cudnn(softmax_alg alg)
313 {
314  switch (alg) {
315  case softmax_alg::FAST:
316  return CUDNN_SOFTMAX_FAST;
318  return CUDNN_SOFTMAX_ACCURATE;
319  case softmax_alg::LOG:
320  return CUDNN_SOFTMAX_LOG;
321  default:
322  LBANN_ERROR("Invalid softmax algorithm requested.");
323  }
324 }
325 
326 } // namespace cudnn
327 
328 namespace dnn_lib {
329 
330 using namespace cudnn;
331 
333 class RNNDataDescriptor
334 {
335 
336 public:
337  RNNDataDescriptor(dnnRNNDataDescriptor_t desc = nullptr);
338 
339  ~RNNDataDescriptor();
340 
341  // Copy-and-swap idiom
342  RNNDataDescriptor(const RNNDataDescriptor&) = delete;
343  RNNDataDescriptor(RNNDataDescriptor&&);
344  RNNDataDescriptor& operator=(RNNDataDescriptor);
345  friend void swap(RNNDataDescriptor& first, RNNDataDescriptor& second);
346 
348  void reset(dnnRNNDataDescriptor_t desc = nullptr);
350  dnnRNNDataDescriptor_t release();
352  dnnRNNDataDescriptor_t get() const noexcept;
354  operator dnnRNNDataDescriptor_t() const noexcept;
355 
360  void create();
361 
366  void set(dnnDataType_t data_type,
367  dnnRNNDataLayout_t layout,
368  size_t max_seq_length,
369  size_t batch_size,
370  size_t vector_size,
371  const int seq_length_array[],
372  void* padding_fill);
373 
374 private:
375  dnnRNNDataDescriptor_t desc_{nullptr};
376 };
377 
378 } // namespace dnn_lib
379 
380 } // namespace lbann
381 
382 #endif // LBANN_HAS_CUDNN
383 #endif // LBANN_UTILS_DNN_LIB_CUDNN_HPP
Position-wise softmax.
#define LBANN_ERROR(...)
Definition: exception.hpp:37
Sample-wise softmax.
bwd_data_conv_alg
Which backward convolution algorithm to use.
Definition: dnn_enums.hpp:45
bwd_filter_conv_alg
Which backward convolution filter algorithm to use.
Definition: dnn_enums.hpp:57
softmax_alg
Internal LBANN names for supported softmax algorithms.
Definition: dnn_enums.hpp:110
pooling_mode
Which pooling mode to use.
Definition: dnn_enums.hpp:78
fwd_conv_alg
Which forward convolution algorithm to use.
Definition: dnn_enums.hpp:32
softmax_mode
Which tensor dimensions to apply softmax over.
Definition: dnn_enums.hpp:87