27 #ifndef LBANN_INCLUDE_LBANN_OPERATORS_SELLECT_HPP_INCLUDED 28 #define LBANN_INCLUDE_LBANN_OPERATORS_SELLECT_HPP_INCLUDED 30 #include "lbann_config.hpp" 35 #include "lbann/proto/operators.pb.h" 41 #include "lbann/proto/operators.pb.h" 45 template <
typename DataT, El::Device D>
47 :
public Cloneable<SelectOperator<DataT, D>,
48 ElementwiseOperator<DataT, DataT, D>>
55 typename BaseType::ConstLocalInputTensorType;
57 typename BaseType::ConstLocalOutputTensorType;
61 bool constant_if_true =
false,
62 bool constant_if_false =
false,
63 double value_if_true = 0.,
64 double value_if_false = 0.,
65 double epsilon = 1e-5)
66 :
m_value{El::To<DataT>(value)},
78 std::string
get_type() const final {
return "select"; }
83 template <
typename ArchiveT>
86 using OperatorType = ElementwiseOperator<DataT, DataT, D>;
87 ar(::cereal::make_nvp(
"ElementwiseOperator",
88 ::cereal::base_class<OperatorType>(
this)),
106 std::vector<LocalOutputTensorType> outputs)
const final;
108 std::vector<ConstLocalInputTensorType> inputs,
109 std::vector<ConstLocalOutputTensorType> grads_wrt_outputs,
110 std::vector<LocalInputTensorType> grads_wrt_inputs)
const final;
113 lbann_data::SelectOperator op_msg;
120 msg.mutable_parameters()->PackFrom(op_msg);
125 std::ostringstream oss;
127 desc.add(
"Value", oss.str());
130 std::ostringstream oss;
132 desc.add(
"If equal (constant)", oss.str());
135 std::ostringstream oss;
137 desc.add(
"If unequal (constant)", oss.str());
140 std::ostringstream oss;
142 desc.add(
"Equality epsilon", oss.str());
153 #endif // LBANN_INCLUDE_LBANN_OPERATORS_SELLECT_HPP_INCLUDED
Element-wise specific tensor operation sub-class.
DataT get_constant_true_case()
Inject polymorphic clone functions into hierarchies.
void fp_compute_local(std::vector< ConstLocalInputTensorType > inputs, std::vector< LocalOutputTensorType > outputs) const final
DataT get_constant_false_case()
bool is_true_case_constant()
void bp_compute_local(std::vector< ConstLocalInputTensorType > inputs, std::vector< ConstLocalOutputTensorType > grads_wrt_outputs, std::vector< LocalInputTensorType > grads_wrt_inputs) const final
Generates nicely formatted description messages.
SelectOperator(double value=0., bool constant_if_true=false, bool constant_if_false=false, double value_if_true=0., double value_if_false=0., double epsilon=1e-5)
void do_fill_description(description &desc) const final
bool is_false_case_constant()
typename BaseType::LocalOutputTensorType LocalOutputTensorType
void set_proto_params(lbann_data::Operator &msg) const final
~SelectOperator()=default
SelectOperator & operator=(SelectOperator &&)=default
std::string get_type() const final
typename BaseType::ConstLocalOutputTensorType ConstLocalOutputTensorType
int get_backprop_requirements() const final
void serialize(ArchiveT &ar)
typename BaseType::LocalInputTensorType LocalInputTensorType
typename BaseType::ConstLocalInputTensorType ConstLocalInputTensorType