LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
fftw_wrapper.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_UTILS_FFTW_WRAPPER_HPP_
27 #define LBANN_UTILS_FFTW_WRAPPER_HPP_
28 
29 #include <lbann/base.hpp>
33 
34 #include <fftw3.h>
35 
36 namespace lbann {
37 
38 namespace fftw {
39 
40 template <typename T>
41 struct FFTWTypeT;
42 
43 template <>
44 struct FFTWTypeT<float>
45 {
46  using type = float;
47 };
48 template <>
49 struct FFTWTypeT<double>
50 {
51  using type = double;
52 };
53 template <>
54 struct FFTWTypeT<El::Complex<float>>
55 {
56  using type = fftwf_complex;
57 };
58 template <>
59 struct FFTWTypeT<El::Complex<double>>
60 {
61  using type = fftw_complex;
62 };
63 
64 template <typename T>
65 using FFTWType = typename FFTWTypeT<T>::type;
66 
67 template <typename T>
68 auto AsFFTWType(T* buffer)
69 {
70  return reinterpret_cast<FFTWType<T>*>(buffer);
71 }
72 
73 template <typename InputT, typename OutputT>
74 struct FFTWTraits;
75 
76 #define BUILD_FFTW_R2C_TRAITS(INTYPE, FFTW_PREFIX) \
77  template <> \
78  struct FFTWTraits<INTYPE, El::Complex<INTYPE>> \
79  { \
80  using plan_type = FFTW_PREFIX##_plan; \
81  using iodim_type = FFTW_PREFIX##_iodim; \
82  static constexpr auto plan_many_fwd = &FFTW_PREFIX##_plan_many_dft_r2c; \
83  static constexpr auto plan_many_bwd = &FFTW_PREFIX##_plan_many_dft_c2r; \
84  static constexpr auto plan_guru_fwd = &FFTW_PREFIX##_plan_guru_dft_r2c; \
85  static constexpr auto plan_guru_bwd = &FFTW_PREFIX##_plan_guru_dft_c2r; \
86  static constexpr auto execute_plan_fwd = &FFTW_PREFIX##_execute_dft_r2c; \
87  static constexpr auto execute_plan_bwd = &FFTW_PREFIX##_execute_dft_c2r; \
88  static constexpr auto destroy_plan = &FFTW_PREFIX##_destroy_plan; \
89  static constexpr auto plain_execute = &FFTW_PREFIX##_execute; \
90  }
91 
92 #define BUILD_FFTW_C2C_TRAITS(INTYPE, FFTW_PREFIX) \
93  template <> \
94  struct FFTWTraits<El::Complex<INTYPE>, El::Complex<INTYPE>> \
95  { \
96  using plan_type = FFTW_PREFIX##_plan; \
97  using iodim_type = FFTW_PREFIX##_iodim; \
98  static constexpr auto execute_plan_fwd = &FFTW_PREFIX##_execute_dft; \
99  static constexpr auto execute_plan_bwd = &FFTW_PREFIX##_execute_dft; \
100  static constexpr auto destroy_plan = &FFTW_PREFIX##_destroy_plan; \
101  static constexpr auto plain_execute = &FFTW_PREFIX##_execute; \
102  static plan_type plan_many_fwd(int rank, \
103  const int* n, \
104  int howmany, \
105  FFTW_PREFIX##_complex* in, \
106  const int* inembed, \
107  int istride, \
108  int idist, \
109  FFTW_PREFIX##_complex* out, \
110  const int* onembed, \
111  int ostride, \
112  int odist, \
113  unsigned flags) \
114  { \
115  return FFTW_PREFIX##_plan_many_dft(rank, \
116  n, \
117  howmany, \
118  in, \
119  inembed, \
120  istride, \
121  idist, \
122  out, \
123  onembed, \
124  ostride, \
125  odist, \
126  FFTW_FORWARD, \
127  flags); \
128  } \
129  static plan_type plan_many_bwd(int rank, \
130  const int* n, \
131  int howmany, \
132  FFTW_PREFIX##_complex* in, \
133  const int* inembed, \
134  int istride, \
135  int idist, \
136  FFTW_PREFIX##_complex* out, \
137  const int* onembed, \
138  int ostride, \
139  int odist, \
140  unsigned flags) \
141  { \
142  return FFTW_PREFIX##_plan_many_dft(rank, \
143  n, \
144  howmany, \
145  in, \
146  inembed, \
147  istride, \
148  idist, \
149  out, \
150  onembed, \
151  ostride, \
152  odist, \
153  FFTW_BACKWARD, \
154  flags); \
155  } \
156  static plan_type plan_guru_fwd(int rank, \
157  const FFTW_PREFIX##_iodim* dims, \
158  int howmany_rank, \
159  const FFTW_PREFIX##_iodim* howmany_dims, \
160  FFTW_PREFIX##_complex* in, \
161  FFTW_PREFIX##_complex* out, \
162  unsigned flags) \
163  { \
164  return FFTW_PREFIX##_plan_guru_dft(rank, \
165  dims, \
166  howmany_rank, \
167  howmany_dims, \
168  in, \
169  out, \
170  FFTW_FORWARD, \
171  flags); \
172  } \
173  static plan_type plan_guru_bwd(int rank, \
174  const FFTW_PREFIX##_iodim* dims, \
175  int howmany_rank, \
176  const FFTW_PREFIX##_iodim* howmany_dims, \
177  FFTW_PREFIX##_complex* in, \
178  FFTW_PREFIX##_complex* out, \
179  unsigned flags) \
180  { \
181  return FFTW_PREFIX##_plan_guru_dft(rank, \
182  dims, \
183  howmany_rank, \
184  howmany_dims, \
185  in, \
186  out, \
187  FFTW_BACKWARD, \
188  flags); \
189  } \
190  }
191 
192 BUILD_FFTW_R2C_TRAITS(float, fftwf);
193 BUILD_FFTW_R2C_TRAITS(double, fftw);
194 
195 BUILD_FFTW_C2C_TRAITS(float, fftwf);
196 BUILD_FFTW_C2C_TRAITS(double, fftw);
197 
207 template <typename InputTypeT>
209 {
210 public:
211  using InputType = InputTypeT;
213 
217 
218  using RealMatType = El::Matrix<RealType, El::Device::CPU>;
219  using ComplexMatType = El::Matrix<ComplexType, El::Device::CPU>;
220 
221  using InputMatType = El::Matrix<InputType, El::Device::CPU>;
222  using OutputMatType = El::Matrix<OutputType, El::Device::CPU>;
223 
224  using PlanType = typename TraitsType::plan_type;
225 
226 private:
228  {
229  PlanType plan_ = nullptr;
230  int num_samples_ = -1; // It's just an int in fftw
231  InternalPlanType(PlanType plan, int n) : plan_{plan}, num_samples_{n} {}
233  {
234  if (plan_ != nullptr) {
235  TraitsType::destroy_plan(plan_);
236  plan_ = nullptr;
237  }
238  }
240  : plan_{other.plan_}, num_samples_{other.num_samples_}
241  {
242  other.plan_ = nullptr;
243  other.num_samples_ = -1;
244  }
245  }; // struct InternalPlanType
246 
247 public:
248  FFTWWrapper() = default;
249  ~FFTWWrapper() = default;
250  // Movable, not copyable.
251  FFTWWrapper(FFTWWrapper&& other) noexcept = default;
252  FFTWWrapper(FFTWWrapper const&) = delete;
261  OutputMatType& out,
262  std::vector<int> const& full_dims)
263  {
264  setup_common(in,
265  out,
266  full_dims,
267  fwd_plans_,
268  TraitsType::plan_many_fwd,
269  TraitsType::plan_guru_fwd);
270  }
277  void setup_forward(InputMatType& in, std::vector<int> const& full_dims)
278  {
280  setup_forward(in, in, full_dims);
281  }
282 
291  InputMatType& out,
292  std::vector<int> const& full_dims)
293  {
294  setup_common(in,
295  out,
296  full_dims,
297  bwd_plans_,
298  TraitsType::plan_many_bwd,
299  TraitsType::plan_guru_bwd);
300  }
301 
308  void setup_backward(OutputMatType& in, std::vector<int> const& full_dims)
309  {
310  setup_backward(in, in, full_dims);
311  }
312 
314  {
315  auto const num_samples = in.Width();
316  auto const good_plan =
317  std::find_if(cbegin(fwd_plans_),
318  cend(fwd_plans_),
319  [num_samples](InternalPlanType const& a) {
320  return a.num_samples_ == num_samples;
321  });
322  if (good_plan == cend(fwd_plans_))
323  LBANN_ERROR("No valid FFTW plan found.");
324 
325  // Initial tests suggest there's no performance reason to *not*
326  // use the "new-array" interface.
327  TraitsType::execute_plan_fwd(good_plan->plan_,
328  AsFFTWType(in.Buffer()),
329  AsFFTWType(out.Buffer()));
330  }
331 
333  {
334  return compute_forward(in, in);
335  }
336 
338  {
339  auto const num_samples = in.Width();
340  auto const good_plan =
341  std::find_if(cbegin(bwd_plans_),
342  cend(bwd_plans_),
343  [num_samples](InternalPlanType const& a) {
344  return a.num_samples_ == num_samples;
345  });
346  if (good_plan == cend(bwd_plans_))
347  LBANN_ERROR("No valid FFTW plan found.");
348 
349  // Initial tests suggest there's no performance reason to *not*
350  // use the "new-array" interface.
351  TraitsType::execute_plan_bwd(good_plan->plan_,
352  AsFFTWType(in.Buffer()),
353  AsFFTWType(out.Buffer()));
354  }
355 
357  {
358  return compute_backward(in, in);
359  }
360 
361 private:
362  template <typename InMatT,
363  typename OutMatT,
364  typename SetupManyFunctorT,
365  typename SetupGuruFunctorT>
366  void setup_common(InMatT& in,
367  OutMatT& out,
368  std::vector<int> const& full_dims,
369  std::vector<InternalPlanType>& plans,
370  SetupManyFunctorT many_functor,
371  SetupGuruFunctorT guru_functor)
372  {
373  using in_data_type = typename InMatT::value_type;
374  using out_data_type = typename OutMatT::value_type;
376 
377  // Look for an acceptable plan
378  int const num_samples = in.Width();
379  auto const good_plan =
380  std::find_if(cbegin(plans),
381  cend(plans),
382  [num_samples](InternalPlanType const& a) {
383  return a.num_samples_ == num_samples;
384  });
385 
386  // We don't have a plan for this yet; let's create one!
387  if (good_plan == cend(plans)) {
388  PlanType plan;
389 
390  auto const& input_dims = Dims::input_dims(full_dims);
391  auto const& output_dims = Dims::output_dims(full_dims);
392  int const num_feature_maps = full_dims.front();
393  int const feature_map_ndims = full_dims.size() - 1;
394  bool const contiguous_samples = (in.Contiguous()) && (out.Contiguous());
395 
396  // Handle the easy case
397  if (contiguous_samples) {
398  int const num_transforms = num_samples * num_feature_maps;
399  int const input_feature_map_size =
400  get_linear_size(feature_map_ndims, input_dims.data() + 1);
401  int const output_feature_map_size =
402  get_linear_size(feature_map_ndims, output_dims.data() + 1);
403  plan = many_functor(feature_map_ndims,
404  full_dims.data() + 1,
405  num_transforms,
406  AsFFTWType(in.Buffer()),
407  nullptr,
408  1,
409  input_feature_map_size,
410  AsFFTWType(out.Buffer()),
411  nullptr,
412  1,
413  output_feature_map_size,
414  /*flags=*/0); // FFTW_PRESERVE_INPUT);
415  }
416  else {
417  using IODimType = typename TraitsType::iodim_type;
418 
419  std::vector<IODimType> dims(feature_map_ndims), how_many(2);
420 
421  auto input_strides = get_packed_strides(input_dims);
422  auto output_strides = get_packed_strides(output_dims);
423 
424  // Setup the "dims"
425  for (int d = 0; d < feature_map_ndims; ++d) {
426  dims[d].n = full_dims[d + 1];
427  dims[d].is = input_strides[d + 1];
428  dims[d].os = output_strides[d + 1];
429  }
430 
431  // Setup the "howmany"
432  how_many[0].n = num_feature_maps;
433  how_many[0].is = input_strides.front();
434  how_many[0].os = output_strides.front();
435 
436  how_many[1].n = num_samples;
437  how_many[1].is = in.LDim();
438  how_many[1].os = out.LDim();
439 
440  plan = guru_functor(dims.size(),
441  dims.data(),
442  how_many.size(),
443  how_many.data(),
444  AsFFTWType(in.Buffer()),
445  AsFFTWType(out.Buffer()),
446  /*flags=*/0); // FFTW_PRESERVE_INPUT);
447  }
448 
449  if (plan == nullptr)
450  LBANN_ERROR(__PRETTY_FUNCTION__,
451  ": FFTW plan construction failed.\n"
452  " contiguous: ",
453  contiguous_samples);
454 
455  plans.emplace_back(plan, num_samples);
456  }
457  }
458 
459 private:
460  // These are likely to be so few in number that a linear search is
461  // going to be fine.
462  std::vector<InternalPlanType> fwd_plans_;
463  std::vector<InternalPlanType> bwd_plans_;
464 
465 }; // class FFTWWrapper
466 
467 } // namespace fftw
468 
469 } // namespace lbann
470 #endif // LBANN_UTILS_FFTW_WRAPPER_HPP_
BUILD_FFTW_R2C_TRAITS(float, fftwf)
typename ToRealT< T >::type ToReal
Definition: fft_common.hpp:50
ToReal< InputType > RealType
Wrapper around FFTW.
auto get_linear_size(std::vector< T > const &dims)
Definition: dim_helpers.hpp:59
typename ToComplexT< T >::type ToComplex
Definition: fft_common.hpp:65
typename FFTWTypeT< T >::type FFTWType
std::vector< InternalPlanType > fwd_plans_
#define LBANN_ERROR(...)
Definition: exception.hpp:37
void setup_forward(InputMatType &in, std::vector< int > const &full_dims)
Setup an in-place forward transform.
ToComplex< InputType > OutputType
void setup_common(InMatT &in, OutMatT &out, std::vector< int > const &full_dims, std::vector< InternalPlanType > &plans, SetupManyFunctorT many_functor, SetupGuruFunctorT guru_functor)
void compute_backward(OutputMatType &in, InputMatType &out) const
El::Matrix< OutputType, El::Device::CPU > OutputMatType
El::Matrix< ComplexType, El::Device::CPU > ComplexMatType
void compute_forward(InputMatType &in, OutputMatType &out) const
El::Matrix< RealType, El::Device::CPU > RealMatType
El::Matrix< InputType, El::Device::CPU > InputMatType
void setup_backward(OutputMatType &in, std::vector< int > const &full_dims)
Setup the in-place backward (inverse) transform.
InternalPlanType(InternalPlanType &&other) noexcept
auto AsFFTWType(T *buffer)
void setup_backward(OutputMatType &in, InputMatType &out, std::vector< int > const &full_dims)
Setup the backward (inverse) transform.
ToComplex< InputType > ComplexType
typename TraitsType::plan_type PlanType
void compute_forward(InputMatType &in) const
std::vector< InternalPlanType > bwd_plans_
void compute_backward(OutputMatType &in) const
void setup_forward(InputMatType &in, OutputMatType &out, std::vector< int > const &full_dims)
Setup the forward transform.
BUILD_FFTW_C2C_TRAITS(float, fftwf)
auto get_packed_strides(size_t ndims, T const *dims)
Definition: dim_helpers.hpp:91