27 #ifndef LBANN_OPERATORS_MATH_CLAMP_HPP_INCLUDED 28 #define LBANN_OPERATORS_MATH_CLAMP_HPP_INCLUDED 30 #include "lbann_config.hpp" 35 #include "lbann/proto/operators.pb.h" 37 #include <h2/meta/Core.hpp> 52 template <
typename DataT, El::Device D>
54 :
public Cloneable<ClampOperator<DataT, D>,
55 ElementwiseOperator<DataT, DataT, D>>
57 #ifdef LBANN_HAS_GPU_FP16 59 h2::meta::IfThenElse<std::is_same_v<DataT, fp16>, float, DataT>;
61 using CompareType = DataT;
73 typename BaseType::ConstLocalInputTensorType;
75 typename BaseType::ConstLocalOutputTensorType;
84 :
m_min{El::To<DataT>(min)},
m_max{El::To<DataT>(max)}
100 std::string
get_type() const final {
return "clamp"; }
112 template <
typename ArchiveT>
115 using OperatorType = ElementwiseOperator<DataT, DataT, D>;
116 ar(::cereal::make_nvp(
"ElementwiseOperator",
117 ::cereal::base_class<OperatorType>(
this)),
131 std::vector<LocalOutputTensorType> output)
const final;
135 std::vector<ConstLocalInputTensorType> input,
136 std::vector<ConstLocalOutputTensorType> gradient_wrt_output,
137 std::vector<LocalInputTensorType> gradient_wrt_input)
const final;
141 lbann_data::ClampOperator clamp_msg;
142 clamp_msg.set_min(
m_min);
143 clamp_msg.set_max(
m_max);
144 msg.mutable_parameters()->PackFrom(clamp_msg);
149 std::ostringstream oss;
151 desc.add(
"Range", oss.str());
161 #ifndef LBANN_CLAMP_OPERATOR_INSTANTIATE 162 #define PROTO_DEVICE(T, D) extern template class ClampOperator<T, D> 165 #endif // LBANN_CLAMP_OPERATOR_INSTANTIATE 169 #endif // LBANN_OPERATORS_MATH_CLAMP_HPP_INCLUDED void bp_compute_local(std::vector< ConstLocalInputTensorType > input, std::vector< ConstLocalOutputTensorType > gradient_wrt_output, std::vector< LocalInputTensorType > gradient_wrt_input) const final
Local backward compute function.
void fp_compute_local(std::vector< ConstLocalInputTensorType > input, std::vector< LocalOutputTensorType > output) const final
Local forward compute function.
Element-wise specific tensor operation sub-class.
int get_backprop_requirements() const final
Inject polymorphic clone functions into hierarchies.
ClampOperator(double min, double max)
Generates nicely formatted description messages.
Constrain values to a range.
typename BaseType::LocalOutputTensorType LocalOutputTensorType
DataT get_max() const noexcept
#define LBANN_ASSERT(cond)
typename BaseType::ConstLocalInputTensorType ConstLocalInputTensorType
std::string get_type() const final
friend class cereal::access
void do_fill_description(description &desc) const final
void serialize(ArchiveT &ar)
typename BaseType::LocalInputTensorType LocalInputTensorType
DataT get_min() const noexcept
typename BaseType::ConstLocalOutputTensorType ConstLocalOutputTensorType
void set_proto_params(lbann_data::Operator &msg) const final
ClampOperator & operator=(ClampOperator &&)=default