27 #ifndef LBANN_LAYERS_MATH_DFT_ABS_HPP_INCLUDED 28 #define LBANN_LAYERS_MATH_DFT_ABS_HPP_INCLUDED 32 #include "lbann/proto/layers.pb.h" 33 #include "lbann_config.hpp" 41 template <
typename T, El::Device D>
68 template <
typename TensorDataType, El::Device Device>
69 class dft_abs_layer :
public data_type_layer<TensorDataType>
74 dft_abs_layer(lbann_comm*
const comm);
76 dft_abs_layer* copy()
const override {
return new dft_abs_layer(*
this); }
81 template <
typename ArchiveT>
86 std::string get_type()
const override {
return "DFT Abs"; }
87 data_layout get_data_layout()
const override {
return Layout; }
89 bool can_run_inplace()
const override {
return false; }
90 int get_backprop_requirements()
const override {
return ERROR_SIGNALS; }
92 description get_description()
const override 99 void write_specific_proto(lbann_data::Layer& proto)
const final;
101 friend class cereal::access;
102 dft_abs_layer() : dft_abs_layer(nullptr) {}
104 dft_abs_layer(dft_abs_layer
const&);
105 void setup_dims()
override;
106 void fp_compute()
override;
107 void bp_compute()
override;
110 using impl_type = dft_abs_impl<TensorDataType, Device>;
111 std::unique_ptr<impl_type> pimpl_;
114 template <
typename T, El::Device D>
115 void dft_abs_layer<T, D>::write_specific_proto(lbann_data::Layer& proto)
const 117 proto.set_datatype(proto::ProtoDataType<T>);
118 proto.mutable_dft_abs();
121 #ifndef LBANN_DFT_ABS_LAYER_INSTANTIATE 123 #ifdef LBANN_HAS_FFTW_FLOAT 124 extern template class dft_abs_layer<float, El::Device::CPU>;
125 #endif // LBANN_HAS_FFTW_FLOAT 126 #ifdef LBANN_HAS_FFTW_DOUBLE 127 extern template class dft_abs_layer<double, El::Device::CPU>;
128 #endif // LBANN_HAS_FFTW_DOUBLE 132 extern template class dft_abs_layer<float, El::Device::GPU>;
133 extern template class dft_abs_layer<double, El::Device::GPU>;
134 #endif // LBANN_HAS_GPU 136 #endif // LBANN_DFT_ABS_LAYER_INSTANTIATE 139 #endif // LBANN_HAS_FFTW 140 #endif // LBANN_LAYERS_MATH_DFT_ABS_HPP_INCLUDED
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
virtual description get_description() const
Human-readable description.
constexpr El::Device Device
data_layout
Data layout that is optimized for different modes of parallelism.