27 #ifndef LBANN_OPERATORS_MATH_ABS_HPP_INCLUDED 28 #define LBANN_OPERATORS_MATH_ABS_HPP_INCLUDED 30 #include "lbann_config.hpp" 36 #include "lbann/proto/operators.pb.h" 38 #include <h2/meta/Core.hpp> 40 #include <google/protobuf/message.h> 50 template <
typename DataT, El::Device D>
53 ElementwiseOperator<DataT, El::Base<DataT>, D>>
64 typename BaseType::ConstLocalInputTensorType;
66 typename BaseType::ConstLocalOutputTensorType;
85 std::string
get_type() const final {
return "abs"; }
95 template <
typename ArchiveT>
99 ar(::cereal::make_nvp(
"DataTypeOperator",
100 ::cereal::base_class<OperatorType>(
this)));
109 std::vector<LocalOutputTensorType> output)
const final;
113 std::vector<ConstLocalInputTensorType> input,
114 std::vector<ConstLocalOutputTensorType> gradient_wrt_output,
115 std::vector<LocalInputTensorType> gradient_wrt_input)
const final;
119 msg.mutable_parameters()->PackFrom(lbann_data::AbsOperator{});
125 #ifndef LBANN_ABS_OPERATOR_INSTANTIATE 126 #define PROTO_DEVICE(T, D) extern template class AbsOperator<T, D> 129 #endif // LBANN_ABS_OPERATOR_INSTANTIATE 133 #endif // LBANN_OPERATORS_MATH_ABS_HPP_INCLUDED
Element-wise specific tensor operation sub-class.
Inject polymorphic clone functions into hierarchies.
typename BaseType::LocalInputTensorType LocalInputTensorType
typename BaseType::ConstLocalOutputTensorType ConstLocalOutputTensorType
Entrywise absolute value.
typename BaseType::LocalOutputTensorType LocalOutputTensorType
Generates nicely formatted description messages.
void bp_compute_local(std::vector< ConstLocalInputTensorType > input, std::vector< ConstLocalOutputTensorType > gradient_wrt_output, std::vector< LocalInputTensorType > gradient_wrt_input) const final
Local backward compute function.
void set_proto_params(lbann_data::Operator &msg) const final
virtual void fp_compute_local(std::vector< ConstLocalInputTensorType > input, std::vector< LocalOutputTensorType > output) const final
Local forward compute function.
void serialize(ArchiveT &ar)
int get_backprop_requirements() const final
AbsOperator & operator=(AbsOperator &&)=default
std::string get_type() const final
void do_fill_description(description &desc) const final
typename BaseType::ConstLocalInputTensorType ConstLocalInputTensorType