LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
math_builders_impl.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_INCLUDE_LBANN_OPERATORS_MATH_MATH_BUILDERS_IMPL_HPP_INCLUDED
27 #define LBANN_INCLUDE_LBANN_OPERATORS_MATH_MATH_BUILDERS_IMPL_HPP_INCLUDED
28 
30 
37 
39 #include "lbann/proto/operators.pb.h"
40 
41 template <typename DataT, El::Device D>
42 std::unique_ptr<lbann::Operator<DataT, DataT, D>>
43 lbann::build_clamp_operator(lbann_data::Operator const& op)
44 {
45  details::AssertConsistentTypeParameters<DataT, DataT, D>(op);
46  lbann_data::ClampOperator params;
47  LBANN_ASSERT(op.parameters().UnpackTo(&params));
48  return std::make_unique<ClampOperator<DataT, D>>(params.min(), params.max());
49 }
50 
51 template <typename DataT, El::Device D>
52 std::unique_ptr<lbann::Operator<DataT, El::Base<DataT>, D>>
53 lbann::build_abs_operator(lbann_data::Operator const& op)
54 {
55  details::AssertConsistentTypeParameters<DataT, El::Base<DataT>, D>(op);
56  return std::make_unique<AbsOperator<DataT, D>>();
57 }
58 
59 template <typename DataT, El::Device D>
60 std::unique_ptr<lbann::Operator<DataT, DataT, D>>
61 lbann::build_select_operator(lbann_data::Operator const& op)
62 {
63  details::AssertConsistentTypeParameters<DataT, DataT, D>(op);
64  lbann_data::SelectOperator params;
65  LBANN_ASSERT(op.parameters().UnpackTo(&params));
66  return std::make_unique<SelectOperator<DataT, D>>(params.value(),
67  params.constant_if_true(),
68  params.constant_if_false(),
69  params.value_if_true(),
70  params.value_if_false(),
71  params.epsilon());
72 }
73 
74 #define LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(OP_NAME, OP_LOWER_NAME) \
75  template <typename DataT, El::Device D> \
76  std::unique_ptr<lbann::Operator<DataT, DataT, D>> \
77  lbann::build_##OP_LOWER_NAME##_operator(lbann_data::Operator const& op) \
78  { \
79  details::AssertConsistentTypeParameters<DataT, DataT, D>(op); \
80  lbann_data::OP_NAME##Operator params; \
81  LBANN_ASSERT(op.parameters().UnpackTo(&params)); \
82  return std::make_unique<OP_NAME##Operator<DataT, D>>(params.constant()); \
83  }
84 
85 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(AddConstant, add_constant)
86 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(ConstantSubtract, constant_subtract)
87 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(EqualConstant, equal_constant)
88 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(GreaterConstant, greater_constant)
89 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(GreaterEqualConstant,
90  greater_equal_constant)
91 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(LessConstant, less_constant)
92 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(LessEqualConstant, less_equal_constant)
93 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(MaxConstant, max_constant)
94 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(MinConstant, min_constant)
95 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(NotEqualConstant, not_equal_constant)
97 LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(SubtractConstant, subtract_constant)
98 
99 #undef LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER
100 
102 LBANN_DEFINE_OPERATOR_BUILDER(acosh, Acosh)
105 LBANN_DEFINE_OPERATOR_BUILDER(asinh, Asinh)
107 LBANN_DEFINE_OPERATOR_BUILDER(atanh, Atanh)
111 LBANN_DEFINE_OPERATOR_BUILDER(divide, Divide)
112 LBANN_DEFINE_OPERATOR_BUILDER(equal, Equal)
114 LBANN_DEFINE_OPERATOR_BUILDER(erfinv, ErfInv)
116 LBANN_DEFINE_OPERATOR_BUILDER(expm1, Expm1)
117 LBANN_DEFINE_OPERATOR_BUILDER(floor, Floor)
119 LBANN_DEFINE_OPERATOR_BUILDER(greater, Greater)
120 LBANN_DEFINE_OPERATOR_BUILDER(greater_equal, GreaterEqual)
122 LBANN_DEFINE_OPERATOR_BUILDER(less_equal, LessEqual)
124 LBANN_DEFINE_OPERATOR_BUILDER(log1p, Log1p)
125 LBANN_DEFINE_OPERATOR_BUILDER(logical_and, LogicalAnd)
126 LBANN_DEFINE_OPERATOR_BUILDER(logical_not, LogicalNot)
127 LBANN_DEFINE_OPERATOR_BUILDER(logical_or, LogicalOr)
128 LBANN_DEFINE_OPERATOR_BUILDER(logical_xor, LogicalXor)
132 LBANN_DEFINE_OPERATOR_BUILDER(multiply, Multiply)
133 LBANN_DEFINE_OPERATOR_BUILDER(negative, Negative)
134 LBANN_DEFINE_OPERATOR_BUILDER(not_equal, NotEqual)
136 LBANN_DEFINE_OPERATOR_BUILDER(reciprocal, Reciprocal)
137 LBANN_DEFINE_OPERATOR_BUILDER(round, Round)
138 LBANN_DEFINE_OPERATOR_BUILDER(rsqrt, Rsqrt)
139 LBANN_DEFINE_OPERATOR_BUILDER(safe_divide, SafeDivide)
140 LBANN_DEFINE_OPERATOR_BUILDER(safe_reciprocal, SafeReciprocal)
145 LBANN_DEFINE_OPERATOR_BUILDER(square, Square)
146 LBANN_DEFINE_OPERATOR_BUILDER(squared_difference, SquaredDifference)
147 LBANN_DEFINE_OPERATOR_BUILDER(subtract, Subtract)
150 #endif // LBANN_INCLUDE_LBANN_OPERATORS_MATH_MATH_BUILDERS_IMPL_HPP_INCLUDED
#define LBANN_DEFINE_OPERATOR_BUILDER(OP_LOWER, OP_NAME)
A utility macro for easily adding a default builder with dynamic type-checking assertions.
#define LBANN_ASSERT(cond)
Definition: exception.hpp:97
#define LBANN_DEFINE_BIN_WITH_CONSTANT_BUILDER(OP_NAME, OP_LOWER_NAME)
std::unique_ptr< Operator< DataT, El::Base< DataT >, D > > build_abs_operator(lbann_data::Operator const &op)