LBANN
0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
declare_stateless_op.hpp
Go to the documentation of this file.
1
// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3
// Produced at the Lawrence Livermore National Laboratory.
4
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5
// the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6
//
7
// LLNL-CODE-697807.
8
// All rights reserved.
9
//
10
// This file is part of LBANN: Livermore Big Artificial Neural Network
11
// Toolkit. For details, see http://software.llnl.gov/LBANN or
12
// https://github.com/LLNL/LBANN.
13
//
14
// Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15
// may not use this file except in compliance with the License. You may
16
// obtain a copy of the License at:
17
//
18
// http://www.apache.org/licenses/LICENSE-2.0
19
//
20
// Unless required by applicable law or agreed to in writing, software
21
// distributed under the License is distributed on an "AS IS" BASIS,
22
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23
// implied. See the License for the specific language governing
24
// permissions and limitations under the license.
26
#ifndef LBANN_INCLUDE_LBANN_OPERATORS_DECLARE_STATELESS_OP_HPP_INCLUDED
27
#define LBANN_INCLUDE_LBANN_OPERATORS_DECLARE_STATELESS_OP_HPP_INCLUDED
28
29
#include "
lbann/operators/elementwise_operator.hpp
"
30
#include "
lbann/operators/operator.hpp
"
31
#include "
lbann/utils/cloneable.hpp
"
32
33
#include "lbann/proto/operators.pb.h"
34
35
// These are all single-type operators.
36
37
#define LBANN_DECLARE_STATELESS_OPERATOR(OP_NAME, OP_STRING, NEEDS_PREVACTS) \
38
template <typename DataT, El::Device D> \
39
class OP_NAME##Operator final \
40
: public Cloneable<OP_NAME##Operator<DataT, D>, Operator<DataT, DataT, D>> \
41
{ \
42
using BaseType = \
43
Cloneable<OP_NAME##Operator<DataT, D>, Operator<DataT, DataT, D>>; \
44
using InputTensorType = typename BaseType::InputTensorType; \
45
using OutputTensorType = typename BaseType::OutputTensorType; \
46
using ConstInputTensorType = typename BaseType::ConstInputTensorType; \
47
using ConstOutputTensorType = typename BaseType::ConstOutputTensorType; \
48
\
49
public: \
50
OP_NAME##Operator() = default; \
51
OP_NAME##Operator(OP_NAME##Operator&&) = default; \
52
OP_NAME##Operator(OP_NAME##Operator const&) = default; \
53
OP_NAME##Operator& operator=(OP_NAME##Operator&&) = default; \
54
OP_NAME##Operator& operator=(OP_NAME##Operator const&) = default; \
55
~OP_NAME##Operator() = default; \
56
std::string get_type() const final { return OP_STRING; } \
57
int get_backprop_requirements() const final \
58
{ \
59
return ((NEEDS_PREVACTS) ? (ERROR_SIGNALS | PREV_ACTIVATIONS) \
60
: ERROR_SIGNALS); \
61
} \
62
template <typename ArchiveT> \
63
void serialize(ArchiveT& ar) \
64
{ \
65
using OperatorType = Operator<DataT, DataT, D>; \
66
ar(::cereal::make_nvp("Operator", \
67
::cereal::base_class<OperatorType>(this))); \
68
} \
69
void fp_compute(std::vector<ConstInputTensorType> const& inputs, \
70
std::vector<OutputTensorType> const& outputs) const final; \
71
void bp_compute( \
72
std::vector<ConstInputTensorType> const& inputs, \
73
std::vector<ConstOutputTensorType> const& gradient_wrt_outputs, \
74
std::vector<InputTensorType> const& gradient_wrt_inputs) const final; \
75
\
76
private: \
77
void set_proto_params(lbann_data::Operator& msg) const final \
78
{ \
79
msg.mutable_parameters()->PackFrom(lbann_data::OP_NAME##Operator{}); \
80
} \
81
void do_fill_description(description&) const final {} \
82
}
83
84
#define LBANN_DECLARE_STATELESS_ELEMENTWISE_OPERATOR(OP_NAME, \
85
OP_STRING, \
86
NEEDS_PREVACTS) \
87
template <typename DataT, El::Device D> \
88
class OP_NAME##Operator final \
89
: public Cloneable<OP_NAME##Operator<DataT, D>, \
90
ElementwiseOperator<DataT, DataT, D>> \
91
{ \
92
using BaseType = Cloneable<OP_NAME##Operator<DataT, D>, \
93
ElementwiseOperator<DataT, DataT, D>>; \
94
using LocalInputTensorType = typename BaseType::LocalInputTensorType; \
95
using LocalOutputTensorType = typename BaseType::LocalOutputTensorType; \
96
using ConstLocalInputTensorType = \
97
typename BaseType::ConstLocalInputTensorType; \
98
using ConstLocalOutputTensorType = \
99
typename BaseType::ConstLocalOutputTensorType; \
100
\
101
public: \
102
OP_NAME##Operator() = default; \
103
OP_NAME##Operator(OP_NAME##Operator&&) = default; \
104
OP_NAME##Operator(OP_NAME##Operator const&) = default; \
105
OP_NAME##Operator& operator=(OP_NAME##Operator&&) = default; \
106
OP_NAME##Operator& operator=(OP_NAME##Operator const&) = default; \
107
~OP_NAME##Operator() = default; \
108
std::string get_type() const final { return OP_STRING; } \
109
int get_backprop_requirements() const final \
110
{ \
111
return ((NEEDS_PREVACTS) ? (ERROR_SIGNALS | PREV_ACTIVATIONS) \
112
: ERROR_SIGNALS); \
113
} \
114
template <typename ArchiveT> \
115
void serialize(ArchiveT& ar) \
116
{ \
117
using OperatorType = ElementwiseOperator<DataT, DataT, D>; \
118
ar(::cereal::make_nvp("ElementwiseOperator", \
119
::cereal::base_class<OperatorType>(this))); \
120
} \
121
\
122
private: \
123
void \
124
fp_compute_local(std::vector<ConstLocalInputTensorType> inputs, \
125
std::vector<LocalOutputTensorType> outputs) const final; \
126
void bp_compute_local( \
127
std::vector<ConstLocalInputTensorType> inputs, \
128
std::vector<ConstLocalOutputTensorType> grads_wrt_outputs, \
129
std::vector<LocalInputTensorType> grads_wrt_inputs) const final; \
130
void set_proto_params(lbann_data::Operator& msg) const final \
131
{ \
132
msg.mutable_parameters()->PackFrom(lbann_data::OP_NAME##Operator{}); \
133
} \
134
void do_fill_description(description&) const final {} \
135
}
136
137
#endif // LBANN_INCLUDE_LBANN_OPERATORS_DECLARE_STATELESS_OP_HPP_INCLUDED
operator.hpp
cloneable.hpp
elementwise_operator.hpp
include
lbann
operators
declare_stateless_op.hpp
Generated on Wed Oct 11 2023 20:49:37 for LBANN by
1.8.13