LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
select.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_INCLUDE_LBANN_OPERATORS_SELLECT_HPP_INCLUDED
28 #define LBANN_INCLUDE_LBANN_OPERATORS_SELLECT_HPP_INCLUDED
29 
30 #include "lbann_config.hpp"
31 
34 
35 #include "lbann/proto/operators.pb.h"
36 
40 
41 #include "lbann/proto/operators.pb.h"
42 
43 namespace lbann {
44 
45 template <typename DataT, El::Device D>
46 class SelectOperator final
47  : public Cloneable<SelectOperator<DataT, D>,
48  ElementwiseOperator<DataT, DataT, D>>
49 {
50  using BaseType =
52  using LocalInputTensorType = typename BaseType::LocalInputTensorType;
53  using LocalOutputTensorType = typename BaseType::LocalOutputTensorType;
55  typename BaseType::ConstLocalInputTensorType;
57  typename BaseType::ConstLocalOutputTensorType;
58 
59 public:
60  SelectOperator(double value = 0.,
61  bool constant_if_true = false,
62  bool constant_if_false = false,
63  double value_if_true = 0.,
64  double value_if_false = 0.,
65  double epsilon = 1e-5)
66  : m_value{El::To<DataT>(value)},
67  m_constant_if_true{constant_if_true},
68  m_constant_if_false{constant_if_false},
69  m_value_if_true{El::To<DataT>(value_if_true)},
70  m_value_if_false{El::To<DataT>(value_if_false)},
71  m_epsilon{El::To<DataT>(epsilon)}
72  {}
73  SelectOperator(SelectOperator&&) = default;
74  SelectOperator(SelectOperator const&) = default;
76  SelectOperator& operator=(SelectOperator const&) = default;
77  ~SelectOperator() = default;
78  std::string get_type() const final { return "select"; }
79  int get_backprop_requirements() const final
80  {
82  }
83  template <typename ArchiveT>
84  void serialize(ArchiveT& ar)
85  {
86  using OperatorType = ElementwiseOperator<DataT, DataT, D>;
87  ar(::cereal::make_nvp("ElementwiseOperator",
88  ::cereal::base_class<OperatorType>(this)),
89  CEREAL_NVP(m_value),
90  CEREAL_NVP(m_constant_if_true),
91  CEREAL_NVP(m_constant_if_false),
92  CEREAL_NVP(m_value_if_true),
93  CEREAL_NVP(m_value_if_false),
94  CEREAL_NVP(m_epsilon));
95  }
96 
97  DataT get_value() { return m_value; }
98  DataT get_epsilon() { return m_epsilon; }
103 
104 private:
105  void fp_compute_local(std::vector<ConstLocalInputTensorType> inputs,
106  std::vector<LocalOutputTensorType> outputs) const final;
107  void bp_compute_local(
108  std::vector<ConstLocalInputTensorType> inputs,
109  std::vector<ConstLocalOutputTensorType> grads_wrt_outputs,
110  std::vector<LocalInputTensorType> grads_wrt_inputs) const final;
111  void set_proto_params(lbann_data::Operator& msg) const final
112  {
113  lbann_data::SelectOperator op_msg;
114  op_msg.set_value(m_value);
115  op_msg.set_constant_if_true(m_constant_if_true);
116  op_msg.set_constant_if_false(m_constant_if_false);
117  op_msg.set_value_if_true(m_value_if_true);
118  op_msg.set_value_if_false(m_value_if_false);
119  op_msg.set_epsilon(m_epsilon);
120  msg.mutable_parameters()->PackFrom(op_msg);
121  }
122  void do_fill_description(description& desc) const final
123  {
124  {
125  std::ostringstream oss;
126  oss << m_value;
127  desc.add("Value", oss.str());
128  }
129  if (m_constant_if_true) {
130  std::ostringstream oss;
131  oss << m_value_if_true;
132  desc.add("If equal (constant)", oss.str());
133  }
134  if (m_constant_if_false) {
135  std::ostringstream oss;
136  oss << m_value_if_false;
137  desc.add("If unequal (constant)", oss.str());
138  }
139  {
140  std::ostringstream oss;
141  oss << m_epsilon;
142  desc.add("Equality epsilon", oss.str());
143  }
144  }
145 
146 private:
147  DataT m_value;
150 };
151 
152 } // namespace lbann
153 #endif // LBANN_INCLUDE_LBANN_OPERATORS_SELLECT_HPP_INCLUDED
Element-wise specific tensor operation sub-class.
DataT get_constant_true_case()
Definition: select.hpp:101
Inject polymorphic clone functions into hierarchies.
Definition: cloneable.hpp:94
void fp_compute_local(std::vector< ConstLocalInputTensorType > inputs, std::vector< LocalOutputTensorType > outputs) const final
DataT get_constant_false_case()
Definition: select.hpp:102
bool is_true_case_constant()
Definition: select.hpp:99
void bp_compute_local(std::vector< ConstLocalInputTensorType > inputs, std::vector< ConstLocalOutputTensorType > grads_wrt_outputs, std::vector< LocalInputTensorType > grads_wrt_inputs) const final
Generates nicely formatted description messages.
Definition: description.hpp:49
SelectOperator(double value=0., bool constant_if_true=false, bool constant_if_false=false, double value_if_true=0., double value_if_false=0., double epsilon=1e-5)
Definition: select.hpp:60
void do_fill_description(description &desc) const final
Definition: select.hpp:122
bool is_false_case_constant()
Definition: select.hpp:100
typename BaseType::LocalOutputTensorType LocalOutputTensorType
Definition: select.hpp:53
void set_proto_params(lbann_data::Operator &msg) const final
Definition: select.hpp:111
SelectOperator & operator=(SelectOperator &&)=default
std::string get_type() const final
Definition: select.hpp:78
typename BaseType::ConstLocalOutputTensorType ConstLocalOutputTensorType
Definition: select.hpp:57
int get_backprop_requirements() const final
Definition: select.hpp:79
void serialize(ArchiveT &ar)
Definition: select.hpp:84
typename BaseType::LocalInputTensorType LocalInputTensorType
Definition: select.hpp:52
typename BaseType::ConstLocalInputTensorType ConstLocalInputTensorType
Definition: select.hpp:55