27 #ifndef LBANN_INCLUDE_LBANN_OPERATORS_BINARY_WITH_CONSTANT_HPP_INCLUDED 28 #define LBANN_INCLUDE_LBANN_OPERATORS_BINARY_WITH_CONSTANT_HPP_INCLUDED 30 #include "lbann_config.hpp" 35 #include "lbann/proto/operators.pb.h" 51 #include "lbann/proto/operators.pb.h" 55 #define LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(OP_NAME, \ 58 template <typename DataT, El::Device D> \ 59 class OP_NAME##Operator final \ 60 : public Cloneable<OP_NAME##Operator<DataT, D>, \ 61 ElementwiseOperator<DataT, DataT, D>> \ 63 using BaseType = Cloneable<OP_NAME##Operator<DataT, D>, \ 64 ElementwiseOperator<DataT, DataT, D>>; \ 65 using LocalInputTensorType = typename BaseType::LocalInputTensorType; \ 66 using LocalOutputTensorType = typename BaseType::LocalOutputTensorType; \ 67 using ConstLocalInputTensorType = \ 68 typename BaseType::ConstLocalInputTensorType; \ 69 using ConstLocalOutputTensorType = \ 70 typename BaseType::ConstLocalOutputTensorType; \ 73 OP_NAME##Operator(double constant = 0.) \ 74 : m_constant{El::To<DataT>(constant)} \ 76 OP_NAME##Operator(OP_NAME##Operator&&) = default; \ 77 OP_NAME##Operator(OP_NAME##Operator const&) = default; \ 78 OP_NAME##Operator& operator=(OP_NAME##Operator&&) = default; \ 79 OP_NAME##Operator& operator=(OP_NAME##Operator const&) = default; \ 80 ~OP_NAME##Operator() = default; \ 81 std::string get_type() const final { return OP_STRING; } \ 82 int get_backprop_requirements() const final \ 84 return ((NEEDS_PREVACTS) ? (ERROR_SIGNALS | PREV_ACTIVATIONS) \ 87 template <typename ArchiveT> \ 88 void serialize(ArchiveT& ar) \ 90 using OperatorType = ElementwiseOperator<DataT, DataT, D>; \ 91 ar(::cereal::make_nvp("ElementwiseOperator", \ 92 ::cereal::base_class<OperatorType>(this)), \ 93 CEREAL_NVP(m_constant)); \ 95 DataT get_constant() const noexcept { return m_constant; } \ 99 fp_compute_local(std::vector<ConstLocalInputTensorType> inputs, \ 100 std::vector<LocalOutputTensorType> outputs) const final; \ 101 void bp_compute_local( \ 102 std::vector<ConstLocalInputTensorType> inputs, \ 103 std::vector<ConstLocalOutputTensorType> grads_wrt_outputs, \ 104 std::vector<LocalInputTensorType> grads_wrt_inputs) const final; \ 105 void set_proto_params(lbann_data::Operator& msg) const final \ 107 lbann_data::OP_NAME##Operator op_msg; \ 108 op_msg.set_constant(m_constant); \ 109 msg.mutable_parameters()->PackFrom(op_msg); \ 111 void do_fill_description(description& desc) const final \ 113 std::ostringstream oss; \ 115 desc.add("Constant", oss.str()); \ 138 "subtract from constant",
148 "not equals constant",
151 "less-equals constant",
154 "less than constant",
157 "greater-equals constant",
160 "greater than constant",
164 #endif // LBANN_INCLUDE_LBANN_OPERATORS_BINARY_WITH_CONSTANT_HPP_INCLUDED
LBANN_DECLARE_BINARY_WITH_CONSTANT_OPERATOR(AddConstant, "add constant", false)