27 #ifndef LBANN_LAYERS_LEARNING_BASE_CONVOLUTION_HPP_INCLUDED 28 #define LBANN_LAYERS_LEARNING_BASE_CONVOLUTION_HPP_INCLUDED 32 #ifdef LBANN_HAS_DNN_LIB 35 #endif // LBANN_HAS_DNN_LIB 40 #ifdef LBANN_HAS_DISTCONV 41 #include "distconv/dnn_backend/convolution.hpp" 47 #ifdef LBANN_HAS_DISTCONV 49 using Backend = ::distconv::BackendDNNLib;
50 template <
typename TensorDataType>
51 using Convolution = ::distconv::Convolution<Backend, TensorDataType>;
54 template <
typename TensorDataType, El::Device Device>
55 class base_convolution_adapter
56 :
public data_type_distconv_adapter<TensorDataType>
62 base_convolution_adapter(Layer& layer)
63 : data_type_distconv_adapter<TensorDataType>(layer)
65 virtual ~base_convolution_adapter() =
default;
67 void setup_fp_tensors()
override;
68 void setup_bp_tensors()
override;
69 void setup_layer(
size_t workspace_capacity)
override;
70 std::unique_ptr<TensorDevType>
71 setup_error_signals_i(
int index)
const override;
73 void fp_compute_convolution();
76 void bp_compute_convolution_data();
77 void bp_compute_convolution_filter();
79 std::unique_ptr<dc::Convolution<TensorDataType>> m_conv;
80 std::unique_ptr<TensorDevType> m_kernel;
81 std::unique_ptr<TensorDevType> m_bias;
82 std::unique_ptr<TensorDevType> m_kernel_gradient;
83 std::unique_ptr<TensorDevType> m_bias_gradient;
85 std::string m_fwd_algo;
86 std::string m_bwd_data_algo;
87 std::string m_bwd_filter_algo;
89 #endif // LBANN_HAS_DISTCONV 93 template <
typename TensorDataType, El::Device Device>
106 template <El::Device D>
107 using DMatDT = El::Matrix<TensorDataType, D>;
109 #ifdef LBANN_HAS_DNN_LIB 110 using ScalingType = dnn_lib::ScalingParamType<TensorDataType>;
113 #endif // LBANN_HAS_DNN_LIB 119 const std::vector<int>&
get_pads()
const {
return m_pads; }
147 #ifdef LBANN_HAS_DNN_LIB 152 dnn_lib::dnnMathType_t m_convolution_math_type =
153 dnn_lib::get_default_convolution_math_type();
155 dnn_lib::FilterDescriptor m_kernel_dnn_desc;
157 dnn_lib::ConvolutionDescriptor m_convolution_dnn_desc;
159 dnn_lib::TensorDescriptor m_bias_dnn_desc;
164 std::unordered_map<int, fwd_conv_alg> m_fwd_dnn_algos;
166 std::unordered_map<int, bwd_data_conv_alg> m_bwd_data_dnn_algos;
168 std::unordered_map<int, bwd_filter_conv_alg> m_bwd_filter_dnn_algos;
170 #endif // LBANN_HAS_DNN_LIB 176 std::vector<int> conv_dims,
177 std::vector<int> pads,
178 std::vector<int> strides,
179 std::vector<int> dilations,
189 #ifdef LBANN_HAS_DNN_LIB 190 void set_dnn_math_mode(dnn_lib::dnnMathType_t math_type) noexcept;
191 #endif // LBANN_HAS_DNN_LIB 194 void setup_dims()
override;
199 void setup_data(
size_t max_mini_batch_size)
override;
202 void setup_gpu()
override;
207 template <
typename ArchiveT>
214 virtual std::vector<int> get_kernel_dims()
const = 0;
217 void apply_convolution_dnn(
bool during_forward_prop);
220 void apply_transposed_convolution_dnn(
bool during_forward_prop);
222 void apply_bias_dnn();
223 void compute_gradients_dnn(
bool using_transposed_convolution);
226 void apply_convolution_im2col(
bool during_forward_prop);
229 void apply_transposed_convolution_im2col(
bool during_forward_prop);
231 void apply_bias_cpu();
233 void compute_gradients_im2col(
bool using_transposed_convolution);
236 #ifdef LBANN_HAS_DNN_LIB 240 get_forward_algo_dnn(
const int local_mini_batch_size,
241 const dnn_lib::TensorDescriptor& input_desc,
242 const TensorDataType* input,
243 const dnn_lib::FilterDescriptor& kernel_desc,
244 const TensorDataType* kernel,
245 const dnn_lib::ConvolutionDescriptor& conv_desc,
246 const dnn_lib::TensorDescriptor& output_desc,
247 TensorDataType* output,
253 const int local_mini_batch_size,
254 const dnn_lib::FilterDescriptor& kernel_desc,
255 const TensorDataType* kernel,
256 const dnn_lib::TensorDescriptor& prev_error_signal_desc,
257 const TensorDataType* prev_error_signal,
258 const dnn_lib::ConvolutionDescriptor& conv_desc,
259 const dnn_lib::TensorDescriptor& error_signal_desc,
260 TensorDataType* error_signal,
269 const int local_mini_batch_size,
270 const dnn_lib::TensorDescriptor& input_desc,
271 const TensorDataType* input,
272 const dnn_lib::TensorDescriptor& prev_error_signal_desc,
273 const TensorDataType* prev_error_signal,
274 const dnn_lib::ConvolutionDescriptor& conv_desc,
275 const dnn_lib::FilterDescriptor& kernel_gradient_desc,
278 #endif // LBANN_HAS_DNN_LIB 280 #ifdef LBANN_HAS_DISTCONV 281 friend class base_convolution_adapter<TensorDataType,
Device>;
284 using BaseConvAdapterType = base_convolution_adapter<TensorDataType, Device>;
285 void setup_distconv_adapter()
override;
286 BaseConvAdapterType& get_distconv_adapter()
override;
287 const BaseConvAdapterType& get_distconv_adapter()
const override;
288 #endif // LBANN_HAS_DISTCONV 292 #endif // LBANN_LAYERS_LEARNING_BASE_CONVOLUTION_HPP_INCLUDED
const std::vector< int > & get_pads() const
std::vector< int > m_conv_dims
Spatial dimensions for convolution kernel.
ScalingType m_bias_scaling_factor
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
Generates nicely formatted description messages.
constexpr El::Device Device
bwd_data_conv_alg
Which backward convolution algorithm to use.
bwd_filter_conv_alg
Which backward convolution filter algorithm to use.
const std::vector< int > & get_conv_dims() const
Get convolutional layer parameters.
const std::vector< int > & get_dilations() const
fwd_conv_alg
Which forward convolution algorithm to use.
std::vector< int > m_dilations
std::vector< int > m_pads
El::Matrix< TensorDataType, D > DMatDT
TensorDataType ScalingType
Computation kernels for convolution and deconvolution layers.
const std::vector< int > & get_strides() const
dc::TensorDev< OutputTensorDataType > TensorDevType
std::vector< int > m_strides