LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
elementwise_operator.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_OPERATORS_ELEMENTWISE_OPERATOR_HPP_INCLUDED
28 #define LBANN_OPERATORS_ELEMENTWISE_OPERATOR_HPP_INCLUDED
29 
31 
32 #include <cereal/cereal.hpp>
33 
34 #include <functional>
35 #include <iterator>
36 #include <type_traits>
37 
38 namespace lbann {
39 
44 template <typename InputT, typename OutputT, El::Device D>
46  : public AbstractCloneableBase<ElementwiseOperator<InputT, OutputT, D>,
47  Operator<InputT, OutputT, D>>
48 {
49 public:
51 
53  using BaseType =
56 
59  using ConstInputTensorType = typename BaseType::ConstInputTensorType;
60  using ConstOutputTensorType = typename BaseType::ConstOutputTensorType;
61 
66 
68 
69 public:
70  ElementwiseOperator() = default;
71  virtual ~ElementwiseOperator() = default;
72 
74 
76  template <typename ArchiveT>
77  void serialize(ArchiveT& ar)
78  {
79  ar(cereal::base_class<Operator<InputT, OutputT, D>>(this));
80  };
81 
83 
85  using BaseType::bp_compute;
87  using BaseType::fp_compute;
88 
89  template <typename TensorViewType>
90  static auto get_local_tensor_views(std::vector<TensorViewType> const& in)
91  {
92  using LocalViewType = std::decay_t<decltype(in[0].local_data())>;
93 
94  std::vector<LocalViewType> local_views;
95  local_views.reserve(in.size());
96  std::transform(cbegin(in),
97  cend(in),
98  std::back_inserter(local_views),
99  [](auto const& x) { return x.local_data(); });
100  return local_views;
101  }
102 
107  void fp_compute(std::vector<ConstInputTensorType> const& inputs,
108  std::vector<OutputTensorType> const& outputs) const final
109  {
111  get_local_tensor_views(outputs));
112  }
113 
114  // ===========================================================
115  // Back prop compute function
116  // ===========================================================
117 
124  std::vector<ConstInputTensorType> const& inputs,
125  std::vector<ConstOutputTensorType> const& gradient_wrt_outputs,
126  std::vector<InputTensorType> const& gradient_wrt_inputs) const final
127  {
129  get_local_tensor_views(gradient_wrt_outputs),
130  get_local_tensor_views(gradient_wrt_inputs));
131  }
132 
134 
135 protected:
137  ElementwiseOperator(ElementwiseOperator const&) = default;
143 
145 
148  virtual void
149  fp_compute_local(std::vector<ConstLocalInputTensorType> input,
150  std::vector<LocalOutputTensorType> output) const = 0;
151 
153  virtual void bp_compute_local(
154  std::vector<ConstLocalInputTensorType> input,
155  std::vector<ConstLocalOutputTensorType> gradient_wrt_output,
156  std::vector<LocalInputTensorType> gradient_wrt_input) const = 0;
157 
159 
160 }; // class ElementwiseOperator
161 
162 } // namespace lbann
163 #endif // LBANN_OPERATORS_ELEMENTWISE_OPERATOR_HPP_INCLUDED
Element-wise specific tensor operation sub-class.
Inject polymorphic clone functions into hierarchies.
Definition: cloneable.hpp:94
void bp_compute(std::vector< ConstInputTensorType > const &inputs, std::vector< ConstOutputTensorType > const &gradient_wrt_outputs, std::vector< InputTensorType > const &gradient_wrt_inputs) const final
Compute operator&#39;s "backward" operation.
virtual void bp_compute_local(std::vector< ConstLocalInputTensorType > input, std::vector< ConstLocalOutputTensorType > gradient_wrt_output, std::vector< LocalInputTensorType > gradient_wrt_input) const =0
Local backward compute function.
typename BaseType::ConstOutputTensorType ConstOutputTensorType
typename BaseType::ConstInputTensorType ConstInputTensorType
void fp_compute(std::vector< ConstInputTensorType > const &inputs, std::vector< OutputTensorType > const &outputs) const final
Apply operator&#39;s forward operation.
ElementwiseOperator & operator=(ElementwiseOperator const &)=default
typename BaseType::OutputTensorType OutputTensorType
typename BaseType::InputTensorType InputTensorType
typename OperatorTraits< OpT >::input_tensor_type InputTensorType
virtual ~ElementwiseOperator()=default
virtual void fp_compute_local(std::vector< ConstLocalInputTensorType > input, std::vector< LocalOutputTensorType > output) const =0
Local forward compute function.
typename OperatorTraits< OpT >::output_tensor_type OutputTensorType
static auto get_local_tensor_views(std::vector< TensorViewType > const &in)
Neural network tensor operation.
Definition: operator.hpp:85