LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
fft_common.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_UTILS_FFT_COMMON_HPP_
27 #define LBANN_UTILS_FFT_COMMON_HPP_
28 
29 #include <lbann/base.hpp>
32 
33 namespace lbann {
34 // Some metaprogramming. This isn't specific to FFT and should
35 // probably move elsewhere sometime soon.
36 
37 template <typename T>
38 struct ToRealT
39 {
40  using type = T;
41 };
42 
43 template <typename T>
44 struct ToRealT<El::Complex<T>>
45 {
46  using type = T;
47 };
48 
49 template <typename T>
50 using ToReal = typename ToRealT<T>::type;
51 
52 template <typename T>
53 struct ToComplexT
54 {
55  using type = El::Complex<T>;
56 };
57 
58 template <typename T>
59 struct ToComplexT<El::Complex<T>>
60 {
61  using type = El::Complex<T>;
62 };
63 
64 template <typename T>
65 using ToComplex = typename ToComplexT<T>::type;
66 
67 namespace fft {
68 template <typename T>
69 auto get_r2c_output_dims(std::vector<T> const& dims)
70 {
71  std::vector<T> r2c_dims(dims);
72  r2c_dims.back() = r2c_dims.back() / 2 + 1;
73  return r2c_dims;
74 }
75 
76 // D1=float, D2=Complex<float>
77 // input_dims<D1,D2>(full_dims) = full_dims
78 // output_dims<D1,D2>(full_dims) = r2c_dims
79 //
80 // D1=Complex<float>, D2=Complex<float>
81 // input_dims<D1,D2>(full_dims) = full_dims
82 // output_dims<D1,D2>(full_dims) = full_dims
83 //
84 // D1=Complex<float>, D2=float
85 // input_dims<D1,D2>(full_dims) = r2c_dims
86 // output_dims<D1,D2>(full_dims) = full_dims
87 template <typename InT, typename OutT>
88 struct DimsHelper;
89 
90 template <typename InOutT>
91 struct DimsHelper<InOutT, InOutT>
92 {
93  static auto input_dims(std::vector<int> const& full_dims)
94  {
95  return full_dims;
96  }
97  static auto output_dims(std::vector<int> const& full_dims)
98  {
99  return full_dims;
100  }
101 };
102 
103 template <typename RealT>
104 struct DimsHelper<RealT, El::Complex<RealT>>
105 {
106  static auto input_dims(std::vector<int> const& full_dims)
107  {
108  return full_dims;
109  }
110  static auto output_dims(std::vector<int> const& full_dims)
111  {
112  return get_r2c_output_dims(full_dims);
113  }
114 };
115 
116 template <typename RealT>
117 struct DimsHelper<El::Complex<RealT>, RealT>
118 {
119  static auto input_dims(std::vector<int> const& full_dims)
120  {
121  return get_r2c_output_dims(full_dims);
122  }
123  static auto output_dims(std::vector<int> const& full_dims)
124  {
125  return full_dims;
126  }
127 };
128 
129 template <typename InT, typename OutT, typename TransformT>
130 void r2c_to_full_1d(El::Matrix<InT, El::Device::CPU> const& r2c_input,
131  El::Matrix<OutT, El::Device::CPU>& full_output,
132  std::vector<int> const& full_dims,
133  TransformT transform)
134 {
135  if (full_dims.size() != 2UL)
136  LBANN_ERROR("Only valid for 1-D feature maps.");
137 
138  auto const r2c_dims = lbann::fft::get_r2c_output_dims(full_dims);
139 
140  auto const feat_map_ndims = full_dims.size() - 1;
141  auto const r2c_feat_map_size =
142  lbann::get_linear_size(feat_map_ndims, r2c_dims.data() + 1);
143  auto const num_samples = r2c_input.Width();
144  auto const num_feat_maps = full_dims[0];
145  auto const num_entries_full = full_dims[1];
146  auto const num_entries_r2c = r2c_dims[1];
147  auto const num_diff_entries = num_entries_full - num_entries_r2c;
148 
149  // A function to conjugate an element and then apply the transform.
150  auto conj_transform = [&t = transform](InT const& x) {
151  return t(El::Conj(x));
152  };
153 
154  // Make sure output is setup.
155  full_output.Resize(lbann::get_linear_size(full_dims), num_samples);
156  for (int sample_id = 0; sample_id < num_samples; ++sample_id) {
157  auto output_start = full_output.Buffer() +
158  sample_id * full_output.LDim(); // Get to this sample
159  for (int feat_map_id = 0; feat_map_id < num_feat_maps; ++feat_map_id) {
160  // This is the part that gets copied directly.
161  auto const r2c_fm_start =
162  r2c_input.LockedBuffer() +
163  sample_id * r2c_input.LDim() // Get to this sample
164  + feat_map_id * r2c_feat_map_size; // Get to this feature map
165  // This is the part that gets reverse-copied.
166  auto const r2c_conj_fm_start =
167  r2c_input.LockedBuffer() +
168  sample_id * r2c_input.LDim() // Get to this sample
169  + feat_map_id * r2c_feat_map_size // Get to this feature map
170  + 1;
171  auto const r2c_conj_fm_end = r2c_conj_fm_start + num_diff_entries;
172 
173  // Direct copy bit.
174  output_start = std::transform(r2c_fm_start,
175  r2c_fm_start + num_entries_r2c,
176  output_start,
177  transform);
178 
179  // Reverse conjugate-and-copy bit.
180  auto const r2c_conj_rbegin =
181  std::reverse_iterator<InT const*>(r2c_conj_fm_end);
182  auto const r2c_conj_rend =
183  std::reverse_iterator<InT const*>(r2c_conj_fm_start);
184  output_start = std::transform(r2c_conj_rbegin,
185  r2c_conj_rend,
186  output_start,
187  conj_transform);
188  }
189  }
190 }
191 
192 template <typename InT, typename OutT, typename TransformT>
193 void r2c_to_full_2d(El::Matrix<InT, El::Device::CPU> const& r2c_input,
194  El::Matrix<OutT, El::Device::CPU>& full_output,
195  std::vector<int> const& full_dims,
196  TransformT transform)
197 {
198  if (full_dims.size() != 3UL)
199  LBANN_ERROR("Only valid for 2-D feature maps.");
200 
201  auto const r2c_dims = lbann::fft::get_r2c_output_dims(full_dims);
202 
203  auto const feat_map_ndims = 2;
204  auto const r2c_feat_map_size =
205  lbann::get_linear_size(feat_map_ndims, r2c_dims.data() + 1);
206  auto const num_samples = r2c_input.Width();
207  auto const num_feat_maps = full_dims[0];
208  auto const num_rows = full_dims[1];
209  auto const num_cols_full = full_dims[2];
210  auto const num_cols_r2c = r2c_dims[2];
211  auto const num_diff_cols = num_cols_full - num_cols_r2c;
212 
213  // Convenience function:
214  auto conj_transform = [&t = transform](InT const& x) { return t(Conj(x)); };
215 
216  // Make sure output is setup.
217  full_output.Resize(lbann::get_linear_size(full_dims), num_samples);
218  for (int sample_id = 0; sample_id < num_samples; ++sample_id) {
219  // This is the start of the feature map.
220  auto output_start = full_output.Buffer() +
221  sample_id * full_output.LDim(); // Get to this sample
222  for (int feat_map_id = 0; feat_map_id < num_feat_maps; ++feat_map_id) {
223  for (int row = 0; row < num_rows; ++row) {
224  auto const conj_row_index = (row == 0 ? 0 : num_rows - row);
225 
226  // This is the part that gets copied directly.
227  auto const r2c_row_start =
228  r2c_input.LockedBuffer() +
229  sample_id * r2c_input.LDim() // Get to this sample
230  + feat_map_id * r2c_feat_map_size // Get to this feature map
231  + row * num_cols_r2c; // Get to this row
232  // This is the part that gets reverse-copied
233  auto const r2c_conj_row_start =
234  r2c_input.LockedBuffer() +
235  sample_id * r2c_input.LDim() // Get to this sample
236  + feat_map_id * r2c_feat_map_size // Get to this feature map
237  + conj_row_index * num_cols_r2c // Get to this row
238  + 1; // Get to the right col
239  auto const r2c_conj_row_end = r2c_conj_row_start + num_diff_cols;
240 
241  // Directly copy the row
242  output_start = std::transform(r2c_row_start,
243  r2c_row_start + num_cols_r2c,
244  output_start,
245  transform);
246 
247  // Reverse copy the conjugated bits
248  auto const r2c_conj_rbegin =
249  std::reverse_iterator<InT const*>(r2c_conj_row_end);
250  auto const r2c_conj_rend =
251  std::reverse_iterator<InT const*>(r2c_conj_row_start);
252  output_start = std::transform(r2c_conj_rbegin,
253  r2c_conj_rend,
254  output_start,
255  conj_transform);
256  }
257  }
258  }
259 }
260 
261 template <typename InT, typename OutT>
262 void r2c_to_full(El::Matrix<InT, El::Device::CPU> const& r2c_input,
263  El::Matrix<OutT, El::Device::CPU>& full_output,
264  std::vector<int> const& full_dims)
265 {
266  auto abs_val_func = [](InT const& in) { return std::abs(in); };
267  switch (full_dims.size()) {
268  case 0:
269  case 1:
270  LBANN_ERROR("Invalid dimension size. Remember: "
271  "The first entry in the dimension array MUST "
272  "be the number of feature maps.");
273  break;
274  case 2:
275  r2c_to_full_1d(r2c_input, full_output, full_dims, abs_val_func);
276  break;
277  case 3:
278  r2c_to_full_2d(r2c_input, full_output, full_dims, abs_val_func);
279  break;
280  default:
281  LBANN_ERROR("LBANN currently only supports 1D and 2D DFT algorithms. "
282  "Please open an issue on GitHub describing the use-case "
283  "for higher-dimensional DFTs.");
284  break;
285  }
286 }
287 
288 } // namespace fft
289 } // namespace lbann
290 #endif // LBANN_UTILS_FFT_COMMON_HPP_
void r2c_to_full_1d(El::Matrix< InT, El::Device::CPU > const &r2c_input, El::Matrix< OutT, El::Device::CPU > &full_output, std::vector< int > const &full_dims, TransformT transform)
Definition: fft_common.hpp:130
static auto input_dims(std::vector< int > const &full_dims)
Definition: fft_common.hpp:106
typename ToRealT< T >::type ToReal
Definition: fft_common.hpp:50
auto get_linear_size(std::vector< T > const &dims)
Definition: dim_helpers.hpp:59
typename ToComplexT< T >::type ToComplex
Definition: fft_common.hpp:65
#define LBANN_ERROR(...)
Definition: exception.hpp:37
static auto output_dims(std::vector< int > const &full_dims)
Definition: fft_common.hpp:123
void r2c_to_full(El::Matrix< InT, El::Device::CPU > const &r2c_input, El::Matrix< OutT, El::Device::CPU > &full_output, std::vector< int > const &full_dims)
Definition: fft_common.hpp:262
static auto input_dims(std::vector< int > const &full_dims)
Definition: fft_common.hpp:93
static auto output_dims(std::vector< int > const &full_dims)
Definition: fft_common.hpp:97
auto get_r2c_output_dims(std::vector< T > const &dims)
Definition: fft_common.hpp:69
El::Complex< T > type
Definition: fft_common.hpp:55
static auto output_dims(std::vector< int > const &full_dims)
Definition: fft_common.hpp:110
static auto input_dims(std::vector< int > const &full_dims)
Definition: fft_common.hpp:119
void r2c_to_full_2d(El::Matrix< InT, El::Device::CPU > const &r2c_input, El::Matrix< OutT, El::Device::CPU > &full_output, std::vector< int > const &full_dims, TransformT transform)
Definition: fft_common.hpp:193