27 #ifndef LBANN_LAYERS_INPUT_LAYER_HPP_INCLUDED 28 #define LBANN_LAYERS_INPUT_LAYER_HPP_INCLUDED 37 #ifdef LBANN_HAS_DISTCONV 38 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
39 class input_distconv_adapter :
public data_type_distconv_adapter<TensorDataType>
44 using TensorHost = dc::TensorHost<TensorDataType>;
45 using TensorHostShuffler = dc::TensorHostShuffler<TensorDataType>;
47 input_distconv_adapter(Layer& layer,
49 const bool shuffle_required);
50 virtual ~input_distconv_adapter() =
default;
52 void setup_layer(
size_t workspace_capacity)
override;
54 TensorHostShuffler& get_shuffler(
const TensorHost& src,
55 const TensorHost& dst);
56 void setup_fp_tensors()
override;
57 std::unique_ptr<TensorDevType> setup_activations_i(
int index)
const override;
58 dc::Shape get_activations_local_shape(
int index)
const override;
59 dc::Shape get_activations_shape(
int index)
const override;
60 void setup_shuffler_buffers(
const TensorHost& src,
const TensorHost& dst);
63 void setup_prev_error_signals()
override {}
64 void setup_original_prev_error_signals()
override {}
65 void setup_error_signals()
override {}
66 void setup_original_error_signals()
override {}
67 void setup_bp_tensors()
override {}
69 bool child_copy_required(
size_t output_index)
const override;
70 bool child_shuffle_required(
size_t output_index)
const override;
73 void fp_setup()
override {}
80 bool m_is_input_processed;
81 std::unique_ptr<TensorHost> m_original_host_tensor;
82 std::unique_ptr<TensorHost> m_host_tensor;
84 const bool m_shuffle_required;
85 std::array<std::unique_ptr<TensorHostShuffler>, 4> m_shufflers;
86 std::unique_ptr<TensorDataType> m_shuffler_src_buf;
87 size_t m_shuffler_src_buf_size = 0;
88 std::unique_ptr<TensorDataType> m_shuffler_dst_buf;
89 size_t m_shuffler_dst_buf_size = 0;
92 TensorDataType* m_copy_pinned_buffer =
nullptr;
94 #endif // LBANN_HAS_DISTCONV 97 template <
typename TensorDataType,
103 "input layer only supports DATA_PARALLEL data layout");
120 this->m_expected_num_parent_layers = 0;
121 this->m_expected_num_child_layers = 1;
128 std::string
get_type()
const override {
return "input"; }
130 #ifdef LBANN_HAS_ONNX 131 void fill_onnx_node(onnx::GraphProto& graph)
const override;
132 #endif // LBANN_HAS_ONNX 143 void setup_dims()
override;
145 void setup_data(
size_t max_mini_batch_size)
override;
150 void fp_setup_outputs()
override;
152 void fp_compute()
override;
157 void set_samples(
const El::AbstractDistMatrix<TensorDataType>& samples);
163 int child_index = 0)
const;
168 template <
typename ArchiveT>
175 void write_specific_proto(lbann_data::Layer& proto)
const final;
178 friend cereal::access;
183 bool m_samples_loaded =
false;
187 #ifdef LBANN_HAS_DISTCONV 190 using distconv_adapter_type =
192 input_distconv_adapter<TensorDataType, T_layout, Dev>;
193 friend distconv_adapter_type;
196 bool is_distconv_supported()
const override 200 void setup_distconv_adapter()
override;
201 distconv_adapter_type& get_distconv_adapter()
override;
202 const distconv_adapter_type& get_distconv_adapter()
const override;
203 bool keep_original_outputs(
int index)
const override;
204 bool keep_original_gradient_wrt_outputs(
int index)
const override;
206 #endif // LBANN_HAS_DISTCONV 211 #ifndef LBANN_INPUT_LAYER_INSTANTIATE 213 #define PROTO_DEVICE(T, Device) \ 214 extern template class input_layer<T, data_layout::DATA_PARALLEL, Device> 219 #endif // LBANN_INPUT_LAYER_INSTANTIATE 223 #endif // LBANN_LAYERS_INPUT_LAYER_HPP_INCLUDED
void serialize(std::ostream &os, google::protobuf::Message const &msg)
Serialize the protobuf message to a stream.
constexpr El::Device Device
::distconv::tensor::Shape Shape
data_layout
Data layout that is optimized for different modes of parallelism.
std::string data_field_type
LBANN_DEFINE_LAYER_BUILDER(elu)
dc::TensorDev< OutputTensorDataType > TensorDevType