27 #ifndef LBANN_LAYERS_ACTIVATIONS_SOFTMAX_IMPL_HPP_INCLUDED 28 #define LBANN_LAYERS_ACTIVATIONS_SOFTMAX_IMPL_HPP_INCLUDED 32 #ifdef LBANN_HAS_DISTCONV 34 #endif // LBANN_HAS_DISTCONV 38 template <
typename T, data_layout L, El::Device D>
40 lbann_data::Layer& proto)
const 42 proto.set_datatype(proto::ProtoDataType<T>);
43 auto* msg = proto.mutable_softmax();
46 msg->set_softmax_mode(
"instance");
49 msg->set_softmax_mode(
"channel");
52 msg->set_softmax_mode(
"invalid");
56 #ifdef LBANN_HAS_DISTCONV 57 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
58 softmax_distconv_adapter<TensorDataType, T_layout, Dev>&
61 return const_cast<softmax_distconv_adapter<TensorDataType, T_layout, Dev>&
>(
63 .get_distconv_adapter());
66 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
67 const softmax_distconv_adapter<TensorDataType, T_layout, Dev>&
71 const softmax_distconv_adapter<TensorDataType, T_layout, Dev>&
>(
75 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
76 void softmax_distconv_adapter<TensorDataType, T_layout, Dev>::
81 for (
auto& d : this->m_prev_activations_dists) {
86 for (
auto& d : this->m_activations_dists) {
91 for (
auto& d : this->m_prev_error_signals_dists) {
96 for (
auto& d : this->m_error_signals_dists) {
103 template <
typename TensorDataType, data_layout T_layout, El::Device Dev>
104 void softmax_distconv_adapter<TensorDataType, T_layout, Dev>::setup_layer(
105 size_t workspace_capacity)
109 m_softmax = std::make_unique<dc::Softmax>(dc::get_backend());
111 ? ::distconv::SoftmaxMode::INSTANCE
112 : ::distconv::SoftmaxMode::CHANNEL;
113 m_softmax->
setup(this->get_prev_activations(), mode);
115 #endif // LBANN_HAS_DISTCONV 119 #endif // LBANN_LAYERS_ACTIVATIONS_SOFTMAX_IMPL_HPP_INCLUDED
void write_specific_proto(lbann_data::Layer &proto) const final
void mark_updated(const dc::Dist &d)
virtual void setup(size_t max_mini_batch_size, const std::vector< El::Grid *> &grids)
Setup layer members.
virtual void setup_distributions(tensor_overlap_constraints &constraints)
void mark_invariant(const dc::Dist &d)