26 #ifndef LBANN_UTILS_DIM_HELPERS_HPP_ 27 #define LBANN_UTILS_DIM_HELPERS_HPP_ 40 template <
typename Out,
typename In>
43 return (dims.size() ? std::accumulate(cbegin(dims),
46 std::multiplies<Out>())
50 template <
typename Out,
typename In>
54 ndims ? std::accumulate(dims, dims + ndims, Out{1}, std::multiplies<Out>())
61 return get_linear_size_as<T>(dims);
67 return get_linear_size_as<T>(ndims, dims);
71 auto get_strides(
size_t ndims, T
const* dims, T
const& lowest_stride)
73 std::vector<T> strides(ndims, lowest_stride);
75 for (
size_t ii = ndims - 1; ii != 0; --ii) {
77 LBANN_ERROR(
"Zero-sized dimension not allowed. Dims[", ii,
"] = 0.");
78 strides[ii - 1] = strides[ii] * dims[ii];
85 auto get_strides(std::vector<T>
const& dims, T
const& lowest_stride)
87 return get_strides(dims.size(), dims.data(), lowest_stride);
102 template <
typename To,
typename From>
105 return std::vector<To>{from.cbegin(), from.cend()};
110 template <
typename T,
typename... ArgTs>
111 std::enable_if_t<std::is_integral_v<T>>
112 accumulate_dims(std::vector<size_t>& acc, T
const& x, ArgTs&&... rest);
114 template <
typename T,
typename... ArgTs>
115 std::enable_if_t<std::is_integral_v<T>>
117 std::vector<T>
const& x,
122 template <
typename T,
typename... ArgTs>
123 std::enable_if_t<std::is_integral_v<T>>
130 template <
typename T,
typename... ArgTs>
131 std::enable_if_t<std::is_integral_v<T>>
133 std::vector<T>
const& x,
136 acc.insert(end(acc), cbegin(x), cend(x));
141 template <
typename... ArgTs>
144 std::vector<size_t> dims;
150 #endif // LBANN_UTILS_DIM_HELPERS_HPP_ auto get_strides(ColMajorDims< DimT > const &dims)
Compute packed strides of the given dimensions.
auto get_linear_size_as(std::vector< In > const &dims)
Compute the linear size of the given dimensions with a specific type.
std::vector< size_t > splice_dims(ArgTs &&... args)
auto get_linear_size(std::vector< T > const &dims)
auto vector_cast(std::vector< From > const &from)
std::enable_if_t< std::is_integral_v< T > > accumulate_dims(std::vector< size_t > &acc, T const &x, ArgTs &&... rest)
auto get_packed_strides(size_t ndims, T const *dims)