LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
dnn_lib.hpp
Go to the documentation of this file.
1 
27 #ifndef LBANN_UTILS_DNN_LIB_DNN_LIB_HPP
28 #define LBANN_UTILS_DNN_LIB_DNN_LIB_HPP
29 
30 #include "lbann/base.hpp"
32 #include "lbann/layers/layer.hpp"
35 #include <vector>
36 
37 #include "lbann/proto/layers.pb.h"
38 
39 #ifdef LBANN_HAS_DNN_LIB
40 
41 namespace lbann {
42 namespace dnn_lib {
43 
44 #if defined LBANN_HAS_CUDNN
45 using namespace cudnn;
46 #elif defined LBANN_HAS_MIOPEN
47 using namespace miopen;
48 #endif // LBANN_HAS_CUDNN
49 
50 template <typename T>
51 struct ScalingParameterT
52 {
53  using type = T;
54 };
55 
56 template <typename T>
57 using ScalingParamType = typename ScalingParameterT<T>::type;
58 
59 #ifdef LBANN_HAS_GPU_FP16
60 template <>
61 struct ScalingParameterT<fp16>
62 {
63  using type = float;
64 };
65 #endif // LBANN_USE_GPU_FP16
66 
68 // Global DNN library objects
70 
72 void initialize();
74 void destroy();
79 dnnHandle_t& get_handle();
80 
82 // Helper functions for DNN library types
84 
86 template <typename TensorDataType>
87 dnnDataType_t get_data_type();
88 
90 // Wrapper classes for DNN library types
92 
93 template <typename T>
94 using BackendHandleType = typename T::handle_type;
95 
97 class TensorDescriptor
98 {
99 public:
100  using handle_type = dnnTensorDescriptor_t;
101 
102 public:
103  explicit TensorDescriptor(dnnTensorDescriptor_t desc = nullptr);
104 
105  ~TensorDescriptor();
106 
107  // Copy-and-swap idiom
108  TensorDescriptor(const TensorDescriptor&);
109  TensorDescriptor(TensorDescriptor&&);
110  TensorDescriptor& operator=(TensorDescriptor);
111  friend void swap(TensorDescriptor& first, TensorDescriptor& second);
112 
114  void reset(dnnTensorDescriptor_t desc = nullptr);
116  dnnTensorDescriptor_t release() noexcept;
118  dnnTensorDescriptor_t get() const noexcept;
120  operator dnnTensorDescriptor_t() const noexcept;
121 
126  void create();
131  void set(dnnDataType_t data_type,
132  std::vector<int> dims,
133  std::vector<int> strides = {});
138  template <typename... IntTs>
139  void set(dnnDataType_t data_type, IntTs... dims)
140  {
141  set(data_type, {static_cast<int>(dims)...});
142  }
143 #if !(defined LBANN_HAS_CUDNN)
144  void set(dnnDataType_t data_type,
145  dnnTensorFormat_t /*format*/,
146  const std::vector<int>& dims)
147  {
148  this->set(data_type, dims);
149  }
150 #endif // !LBANN_HAS_CUDNN
151 
152 private:
153  dnnTensorDescriptor_t desc_ = nullptr;
154 };
155 
156 #ifdef LBANN_HAS_CUDNN
157 
158 class FilterDescriptor
159 {
160 public:
161  using handle_type = dnnFilterDescriptor_t;
162 
163 public:
164  explicit FilterDescriptor(dnnFilterDescriptor_t desc = nullptr);
165 
166  ~FilterDescriptor();
167 
168  // Copy-and-swap idiom
169  FilterDescriptor(const FilterDescriptor&);
170  FilterDescriptor(FilterDescriptor&&);
171  FilterDescriptor& operator=(FilterDescriptor);
172  friend void swap(FilterDescriptor& first, FilterDescriptor& second);
173 
175  void reset(dnnFilterDescriptor_t desc = nullptr);
177  dnnFilterDescriptor_t release() noexcept;
179  dnnFilterDescriptor_t get() const noexcept;
181  operator dnnFilterDescriptor_t() const noexcept;
182 
187  void create();
192  void set(dnnDataType_t data_type,
193  dnnTensorFormat_t format,
194  const std::vector<int>& dims);
199  template <typename... IntTs>
200  void set(dnnDataType_t data_type, dnnTensorFormat_t format, IntTs... dims)
201  {
202  set(data_type, format, {static_cast<int>(dims)...});
203  }
204 
205 private:
206  dnnFilterDescriptor_t desc_ = nullptr;
207 };
208 #else // MIOpen and OneDNN
209 using FilterDescriptor = TensorDescriptor;
210 #endif // LBANN_HAS_CUDNN
211 
213 class DropoutDescriptor
214 {
215 
216 public:
217  explicit DropoutDescriptor(dnnDropoutDescriptor_t desc = nullptr);
218  DropoutDescriptor(float dropout,
219  void* states,
220  size_t states_size,
221  unsigned long long seed,
222  bool use_mask,
223  bool state_evo,
224  dnnRNGType_t rng_mode)
225  {
226  this
227  ->set(dropout, states, states_size, seed, use_mask, state_evo, rng_mode);
228  }
229 
230  ~DropoutDescriptor();
231 
232  // Copy-and-swap idiom
233  DropoutDescriptor(const DropoutDescriptor&);
234  DropoutDescriptor(DropoutDescriptor&&);
235  DropoutDescriptor& operator=(DropoutDescriptor);
236  friend void swap(DropoutDescriptor& first, DropoutDescriptor& second);
237 
239  void reset(dnnDropoutDescriptor_t desc = nullptr);
241  dnnDropoutDescriptor_t release() noexcept;
243  dnnDropoutDescriptor_t get() const noexcept;
245  operator dnnDropoutDescriptor_t() const noexcept;
246 
251  void create();
256  void set(float dropout,
257  void* states,
258  size_t states_size,
259  unsigned long long seed,
260  bool use_mask = false,
261  bool state_evo = false,
262  dnnRNGType_t rng_mode = DNN_RNG_PSEUDO_XORWOW);
263 
264 private:
265  dnnDropoutDescriptor_t desc_ = nullptr;
266 };
267 
269 class RNNDescriptor
270 {
271 
272 public:
273  explicit RNNDescriptor(dnnRNNDescriptor_t desc = nullptr);
274 
275  RNNDescriptor(const RNNDescriptor&) = delete;
276  ~RNNDescriptor();
277 
278  // Copy-and-swap idiom
279  RNNDescriptor(RNNDescriptor&&);
280  RNNDescriptor& operator=(RNNDescriptor);
281  friend void swap(RNNDescriptor& first, RNNDescriptor& second);
282 
284  void reset(dnnRNNDescriptor_t desc = nullptr);
286  dnnRNNDescriptor_t release() noexcept;
288  dnnRNNDescriptor_t get() const noexcept;
290  operator dnnRNNDescriptor_t() const noexcept;
291 
296  void create();
301  void set(dnnRNNAlgo_t algorithm,
302  dnnRNNMode_t cell_mode,
303  dnnRNNBiasMode_t bias_mode,
304  dnnDirectionMode_t direction_mode,
305  dnnRNNInputMode_t input_mode,
306  dnnDataType_t data_type,
307  dnnDataType_t math_precision,
308  dnnMathType_t math_type,
309  size_t input_size,
310  size_t hidden_size,
311  size_t proj_size,
312  size_t num_layers,
313  dnnDropoutDescriptor_t dropout_desc,
314  uint32_t aux_flags);
315 
316 private:
317  dnnRNNDescriptor_t desc_ = nullptr;
318 };
319 
321 class ConvolutionDescriptor
322 {
323 public:
325  using DescriptorHandle_t = dnnConvolutionDescriptor_t;
326 
327 public:
329 
332  explicit ConvolutionDescriptor(DescriptorHandle_t desc = nullptr);
333 
335  ~ConvolutionDescriptor();
336 
341  ConvolutionDescriptor(const ConvolutionDescriptor&);
343  ConvolutionDescriptor(ConvolutionDescriptor&&);
344 
346  ConvolutionDescriptor& operator=(ConvolutionDescriptor);
347 
349 
350 
353  DescriptorHandle_t release() noexcept;
355  DescriptorHandle_t get() const noexcept;
359  operator DescriptorHandle_t() const noexcept;
360 
362 
363 
366  void swap(ConvolutionDescriptor& other);
367 
369  void reset(DescriptorHandle_t desc = nullptr);
370 
375  void create();
376 
381  void set(std::vector<int> const& pad,
382  std::vector<int> const& stride,
383  std::vector<int> const& dilation,
384  dnnDataType_t data_type,
385  dnnConvolutionMode_t mode = DNN_CROSS_CORRELATION);
386  void set(size_t array_dim,
387  int const pad[],
388  int const stride[],
389  int const dilation[],
390  dnnDataType_t data_type,
391  dnnConvolutionMode_t mode = DNN_CROSS_CORRELATION);
392 
394  void set_math_mode(dnnMathType_t math_type);
395 
397  void set_group_count(int num_groups);
398 
400 
401 private:
402  DescriptorHandle_t desc_ = nullptr;
403 };
404 
406 void swap(ConvolutionDescriptor& lhs, ConvolutionDescriptor& rhs);
407 
409 class PoolingDescriptor
410 {
411 public:
413  using DescriptorHandle_t = dnnPoolingDescriptor_t;
414 
415 public:
417 
420  explicit PoolingDescriptor(DescriptorHandle_t desc = nullptr);
421 
423  ~PoolingDescriptor();
424 
429  PoolingDescriptor(const PoolingDescriptor&);
431  PoolingDescriptor(PoolingDescriptor&&);
432 
434  PoolingDescriptor& operator=(PoolingDescriptor);
435 
437 
438 
441  DescriptorHandle_t release() noexcept;
443  DescriptorHandle_t get() const noexcept;
447  operator DescriptorHandle_t() const noexcept;
448 
450 
451 
454  void swap(PoolingDescriptor& other);
455 
457  void reset(DescriptorHandle_t desc = nullptr);
458 
463  void create();
468  void set(pooling_mode mode,
469  dnnNanPropagation_t maxpoolingNanOpt,
470  std::vector<int> const& window_dims,
471  std::vector<int> const& padding,
472  std::vector<int> const& stride);
473  void set(pooling_mode mode,
474  dnnNanPropagation_t nan_prop,
475  int num_dims,
476  int const window_dims[],
477  int const padding[],
478  int const stride[]);
479 
481 
482 private:
483  DescriptorHandle_t desc_ = nullptr;
484 };
485 
487 void swap(PoolingDescriptor& lhs, PoolingDescriptor& rhs);
488 
490 class LRNDescriptor
491 {
492 public:
494  using DescriptorHandle_t = dnnLRNDescriptor_t;
495 
496 public:
498 
501  explicit LRNDescriptor(DescriptorHandle_t desc = nullptr);
502 
504  ~LRNDescriptor();
505 
510  LRNDescriptor(const LRNDescriptor&);
512  LRNDescriptor(LRNDescriptor&&);
513 
515  LRNDescriptor& operator=(LRNDescriptor);
516 
518 
519 
522  DescriptorHandle_t release() noexcept;
524  DescriptorHandle_t get() const noexcept;
528  operator DescriptorHandle_t() const noexcept;
529 
531 
532 
535  void swap(LRNDescriptor& other);
536 
538  void reset(DescriptorHandle_t desc = nullptr);
539 
544  void create();
549  void set(unsigned n,
550  double alpha,
551  double beta,
552  double k,
553  dnnLRNMode_t mode = DNN_LRN_CROSS_CHANNEL);
554 
556 
557 private:
558  DescriptorHandle_t desc_ = nullptr;
559 };
560 
562 void swap(LRNDescriptor& lhs, LRNDescriptor& rhs);
563 
565 // DNN library tensor managers
567 
569 template <typename TensorDataType>
570 class layer_tensor_manager
571 {
572 public:
573  using LayerType = data_type_layer<TensorDataType>;
574 
575 public:
576  layer_tensor_manager(const LayerType* l = nullptr);
577  virtual ~layer_tensor_manager() = default;
578 
580  const LayerType* get_layer() const { return m_layer; }
582  void set_layer(const LayerType* l);
583 
585  virtual TensorDescriptor& get_prev_activations(int parent_index = 0) = 0;
587  virtual TensorDescriptor& get_activations(int child_index = 0) = 0;
589  virtual TensorDescriptor& get_prev_error_signals(int child_index = 0) = 0;
591  virtual TensorDescriptor& get_error_signals(int parent_index = 0) = 0;
592 
593 protected:
594  layer_tensor_manager(const layer_tensor_manager&) = default;
595  layer_tensor_manager& operator=(const layer_tensor_manager&) = default;
596  layer_tensor_manager(layer_tensor_manager&&) = default;
597  layer_tensor_manager& operator=(layer_tensor_manager&&) = default;
598 
600  void set_num_parents(int num_parents);
602  void set_num_children(int num_children);
603 
605  const LayerType* m_layer;
607  std::vector<TensorDescriptor> m_prev_activations;
609  std::vector<TensorDescriptor> m_activations;
611  std::vector<TensorDescriptor> m_prev_error_signals;
613  std::vector<TensorDescriptor> m_error_signals;
614 };
615 
617 template <typename TensorDataType>
618 class data_parallel_layer_tensor_manager
619  : public layer_tensor_manager<TensorDataType>
620 {
621 public:
622  using LayerType = data_type_layer<TensorDataType>;
623 
624 public:
625  data_parallel_layer_tensor_manager(const LayerType* l = nullptr);
626  data_parallel_layer_tensor_manager(
627  const data_parallel_layer_tensor_manager&) = default;
628  data_parallel_layer_tensor_manager&
629  operator=(const data_parallel_layer_tensor_manager&) = default;
630  data_parallel_layer_tensor_manager(data_parallel_layer_tensor_manager&&) =
631  default;
632  data_parallel_layer_tensor_manager&
633  operator=(data_parallel_layer_tensor_manager&&) = default;
634  ~data_parallel_layer_tensor_manager() = default;
635  TensorDescriptor& get_prev_activations(int parent_index = 0) override;
636  TensorDescriptor& get_activations(int child_index = 0) override;
637  TensorDescriptor& get_prev_error_signals(int child_index = 0) override;
638  TensorDescriptor& get_error_signals(int parent_index = 0) override;
639 };
640 
642 template <typename TensorDataType>
643 class entrywise_layer_tensor_manager
644  : public layer_tensor_manager<TensorDataType>
645 {
646 public:
647  using LayerType = data_type_layer<TensorDataType>;
648 
649 public:
650  entrywise_layer_tensor_manager(const LayerType* l = nullptr);
651  entrywise_layer_tensor_manager(const entrywise_layer_tensor_manager&) =
652  default;
653  entrywise_layer_tensor_manager&
654  operator=(const entrywise_layer_tensor_manager&) = default;
655  entrywise_layer_tensor_manager(entrywise_layer_tensor_manager&&) = default;
656  entrywise_layer_tensor_manager&
657  operator=(entrywise_layer_tensor_manager&&) = default;
658  ~entrywise_layer_tensor_manager() = default;
659  TensorDescriptor& get_prev_activations(int parent_index = 0) override;
660  TensorDescriptor& get_activations(int child_index = 0) override;
661  TensorDescriptor& get_prev_error_signals(int child_index = 0) override;
662  TensorDescriptor& get_error_signals(int parent_index = 0) override;
663 };
664 
666 // DNN library algorithm selection
668 
679 fwd_conv_alg get_fwd_algorithm(bool autotune,
680  bool deterministic,
681  const TensorDescriptor& input_desc,
682  const void* input,
683  const FilterDescriptor& kernel_desc,
684  const void* kernel,
685  const ConvolutionDescriptor& conv_desc,
686  const TensorDescriptor& output_desc,
687  void* output,
688  size_t ws_size,
689  void* ws);
690 
701 get_bwd_data_algorithm(bool autotune,
702  bool deterministic,
703  const FilterDescriptor& kernel_desc,
704  const void* kernel,
705  const TensorDescriptor& prev_error_signal_desc,
706  const void* prev_error_signal,
707  const ConvolutionDescriptor& conv_desc,
708  const TensorDescriptor& error_signal_desc,
709  void* error_signal,
710  size_t ws_size,
711  void* ws);
712 
723 get_bwd_filter_algorithm(bool autotune,
724  bool deterministic,
725  const TensorDescriptor& input_desc,
726  const void* input,
727  const TensorDescriptor& prev_error_signal_desc,
728  const void* prev_error_signal,
729  const ConvolutionDescriptor& conv_desc,
730  const FilterDescriptor& kernel_gradient_desc,
731  void* kernel_gradient,
732  size_t ws_size,
733  void* ws);
734 
738 void default_to_tensor_ops() noexcept;
739 
744 dnnMathType_t get_default_convolution_math_type() noexcept;
745 
746 using ProtoTensorOpEnumType = decltype(lbann_data::DEFAULT_TENSOR_OPS);
748 dnnMathType_t convert_to_dnn_math_type(ProtoTensorOpEnumType mt);
750 ProtoTensorOpEnumType convert_to_proto_math_type(dnnMathType_t mt);
751 
752 } // namespace dnn_lib
753 } // namespace lbann
754 #endif // LBANN_HAS_DNN_LIB
755 #endif // LBANN_UTILS_DNN_LIB_DNN_LIB_HPP
std::basic_string< T > pad(const std::basic_string< T > &s, typename std::basic_string< T >::size_type n, T c)
Definition: file_utils.hpp:93
bwd_data_conv_alg
Which backward convolution algorithm to use.
Definition: dnn_enums.hpp:45
bwd_filter_conv_alg
Which backward convolution filter algorithm to use.
Definition: dnn_enums.hpp:57
pooling_mode
Which pooling mode to use.
Definition: dnn_enums.hpp:78
fwd_conv_alg
Which forward convolution algorithm to use.
Definition: dnn_enums.hpp:32
std::string get()
world_comm_ptr initialize(int &argc, char **&argv)