27 #ifndef LBANN_UTILS_DNN_LIB_ONEDNN_HPP 28 #define LBANN_UTILS_DNN_LIB_ONEDNN_HPP 34 #include <h2/meta/Core.hpp> 35 #include <h2/meta/TypeList.hpp> 37 #ifdef LBANN_HAS_ONEDNN 55 struct IsSupportedTypeT : std::false_type
59 #define ADD_ONEDNN_TYPE_MAP(CPPTYPE, EVAL) \ 61 struct TypeMapT<CPPTYPE> \ 62 : std::integral_constant<dnnl::memory::data_type, \ 63 dnnl::memory::data_type::EVAL> \ 67 struct IsSupportedTypeT<CPPTYPE> : std::true_type \ 72 ADD_ONEDNN_TYPE_MAP(
float, f32);
73 ADD_ONEDNN_TYPE_MAP(int32_t, s32);
74 ADD_ONEDNN_TYPE_MAP(int8_t, s8);
75 ADD_ONEDNN_TYPE_MAP(uint8_t, u8);
79 #if defined LBANN_HAS_HALF 80 ADD_ONEDNN_TYPE_MAP(cpu_fp16, f16);
82 #if defined LBANN_HAS_GPU_FP16 83 ADD_ONEDNN_TYPE_MAP(fp16, f16);
89 using TypeMap =
typename details::TypeMapT<T>;
92 inline constexpr
auto DataTypeValue = TypeMap<T>::value;
95 inline constexpr
bool IsSupportedType = details::IsSupportedTypeT<T>::value;
97 template <
typename T, ::h2::meta::EnableWhen<IsSupportedType<T>,
int> = 0>
98 inline constexpr dnnl::memory::data_type get_data_type()
100 return DataTypeValue<T>;
103 template <
typename T, ::h2::meta::EnableUnless<IsSupportedType<T>,
int> = 0>
104 inline dnnl::memory::data_type get_data_type()
108 "\" is not supported " 109 "by the oneDNN runtime.");
112 template <El::Device D>
113 dnnl::engine& get_device_engine();
115 template <El::Device D>
116 dnnl::stream get_stream(dnnl::engine
const& e, El::SyncInfo<D>
const&);
120 template <El::Device D>
121 struct onednn_backend
123 static constexpr
auto device = D;
125 template <
typename T>
126 static auto data_type()
128 return onednn::get_data_type<T>();
131 class TensorDescriptor
135 static constexpr
auto device = D;
138 using backend_type = onednn_backend<D>;
141 using dnnTensorDescriptor_t = dnnl::memory;
144 using dnnDataType_t = dnnl::memory::data_type;
147 using dnnTensorFormat_t = dnnl::memory::format_tag;
150 using internal_descriptor_type = dnnl::memory::desc;
156 TensorDescriptor() =
default;
159 explicit TensorDescriptor(dnnTensorDescriptor_t desc)
160 : desc_{std::move(desc)}
164 ~TensorDescriptor() noexcept = default;
167 TensorDescriptor(TensorDescriptor const&) = default;
169 TensorDescriptor(TensorDescriptor&&) = default;
171 TensorDescriptor& operator=(TensorDescriptor const&) = default;
173 TensorDescriptor& operator=(TensorDescriptor&&) = default;
176 void swap(TensorDescriptor& other) { std::swap(desc_, other.desc_); }
179 void reset(dnnTensorDescriptor_t desc = dnnTensorDescriptor_t{})
181 desc_ = dnnl::memory{std::move(desc)};
185 dnnTensorDescriptor_t release() noexcept
187 dnnTensorDescriptor_t tmp = desc_;
188 desc_ = dnnl::memory{};
193 dnnTensorDescriptor_t
get()
const noexcept {
return desc_; }
196 operator dnnTensorDescriptor_t() const noexcept {
return desc_; }
202 void create() noexcept {}
204 void set(dnnDataType_t data_type,
205 dnnl::memory::dims dims,
206 dnnl::memory::dims strides = {})
211 auto md = dnnl::memory::desc(dims, data_type, strides);
213 dnnl::memory(md, onednn::get_device_engine<D>(), DNNL_MEMORY_NONE);
220 template <
typename DimT>
221 void set(dnnDataType_t data_type,
222 std::vector<DimT>
const& dims_in,
223 std::vector<DimT>
const& strides_in = {})
225 dnnl::memory::dims dims{cbegin(dims_in), cend(dims_in)};
226 dnnl::memory::dims strides{cbegin(strides_in), cend(strides_in)};
227 this->
set(data_type, dims, strides);
234 template <
typename... IntTs>
235 void set(dnnDataType_t data_type, IntTs... dims)
237 set(data_type, {
static_cast<dnnl::memory::dim
>(dims)...});
240 void set(dnnDataType_t data_type,
242 const std::vector<int>& dims)
244 this->
set(data_type, dims);
259 template <
typename DataT,
typename ScalarT>
261 TensorDescriptor
const& xDesc,
262 El::Matrix<DataT, D>
const& x,
263 ScalarT
const& beta_in,
264 TensorDescriptor
const& yDesc,
265 El::Matrix<DataT, D>& y,
266 El::SyncInfo<D>
const& si,
270 template <
typename DataT,
typename ScalarT>
271 static void logsoftmax_forward(ScalarT
const& alpha_in,
272 TensorDescriptor
const& xDesc,
273 El::Matrix<DataT, D>
const& x,
274 ScalarT
const& beta_in,
275 TensorDescriptor
const& yDesc,
276 El::Matrix<DataT, D>& y,
277 El::SyncInfo<D>
const& si,
283 template <
typename DataT,
typename ScalarT>
285 TensorDescriptor
const& yDesc,
286 El::Matrix<DataT, D>
const& y,
287 TensorDescriptor
const& dyDesc,
288 El::Matrix<DataT, D>
const& dy,
289 ScalarT
const& beta_in,
290 TensorDescriptor
const& dxDesc,
291 El::Matrix<DataT, D>& dx,
292 El::SyncInfo<D>
const& si,
296 template <
typename DataT,
typename ScalarT>
297 static void logsoftmax_backward(ScalarT
const& alpha_in,
298 TensorDescriptor
const& yDesc,
299 El::Matrix<DataT, D>
const& y,
300 TensorDescriptor
const& dyDesc,
301 El::Matrix<DataT, D>
const& dy,
302 ScalarT
const& beta_in,
303 TensorDescriptor
const& dxDesc,
304 El::Matrix<DataT, D>& dx,
305 El::SyncInfo<D>
const& si,
315 #endif // LBANN_HAS_ONEDNN 316 #endif // LBANN_UTILS_DNN_LIB_ONEDNN_HPP void softmax_backward(ScalarT const &alpha_in, TensorDescT const &yDesc, El::Matrix< DataT, D > const &y, TensorDescT const &dyDesc, El::Matrix< DataT, D > const &dy, ScalarT const &beta_in, TensorDescT const &dxDesc, El::Matrix< DataT, D > &dx, El::SyncInfo< D > const &si, softmax_mode mode, softmax_alg alg=softmax_alg::ACCURATE)
void softmax_forward(ScalarT const &alpha_in, TensorDescT const &xDesc, El::Matrix< DataT, D > const &x, ScalarT const &beta_in, TensorDescT const &yDesc, El::Matrix< DataT, D > &y, El::SyncInfo< D > const &si, softmax_mode mode, softmax_alg alg=softmax_alg::ACCURATE)
softmax_alg
Internal LBANN names for supported softmax algorithms.
softmax_mode
Which tensor dimensions to apply softmax over.
auto get_packed_strides(size_t ndims, T const *dims)