27 #ifndef LBANN_OPERATORS_ELEMENTWISE_OPERATOR_HPP_INCLUDED 28 #define LBANN_OPERATORS_ELEMENTWISE_OPERATOR_HPP_INCLUDED 32 #include <cereal/cereal.hpp> 36 #include <type_traits> 44 template <
typename InputT,
typename OutputT, El::Device D>
47 Operator<InputT, OutputT, D>>
76 template <
typename ArchiveT>
85 using BaseType::bp_compute;
87 using BaseType::fp_compute;
89 template <
typename TensorViewType>
92 using LocalViewType = std::decay_t<decltype(in[0].local_data())>;
94 std::vector<LocalViewType> local_views;
95 local_views.reserve(in.size());
96 std::transform(cbegin(in),
98 std::back_inserter(local_views),
99 [](
auto const& x) {
return x.local_data(); });
107 void fp_compute(std::vector<ConstInputTensorType>
const& inputs,
108 std::vector<OutputTensorType>
const& outputs)
const final 124 std::vector<ConstInputTensorType>
const& inputs,
125 std::vector<ConstOutputTensorType>
const& gradient_wrt_outputs,
126 std::vector<InputTensorType>
const& gradient_wrt_inputs)
const final 150 std::vector<LocalOutputTensorType> output)
const = 0;
154 std::vector<ConstLocalInputTensorType> input,
155 std::vector<ConstLocalOutputTensorType> gradient_wrt_output,
156 std::vector<LocalInputTensorType> gradient_wrt_input)
const = 0;
163 #endif // LBANN_OPERATORS_ELEMENTWISE_OPERATOR_HPP_INCLUDED
Element-wise specific tensor operation sub-class.
Inject polymorphic clone functions into hierarchies.
void bp_compute(std::vector< ConstInputTensorType > const &inputs, std::vector< ConstOutputTensorType > const &gradient_wrt_outputs, std::vector< InputTensorType > const &gradient_wrt_inputs) const final
Compute operator's "backward" operation.
void serialize(ArchiveT &ar)
virtual void bp_compute_local(std::vector< ConstLocalInputTensorType > input, std::vector< ConstLocalOutputTensorType > gradient_wrt_output, std::vector< LocalInputTensorType > gradient_wrt_input) const =0
Local backward compute function.
typename BaseType::ConstOutputTensorType ConstOutputTensorType
typename BaseType::ConstInputTensorType ConstInputTensorType
ElementwiseOperator()=default
void fp_compute(std::vector< ConstInputTensorType > const &inputs, std::vector< OutputTensorType > const &outputs) const final
Apply operator's forward operation.
ElementwiseOperator & operator=(ElementwiseOperator const &)=default
typename BaseType::OutputTensorType OutputTensorType
typename BaseType::InputTensorType InputTensorType
typename OperatorTraits< OpT >::input_tensor_type InputTensorType
virtual ~ElementwiseOperator()=default
virtual void fp_compute_local(std::vector< ConstLocalInputTensorType > input, std::vector< LocalOutputTensorType > output) const =0
Local forward compute function.
typename OperatorTraits< OpT >::output_tensor_type OutputTensorType
static auto get_local_tensor_views(std::vector< TensorViewType > const &in)
Neural network tensor operation.