27 #ifndef LBANN_LAYER_ACTIVATION_RELU_IMPL_HPP_INCLUDED 28 #define LBANN_LAYER_ACTIVATION_RELU_IMPL_HPP_INCLUDED 32 #ifdef LBANN_HAS_DISTCONV 34 #endif // LBANN_HAS_DISTCONV 38 template <
typename T, data_layout L, El::Device D>
41 proto.set_datatype(proto::ProtoDataType<T>);
45 #ifdef LBANN_HAS_DISTCONV 46 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
47 relu_distconv_adapter<TensorDataType, T_layout, Dev>&
50 return const_cast<relu_distconv_adapter<TensorDataType, T_layout, Dev>&
>(
52 .get_distconv_adapter());
55 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
56 const relu_distconv_adapter<TensorDataType, T_layout, Dev>&
60 const relu_distconv_adapter<TensorDataType, T_layout, Dev>&
>(
64 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
65 void relu_distconv_adapter<TensorDataType, T_layout, Dev>::setup_distributions(
70 auto& x = this->get_prev_activations_dist();
71 auto& y = this->get_activations_dist();
72 auto& dx = this->get_error_signals_dist();
73 auto& dy = this->get_prev_error_signals_dist();
81 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
82 void relu_distconv_adapter<TensorDataType, T_layout, Dev>::setup_layer(
83 size_t workspace_capacity)
85 m_relu = std::make_unique<dc::ReLU>(dc::get_backend());
86 m_relu->setup(this->get_prev_activations(),
87 this->get_activations(),
88 this->get_error_signals(),
89 this->get_prev_error_signals());
91 #endif // LBANN_HAS_DISTCONV 95 #endif // LBANN_LAYER_ACTIVATION_RELU_IMPL_HPP_INCLUDED void write_specific_proto(lbann_data::Layer &proto) const final
virtual void setup_distributions(tensor_overlap_constraints &constraints)
void mark_equivalent(dc::Dist &d1, dc::Dist &d2)