27 #ifndef LBANN_LAYERS_REGULARIZERS_CHANNELWISE_SOFTMAX_IMPL_HPP_INCLUDED 28 #define LBANN_LAYERS_REGULARIZERS_CHANNELWISE_SOFTMAX_IMPL_HPP_INCLUDED 35 template <
typename TensorDataType, data_layout Layout, El::Device Device>
39 int64_t dims =
static_cast<int64_t
>(this->get_input_dims().size());
40 if (this->m_dim < -dims || this->m_dim >= dims) {
43 " is out of bounds for Channelwise " 44 "Softmax layer on tensor with ",
48 if (!this->m_single_dim_mode && this->m_dim != 0 && this->m_dim != -dims &&
49 this->m_dim != (dims - 1) && this->m_dim != -1) {
50 LBANN_ERROR(
"Channelwise softmax with all dimensions is only supported for " 51 "the first or last tensor dimensions. Got dimension ",
55 this->set_output_dims(this->get_input_dims());
58 template <
typename TensorDataType, data_layout Layout, El::Device Device>
61 El::Int& channel_stride,
62 El::Int& num_channels)
const 64 auto const& input_dims = this->get_input_dims();
65 int dims =
static_cast<int>(input_dims.size());
66 int dim = this->m_dim;
70 size_t total_size = 1;
71 for (
int i = 0; i < dims; ++i) {
72 total_size *= input_dims[i];
81 if (m_single_dim_mode) {
82 channel_size = input_dims[dim];
83 num_channels = total_size / channel_size;
86 for (
int i = dims - 1; i >= dim; --i) {
87 channel_stride *= input_dims[i];
93 channel_size = total_size / input_dims[dim];
94 num_channels = input_dims[dim];
98 channel_stride = channel_size;
108 #endif // LBANN_LAYERS_REGULARIZERS_CHANNELWISE_SOFTMAX_IMPL_HPP_INCLUDED virtual void setup_dims()
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
void get_channel_size_and_stride(El::Int &channel_size, El::Int &channel_stride, El::Int &num_channels) const
void setup_dims() override
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.