26 #ifndef LBANN_UTILS_FFTW_WRAPPER_HPP_ 27 #define LBANN_UTILS_FFTW_WRAPPER_HPP_ 56 using type = fftwf_complex;
73 template <
typename InputT,
typename OutputT>
76 #define BUILD_FFTW_R2C_TRAITS(INTYPE, FFTW_PREFIX) \ 78 struct FFTWTraits<INTYPE, El::Complex<INTYPE>> \ 80 using plan_type = FFTW_PREFIX##_plan; \ 81 using iodim_type = FFTW_PREFIX##_iodim; \ 82 static constexpr auto plan_many_fwd = &FFTW_PREFIX##_plan_many_dft_r2c; \ 83 static constexpr auto plan_many_bwd = &FFTW_PREFIX##_plan_many_dft_c2r; \ 84 static constexpr auto plan_guru_fwd = &FFTW_PREFIX##_plan_guru_dft_r2c; \ 85 static constexpr auto plan_guru_bwd = &FFTW_PREFIX##_plan_guru_dft_c2r; \ 86 static constexpr auto execute_plan_fwd = &FFTW_PREFIX##_execute_dft_r2c; \ 87 static constexpr auto execute_plan_bwd = &FFTW_PREFIX##_execute_dft_c2r; \ 88 static constexpr auto destroy_plan = &FFTW_PREFIX##_destroy_plan; \ 89 static constexpr auto plain_execute = &FFTW_PREFIX##_execute; \ 92 #define BUILD_FFTW_C2C_TRAITS(INTYPE, FFTW_PREFIX) \ 94 struct FFTWTraits<El::Complex<INTYPE>, El::Complex<INTYPE>> \ 96 using plan_type = FFTW_PREFIX##_plan; \ 97 using iodim_type = FFTW_PREFIX##_iodim; \ 98 static constexpr auto execute_plan_fwd = &FFTW_PREFIX##_execute_dft; \ 99 static constexpr auto execute_plan_bwd = &FFTW_PREFIX##_execute_dft; \ 100 static constexpr auto destroy_plan = &FFTW_PREFIX##_destroy_plan; \ 101 static constexpr auto plain_execute = &FFTW_PREFIX##_execute; \ 102 static plan_type plan_many_fwd(int rank, \ 105 FFTW_PREFIX##_complex* in, \ 106 const int* inembed, \ 109 FFTW_PREFIX##_complex* out, \ 110 const int* onembed, \ 115 return FFTW_PREFIX##_plan_many_dft(rank, \ 129 static plan_type plan_many_bwd(int rank, \ 132 FFTW_PREFIX##_complex* in, \ 133 const int* inembed, \ 136 FFTW_PREFIX##_complex* out, \ 137 const int* onembed, \ 142 return FFTW_PREFIX##_plan_many_dft(rank, \ 156 static plan_type plan_guru_fwd(int rank, \ 157 const FFTW_PREFIX##_iodim* dims, \ 159 const FFTW_PREFIX##_iodim* howmany_dims, \ 160 FFTW_PREFIX##_complex* in, \ 161 FFTW_PREFIX##_complex* out, \ 164 return FFTW_PREFIX##_plan_guru_dft(rank, \ 173 static plan_type plan_guru_bwd(int rank, \ 174 const FFTW_PREFIX##_iodim* dims, \ 176 const FFTW_PREFIX##_iodim* howmany_dims, \ 177 FFTW_PREFIX##_complex* in, \ 178 FFTW_PREFIX##_complex* out, \ 181 return FFTW_PREFIX##_plan_guru_dft(rank, \ 207 template <
typename InputTypeT>
230 int num_samples_ = -1;
234 if (plan_ !=
nullptr) {
235 TraitsType::destroy_plan(plan_);
240 : plan_{other.plan_}, num_samples_{other.num_samples_}
242 other.plan_ =
nullptr;
243 other.num_samples_ = -1;
262 std::vector<int>
const& full_dims)
268 TraitsType::plan_many_fwd,
269 TraitsType::plan_guru_fwd);
280 setup_forward(in, in, full_dims);
292 std::vector<int>
const& full_dims)
298 TraitsType::plan_many_bwd,
299 TraitsType::plan_guru_bwd);
310 setup_backward(in, in, full_dims);
315 auto const num_samples = in.Width();
316 auto const good_plan =
317 std::find_if(cbegin(fwd_plans_),
322 if (good_plan == cend(fwd_plans_))
327 TraitsType::execute_plan_fwd(good_plan->plan_,
334 return compute_forward(in, in);
339 auto const num_samples = in.Width();
340 auto const good_plan =
341 std::find_if(cbegin(bwd_plans_),
346 if (good_plan == cend(bwd_plans_))
351 TraitsType::execute_plan_bwd(good_plan->plan_,
358 return compute_backward(in, in);
362 template <
typename InMatT,
364 typename SetupManyFunctorT,
365 typename SetupGuruFunctorT>
368 std::vector<int>
const& full_dims,
369 std::vector<InternalPlanType>& plans,
370 SetupManyFunctorT many_functor,
371 SetupGuruFunctorT guru_functor)
373 using in_data_type =
typename InMatT::value_type;
374 using out_data_type =
typename OutMatT::value_type;
378 int const num_samples = in.Width();
379 auto const good_plan =
380 std::find_if(cbegin(plans),
387 if (good_plan == cend(plans)) {
390 auto const& input_dims = Dims::input_dims(full_dims);
391 auto const& output_dims = Dims::output_dims(full_dims);
392 int const num_feature_maps = full_dims.front();
393 int const feature_map_ndims = full_dims.size() - 1;
394 bool const contiguous_samples = (in.Contiguous()) && (out.Contiguous());
397 if (contiguous_samples) {
398 int const num_transforms = num_samples * num_feature_maps;
399 int const input_feature_map_size =
401 int const output_feature_map_size =
403 plan = many_functor(feature_map_ndims,
404 full_dims.data() + 1,
409 input_feature_map_size,
413 output_feature_map_size,
417 using IODimType =
typename TraitsType::iodim_type;
419 std::vector<IODimType> dims(feature_map_ndims), how_many(2);
425 for (
int d = 0; d < feature_map_ndims; ++d) {
426 dims[d].n = full_dims[d + 1];
427 dims[d].is = input_strides[d + 1];
428 dims[d].os = output_strides[d + 1];
432 how_many[0].n = num_feature_maps;
433 how_many[0].is = input_strides.front();
434 how_many[0].os = output_strides.front();
436 how_many[1].n = num_samples;
437 how_many[1].is = in.LDim();
438 how_many[1].os = out.LDim();
440 plan = guru_functor(dims.size(),
451 ": FFTW plan construction failed.\n" 455 plans.emplace_back(plan, num_samples);
470 #endif // LBANN_UTILS_FFTW_WRAPPER_HPP_
BUILD_FFTW_R2C_TRAITS(float, fftwf)
typename ToRealT< T >::type ToReal
ToReal< InputType > RealType
auto get_linear_size(std::vector< T > const &dims)
typename ToComplexT< T >::type ToComplex
typename FFTWTypeT< T >::type FFTWType
std::vector< InternalPlanType > fwd_plans_
void setup_forward(InputMatType &in, std::vector< int > const &full_dims)
Setup an in-place forward transform.
ToComplex< InputType > OutputType
void setup_common(InMatT &in, OutMatT &out, std::vector< int > const &full_dims, std::vector< InternalPlanType > &plans, SetupManyFunctorT many_functor, SetupGuruFunctorT guru_functor)
void compute_backward(OutputMatType &in, InputMatType &out) const
El::Matrix< OutputType, El::Device::CPU > OutputMatType
El::Matrix< ComplexType, El::Device::CPU > ComplexMatType
void compute_forward(InputMatType &in, OutputMatType &out) const
El::Matrix< RealType, El::Device::CPU > RealMatType
El::Matrix< InputType, El::Device::CPU > InputMatType
void setup_backward(OutputMatType &in, std::vector< int > const &full_dims)
Setup the in-place backward (inverse) transform.
InternalPlanType(InternalPlanType &&other) noexcept
auto AsFFTWType(T *buffer)
void setup_backward(OutputMatType &in, InputMatType &out, std::vector< int > const &full_dims)
Setup the backward (inverse) transform.
ToComplex< InputType > ComplexType
typename TraitsType::plan_type PlanType
void compute_forward(InputMatType &in) const
std::vector< InternalPlanType > bwd_plans_
void compute_backward(OutputMatType &in) const
void setup_forward(InputMatType &in, OutputMatType &out, std::vector< int > const &full_dims)
Setup the forward transform.
BUILD_FFTW_C2C_TRAITS(float, fftwf)
InternalPlanType(PlanType plan, int n)
auto get_packed_strides(size_t ndims, T const *dims)