27 #ifndef LBANN_UTILS_DNN_LIB_DNN_LIB_HPP 28 #define LBANN_UTILS_DNN_LIB_DNN_LIB_HPP 37 #include "lbann/proto/layers.pb.h" 39 #ifdef LBANN_HAS_DNN_LIB 44 #if defined LBANN_HAS_CUDNN 45 using namespace cudnn;
46 #elif defined LBANN_HAS_MIOPEN 47 using namespace miopen;
48 #endif // LBANN_HAS_CUDNN 51 struct ScalingParameterT
57 using ScalingParamType =
typename ScalingParameterT<T>::type;
59 #ifdef LBANN_HAS_GPU_FP16 61 struct ScalingParameterT<fp16>
65 #endif // LBANN_USE_GPU_FP16 79 dnnHandle_t& get_handle();
86 template <
typename TensorDataType>
87 dnnDataType_t get_data_type();
94 using BackendHandleType =
typename T::handle_type;
97 class TensorDescriptor
100 using handle_type = dnnTensorDescriptor_t;
103 explicit TensorDescriptor(dnnTensorDescriptor_t desc =
nullptr);
108 TensorDescriptor(
const TensorDescriptor&);
109 TensorDescriptor(TensorDescriptor&&);
110 TensorDescriptor& operator=(TensorDescriptor);
111 friend void swap(TensorDescriptor& first, TensorDescriptor& second);
114 void reset(dnnTensorDescriptor_t desc =
nullptr);
116 dnnTensorDescriptor_t release() noexcept;
118 dnnTensorDescriptor_t
get() const noexcept;
120 operator dnnTensorDescriptor_t() const noexcept;
131 void set(dnnDataType_t data_type,
132 std::vector<
int> dims,
133 std::vector<
int> strides = {});
138 template <
typename... IntTs>
139 void set(dnnDataType_t data_type, IntTs... dims)
141 set(data_type, {
static_cast<int>(dims)...});
143 #if !(defined LBANN_HAS_CUDNN) 144 void set(dnnDataType_t data_type,
146 const std::vector<int>& dims)
148 this->
set(data_type, dims);
150 #endif // !LBANN_HAS_CUDNN 153 dnnTensorDescriptor_t desc_ =
nullptr;
156 #ifdef LBANN_HAS_CUDNN 158 class FilterDescriptor
161 using handle_type = dnnFilterDescriptor_t;
164 explicit FilterDescriptor(dnnFilterDescriptor_t desc =
nullptr);
169 FilterDescriptor(
const FilterDescriptor&);
170 FilterDescriptor(FilterDescriptor&&);
171 FilterDescriptor& operator=(FilterDescriptor);
172 friend void swap(FilterDescriptor& first, FilterDescriptor& second);
175 void reset(dnnFilterDescriptor_t desc =
nullptr);
177 dnnFilterDescriptor_t release() noexcept;
179 dnnFilterDescriptor_t
get() const noexcept;
181 operator dnnFilterDescriptor_t() const noexcept;
192 void set(dnnDataType_t data_type,
193 dnnTensorFormat_t format,
194 const std::vector<
int>& dims);
199 template <typename... IntTs>
200 void set(dnnDataType_t data_type, dnnTensorFormat_t format, IntTs... dims)
202 set(data_type, format, {
static_cast<int>(dims)...});
206 dnnFilterDescriptor_t desc_ =
nullptr;
208 #else // MIOpen and OneDNN 209 using FilterDescriptor = TensorDescriptor;
210 #endif // LBANN_HAS_CUDNN 213 class DropoutDescriptor
217 explicit DropoutDescriptor(dnnDropoutDescriptor_t desc =
nullptr);
218 DropoutDescriptor(
float dropout,
221 unsigned long long seed,
224 dnnRNGType_t rng_mode)
227 ->set(dropout, states, states_size, seed, use_mask, state_evo, rng_mode);
230 ~DropoutDescriptor();
233 DropoutDescriptor(
const DropoutDescriptor&);
234 DropoutDescriptor(DropoutDescriptor&&);
235 DropoutDescriptor& operator=(DropoutDescriptor);
236 friend void swap(DropoutDescriptor& first, DropoutDescriptor& second);
239 void reset(dnnDropoutDescriptor_t desc =
nullptr);
241 dnnDropoutDescriptor_t release() noexcept;
243 dnnDropoutDescriptor_t
get() const noexcept;
245 operator dnnDropoutDescriptor_t() const noexcept;
256 void set(
float dropout,
259 unsigned long long seed,
260 bool use_mask = false,
261 bool state_evo = false,
262 dnnRNGType_t rng_mode = DNN_RNG_PSEUDO_XORWOW);
265 dnnDropoutDescriptor_t desc_ =
nullptr;
273 explicit RNNDescriptor(dnnRNNDescriptor_t desc =
nullptr);
275 RNNDescriptor(
const RNNDescriptor&) =
delete;
279 RNNDescriptor(RNNDescriptor&&);
280 RNNDescriptor& operator=(RNNDescriptor);
281 friend void swap(RNNDescriptor& first, RNNDescriptor& second);
284 void reset(dnnRNNDescriptor_t desc =
nullptr);
286 dnnRNNDescriptor_t release() noexcept;
288 dnnRNNDescriptor_t
get() const noexcept;
290 operator dnnRNNDescriptor_t() const noexcept;
301 void set(dnnRNNAlgo_t algorithm,
302 dnnRNNMode_t cell_mode,
303 dnnRNNBiasMode_t bias_mode,
304 dnnDirectionMode_t direction_mode,
305 dnnRNNInputMode_t input_mode,
306 dnnDataType_t data_type,
307 dnnDataType_t math_precision,
308 dnnMathType_t math_type,
313 dnnDropoutDescriptor_t dropout_desc,
317 dnnRNNDescriptor_t desc_ =
nullptr;
321 class ConvolutionDescriptor
325 using DescriptorHandle_t = dnnConvolutionDescriptor_t;
332 explicit ConvolutionDescriptor(DescriptorHandle_t desc =
nullptr);
335 ~ConvolutionDescriptor();
341 ConvolutionDescriptor(
const ConvolutionDescriptor&);
343 ConvolutionDescriptor(ConvolutionDescriptor&&);
346 ConvolutionDescriptor& operator=(ConvolutionDescriptor);
353 DescriptorHandle_t release() noexcept;
355 DescriptorHandle_t
get() const noexcept;
359 operator DescriptorHandle_t() const noexcept;
366 void swap(ConvolutionDescriptor& other);
369 void reset(DescriptorHandle_t desc =
nullptr);
381 void set(std::vector<
int> const&
pad,
382 std::vector<
int> const& stride,
383 std::vector<
int> const& dilation,
384 dnnDataType_t data_type,
385 dnnConvolutionMode_t mode = DNN_CROSS_CORRELATION);
386 void set(
size_t array_dim,
389 int const dilation[],
390 dnnDataType_t data_type,
391 dnnConvolutionMode_t mode = DNN_CROSS_CORRELATION);
394 void set_math_mode(dnnMathType_t math_type);
397 void set_group_count(
int num_groups);
402 DescriptorHandle_t desc_ =
nullptr;
406 void swap(ConvolutionDescriptor& lhs, ConvolutionDescriptor& rhs);
409 class PoolingDescriptor
413 using DescriptorHandle_t = dnnPoolingDescriptor_t;
420 explicit PoolingDescriptor(DescriptorHandle_t desc =
nullptr);
423 ~PoolingDescriptor();
429 PoolingDescriptor(
const PoolingDescriptor&);
431 PoolingDescriptor(PoolingDescriptor&&);
434 PoolingDescriptor& operator=(PoolingDescriptor);
441 DescriptorHandle_t release() noexcept;
443 DescriptorHandle_t
get() const noexcept;
447 operator DescriptorHandle_t() const noexcept;
454 void swap(PoolingDescriptor& other);
457 void reset(DescriptorHandle_t desc =
nullptr);
469 dnnNanPropagation_t maxpoolingNanOpt,
470 std::vector<
int> const& window_dims,
471 std::vector<
int> const& padding,
472 std::vector<
int> const& stride);
474 dnnNanPropagation_t nan_prop,
476 int const window_dims[],
483 DescriptorHandle_t desc_ =
nullptr;
487 void swap(PoolingDescriptor& lhs, PoolingDescriptor& rhs);
494 using DescriptorHandle_t = dnnLRNDescriptor_t;
501 explicit LRNDescriptor(DescriptorHandle_t desc =
nullptr);
510 LRNDescriptor(
const LRNDescriptor&);
512 LRNDescriptor(LRNDescriptor&&);
515 LRNDescriptor& operator=(LRNDescriptor);
522 DescriptorHandle_t release() noexcept;
524 DescriptorHandle_t
get() const noexcept;
528 operator DescriptorHandle_t() const noexcept;
535 void swap(LRNDescriptor& other);
538 void reset(DescriptorHandle_t desc =
nullptr);
553 dnnLRNMode_t mode = DNN_LRN_CROSS_CHANNEL);
558 DescriptorHandle_t desc_ =
nullptr;
562 void swap(LRNDescriptor& lhs, LRNDescriptor& rhs);
569 template <typename TensorDataType>
570 class layer_tensor_manager
573 using LayerType = data_type_layer<TensorDataType>;
576 layer_tensor_manager(
const LayerType* l =
nullptr);
577 virtual ~layer_tensor_manager() =
default;
580 const LayerType* get_layer()
const {
return m_layer; }
582 void set_layer(
const LayerType* l);
585 virtual TensorDescriptor& get_prev_activations(
int parent_index = 0) = 0;
587 virtual TensorDescriptor& get_activations(
int child_index = 0) = 0;
589 virtual TensorDescriptor& get_prev_error_signals(
int child_index = 0) = 0;
591 virtual TensorDescriptor& get_error_signals(
int parent_index = 0) = 0;
594 layer_tensor_manager(
const layer_tensor_manager&) =
default;
595 layer_tensor_manager& operator=(
const layer_tensor_manager&) =
default;
596 layer_tensor_manager(layer_tensor_manager&&) =
default;
597 layer_tensor_manager& operator=(layer_tensor_manager&&) =
default;
600 void set_num_parents(
int num_parents);
602 void set_num_children(
int num_children);
605 const LayerType* m_layer;
607 std::vector<TensorDescriptor> m_prev_activations;
609 std::vector<TensorDescriptor> m_activations;
611 std::vector<TensorDescriptor> m_prev_error_signals;
613 std::vector<TensorDescriptor> m_error_signals;
617 template <
typename TensorDataType>
618 class data_parallel_layer_tensor_manager
619 :
public layer_tensor_manager<TensorDataType>
622 using LayerType = data_type_layer<TensorDataType>;
625 data_parallel_layer_tensor_manager(
const LayerType* l =
nullptr);
626 data_parallel_layer_tensor_manager(
627 const data_parallel_layer_tensor_manager&) =
default;
628 data_parallel_layer_tensor_manager&
629 operator=(
const data_parallel_layer_tensor_manager&) =
default;
630 data_parallel_layer_tensor_manager(data_parallel_layer_tensor_manager&&) =
632 data_parallel_layer_tensor_manager&
633 operator=(data_parallel_layer_tensor_manager&&) =
default;
634 ~data_parallel_layer_tensor_manager() =
default;
635 TensorDescriptor& get_prev_activations(
int parent_index = 0)
override;
636 TensorDescriptor& get_activations(
int child_index = 0)
override;
637 TensorDescriptor& get_prev_error_signals(
int child_index = 0)
override;
638 TensorDescriptor& get_error_signals(
int parent_index = 0)
override;
642 template <
typename TensorDataType>
643 class entrywise_layer_tensor_manager
644 :
public layer_tensor_manager<TensorDataType>
647 using LayerType = data_type_layer<TensorDataType>;
650 entrywise_layer_tensor_manager(
const LayerType* l =
nullptr);
651 entrywise_layer_tensor_manager(
const entrywise_layer_tensor_manager&) =
653 entrywise_layer_tensor_manager&
654 operator=(
const entrywise_layer_tensor_manager&) =
default;
655 entrywise_layer_tensor_manager(entrywise_layer_tensor_manager&&) =
default;
656 entrywise_layer_tensor_manager&
657 operator=(entrywise_layer_tensor_manager&&) =
default;
658 ~entrywise_layer_tensor_manager() =
default;
659 TensorDescriptor& get_prev_activations(
int parent_index = 0)
override;
660 TensorDescriptor& get_activations(
int child_index = 0)
override;
661 TensorDescriptor& get_prev_error_signals(
int child_index = 0)
override;
662 TensorDescriptor& get_error_signals(
int parent_index = 0)
override;
681 const TensorDescriptor& input_desc,
683 const FilterDescriptor& kernel_desc,
685 const ConvolutionDescriptor& conv_desc,
686 const TensorDescriptor& output_desc,
701 get_bwd_data_algorithm(
bool autotune,
703 const FilterDescriptor& kernel_desc,
705 const TensorDescriptor& prev_error_signal_desc,
706 const void* prev_error_signal,
707 const ConvolutionDescriptor& conv_desc,
708 const TensorDescriptor& error_signal_desc,
723 get_bwd_filter_algorithm(
bool autotune,
725 const TensorDescriptor& input_desc,
727 const TensorDescriptor& prev_error_signal_desc,
728 const void* prev_error_signal,
729 const ConvolutionDescriptor& conv_desc,
730 const FilterDescriptor& kernel_gradient_desc,
731 void* kernel_gradient,
738 void default_to_tensor_ops() noexcept;
744 dnnMathType_t get_default_convolution_math_type() noexcept;
746 using ProtoTensorOpEnumType = decltype(
lbann_data::DEFAULT_TENSOR_OPS);
748 dnnMathType_t convert_to_dnn_math_type(ProtoTensorOpEnumType mt);
750 ProtoTensorOpEnumType convert_to_proto_math_type(dnnMathType_t mt);
754 #endif // LBANN_HAS_DNN_LIB 755 #endif // LBANN_UTILS_DNN_LIB_DNN_LIB_HPP std::basic_string< T > pad(const std::basic_string< T > &s, typename std::basic_string< T >::size_type n, T c)
bwd_data_conv_alg
Which backward convolution algorithm to use.
bwd_filter_conv_alg
Which backward convolution filter algorithm to use.
pooling_mode
Which pooling mode to use.
fwd_conv_alg
Which forward convolution algorithm to use.
world_comm_ptr initialize(int &argc, char **&argv)