26 #ifndef LBANN_UTILS_FFT_COMMON_HPP_ 27 #define LBANN_UTILS_FFT_COMMON_HPP_ 55 using type = El::Complex<T>;
61 using type = El::Complex<T>;
71 std::vector<T> r2c_dims(dims);
72 r2c_dims.back() = r2c_dims.back() / 2 + 1;
87 template <
typename InT,
typename OutT>
90 template <
typename InOutT>
93 static auto input_dims(std::vector<int>
const& full_dims)
103 template <
typename RealT>
116 template <
typename RealT>
129 template <
typename InT,
typename OutT,
typename TransformT>
131 El::Matrix<OutT, El::Device::CPU>& full_output,
132 std::vector<int>
const& full_dims,
133 TransformT transform)
135 if (full_dims.size() != 2UL)
140 auto const feat_map_ndims = full_dims.size() - 1;
141 auto const r2c_feat_map_size =
143 auto const num_samples = r2c_input.Width();
144 auto const num_feat_maps = full_dims[0];
145 auto const num_entries_full = full_dims[1];
146 auto const num_entries_r2c = r2c_dims[1];
147 auto const num_diff_entries = num_entries_full - num_entries_r2c;
150 auto conj_transform = [&t = transform](InT
const& x) {
151 return t(El::Conj(x));
156 for (
int sample_id = 0; sample_id < num_samples; ++sample_id) {
157 auto output_start = full_output.Buffer() +
158 sample_id * full_output.LDim();
159 for (
int feat_map_id = 0; feat_map_id < num_feat_maps; ++feat_map_id) {
161 auto const r2c_fm_start =
162 r2c_input.LockedBuffer() +
163 sample_id * r2c_input.LDim()
164 + feat_map_id * r2c_feat_map_size;
166 auto const r2c_conj_fm_start =
167 r2c_input.LockedBuffer() +
168 sample_id * r2c_input.LDim()
169 + feat_map_id * r2c_feat_map_size
171 auto const r2c_conj_fm_end = r2c_conj_fm_start + num_diff_entries;
174 output_start = std::transform(r2c_fm_start,
175 r2c_fm_start + num_entries_r2c,
180 auto const r2c_conj_rbegin =
181 std::reverse_iterator<InT const*>(r2c_conj_fm_end);
182 auto const r2c_conj_rend =
183 std::reverse_iterator<InT const*>(r2c_conj_fm_start);
184 output_start = std::transform(r2c_conj_rbegin,
192 template <
typename InT,
typename OutT,
typename TransformT>
194 El::Matrix<OutT, El::Device::CPU>& full_output,
195 std::vector<int>
const& full_dims,
196 TransformT transform)
198 if (full_dims.size() != 3UL)
203 auto const feat_map_ndims = 2;
204 auto const r2c_feat_map_size =
206 auto const num_samples = r2c_input.Width();
207 auto const num_feat_maps = full_dims[0];
208 auto const num_rows = full_dims[1];
209 auto const num_cols_full = full_dims[2];
210 auto const num_cols_r2c = r2c_dims[2];
211 auto const num_diff_cols = num_cols_full - num_cols_r2c;
214 auto conj_transform = [&t = transform](InT
const& x) {
return t(Conj(x)); };
218 for (
int sample_id = 0; sample_id < num_samples; ++sample_id) {
220 auto output_start = full_output.Buffer() +
221 sample_id * full_output.LDim();
222 for (
int feat_map_id = 0; feat_map_id < num_feat_maps; ++feat_map_id) {
223 for (
int row = 0; row < num_rows; ++row) {
224 auto const conj_row_index = (row == 0 ? 0 : num_rows - row);
227 auto const r2c_row_start =
228 r2c_input.LockedBuffer() +
229 sample_id * r2c_input.LDim()
230 + feat_map_id * r2c_feat_map_size
231 + row * num_cols_r2c;
233 auto const r2c_conj_row_start =
234 r2c_input.LockedBuffer() +
235 sample_id * r2c_input.LDim()
236 + feat_map_id * r2c_feat_map_size
237 + conj_row_index * num_cols_r2c
239 auto const r2c_conj_row_end = r2c_conj_row_start + num_diff_cols;
242 output_start = std::transform(r2c_row_start,
243 r2c_row_start + num_cols_r2c,
248 auto const r2c_conj_rbegin =
249 std::reverse_iterator<InT const*>(r2c_conj_row_end);
250 auto const r2c_conj_rend =
251 std::reverse_iterator<InT const*>(r2c_conj_row_start);
252 output_start = std::transform(r2c_conj_rbegin,
261 template <
typename InT,
typename OutT>
262 void r2c_to_full(El::Matrix<InT, El::Device::CPU>
const& r2c_input,
263 El::Matrix<OutT, El::Device::CPU>& full_output,
264 std::vector<int>
const& full_dims)
266 auto abs_val_func = [](InT
const& in) {
return std::abs(in); };
267 switch (full_dims.size()) {
271 "The first entry in the dimension array MUST " 272 "be the number of feature maps.");
281 LBANN_ERROR(
"LBANN currently only supports 1D and 2D DFT algorithms. " 282 "Please open an issue on GitHub describing the use-case " 283 "for higher-dimensional DFTs.");
290 #endif // LBANN_UTILS_FFT_COMMON_HPP_ void r2c_to_full_1d(El::Matrix< InT, El::Device::CPU > const &r2c_input, El::Matrix< OutT, El::Device::CPU > &full_output, std::vector< int > const &full_dims, TransformT transform)
static auto input_dims(std::vector< int > const &full_dims)
typename ToRealT< T >::type ToReal
auto get_linear_size(std::vector< T > const &dims)
typename ToComplexT< T >::type ToComplex
static auto output_dims(std::vector< int > const &full_dims)
void r2c_to_full(El::Matrix< InT, El::Device::CPU > const &r2c_input, El::Matrix< OutT, El::Device::CPU > &full_output, std::vector< int > const &full_dims)
static auto input_dims(std::vector< int > const &full_dims)
static auto output_dims(std::vector< int > const &full_dims)
auto get_r2c_output_dims(std::vector< T > const &dims)
static auto output_dims(std::vector< int > const &full_dims)
static auto input_dims(std::vector< int > const &full_dims)
void r2c_to_full_2d(El::Matrix< InT, El::Device::CPU > const &r2c_input, El::Matrix< OutT, El::Device::CPU > &full_output, std::vector< int > const &full_dims, TransformT transform)