27 #ifndef LBANN_LAYERS_MISC_ARGMIN_HPP_INCLUDED 28 #define LBANN_LAYERS_MISC_ARGMIN_HPP_INCLUDED 39 template <
typename TensorDataType, data_layout Layout, El::Device Device>
43 "argmin layer only supports data parallel layout");
44 static_assert(
Device == El::Device::CPU,
"argmin layer only supports CPU");
53 template <
typename ArchiveT>
58 std::string
get_type()
const override {
return "argmin"; }
76 #ifndef LBANN_ARGMIN_LAYER_INSTANTIATE 78 extern template class argmin_layer<T, \ 79 data_layout::DATA_PARALLEL, \ 82 #define LBANN_INSTANTIATE_CPU_HALF 85 #undef LBANN_INSTANTIATE_CPU_HALF 86 #endif // LBANN_ARGMIN_LAYER_INSTANTIATE 89 #endif // LBANN_LAYERS_MISC_ARGMIN_HPP_INCLUDED
void fp_compute() override
Apply layer operation. Called by the 'forward_prop' function. Given the input tensors, the output tensors are populated with computed values.
void setup_dims() override
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
std::string get_type() const override
Get the layer type's name.
bool can_run_inplace() const override
If True, the computation can run in-place (feeding each input activations tensor as the corresponding...
data_layout get_data_layout() const override
Get data layout of the data tensors. We assume that the data layouts of the previous activations...
void serialize(ArchiveT &ar)
constexpr El::Device Device
El::Device get_device_allocation() const override
Get the device allocation for the data tensors. We assume that the decice allocation of the previous ...
void write_specific_proto(lbann_data::Layer &proto) const final
argmin_layer(lbann_comm *comm)
Get index of minimum-value tensor entry.
data_layout
Data layout that is optimized for different modes of parallelism.
friend class cereal::access
argmin_layer * copy() const override
Copy function. This function dynamically allocates memory for a layer instance and instantiates a cop...
int get_backprop_requirements() const override
Returns the necessary tensors for computing backpropagation.