27 #ifndef LBANN_OPERATORS_OPERATOR_HPP_INCLUDED 28 #define LBANN_OPERATORS_OPERATOR_HPP_INCLUDED 39 #include "lbann/proto/operators.pb.h" 41 #include <h2/meta/Core.hpp> 42 #include <h2/meta/TypeList.hpp> 44 #include <google/protobuf/message.h> 52 #ifdef LBANN_HAS_GPU_FP16 84 template <
typename InputT,
typename OutputT, El::Device D>
100 h2::meta::tlist::MemberV<InputT, supported_operator_data_type>(),
101 "Must use a supported input type.");
103 h2::meta::tlist::MemberV<OutputT, supported_operator_data_type>(),
104 "Must use a supported output type.");
115 virtual std::string
get_type()
const = 0;
134 void write_proto(google::protobuf::Message& msg)
const;
139 template <
typename ArchiveT>
151 fp_compute(std::vector<ConstInputTensorType>
const& inputs,
152 std::vector<OutputTensorType>
const& outputs)
const = 0;
160 bp_compute(std::vector<ConstInputTensorType>
const& inputs,
161 std::vector<ConstOutputTensorType>
const& gradient_wrt_outputs,
162 std::vector<InputTensorType>
const& gradient_wrt_inputs)
const;
178 template <
typename InputT,
typename OutputT, El::Device D>
180 google::protobuf::Message& msg)
const 182 lbann_data::Operator operator_msg;
183 operator_msg.set_input_datatype(proto::ProtoDataType<InputT>);
184 operator_msg.set_output_datatype(proto::ProtoDataType<OutputT>);
185 operator_msg.set_device_allocation(proto::ProtoDevice<D>);
189 msg.CopyFrom(operator_msg);
192 template <
typename InputT,
typename OutputT, El::Device D>
200 desc.
add(
"Input data type", TypeName<InputT>());
201 desc.
add(
"Output data type", TypeName<OutputT>());
208 template <
typename InputT,
typename OutputT, El::Device D>
210 std::vector<ConstInputTensorType>
const&,
211 std::vector<ConstOutputTensorType>
const&,
212 std::vector<InputTensorType>
const&)
const 215 template <
typename InputT,
typename OutputT, El::Device D>
216 template <
typename ArchiveT>
221 #endif // LBANN_OPERATORS_OPERATOR_HPP_INCLUDED virtual void do_fill_description(Description &) const =0
Concrete operator description.
Inject polymorphic clone functions into hierarchies.
virtual int get_backprop_requirements() const
Returns the necessary tensors for computing backpropagation for this operator.
Represents a class that is describable in LBANN's protobuf specification.
virtual ~Operator()=default
Destructor.
Generates nicely formatted description messages.
virtual void fp_compute(std::vector< ConstInputTensorType > const &inputs, std::vector< OutputTensorType > const &outputs) const =0
Apply operator's forward operation.
virtual void bp_compute(std::vector< ConstInputTensorType > const &inputs, std::vector< ConstOutputTensorType > const &gradient_wrt_outputs, std::vector< InputTensorType > const &gradient_wrt_inputs) const
Compute operator's "backward" operation.
Operator & operator=(Operator &&other) noexcept=default
Operator()=default
Constructor.
Description get_description() const override
Get the description of the operator.
A class that can generate self-descriptions.
void serialize(ArchiveT &ar)
void add(std::string line)
void write_proto(google::protobuf::Message &msg) const
Write a protobuf description of the operator.
virtual std::string get_type() const =0
Get the operator type's name.
h2::meta::TL< float, double, El::Complex< float >, El::Complex< double > > supported_operator_data_type
virtual void set_proto_params(lbann_data::Operator &) const =0
Fill the concrete operator parameters.
Neural network tensor operation.