26 #ifndef SRC_LAYERS_TRANSFORM_TENSOR_DIMS_UTILS_HPP_INCLUDED 27 #define SRC_LAYERS_TRANSFORM_TENSOR_DIMS_UTILS_HPP_INCLUDED 38 template <
typename T,
typename Tag>
50 explicit NamedVector(std::vector<T>&& v) : m_data{std::move(v)} {}
57 template <
typename U,
typename UTag>
63 template <
typename U,
typename UTag>
70 std::vector<T>&
get() noexcept {
return m_data; }
71 std::vector<T>
const&
get()
const noexcept {
return m_data; }
74 auto size() const noexcept {
return m_data.size(); }
78 template <
typename IndexT>
81 template <
typename IndexT>
84 template <
typename IndexT>
87 template <
typename IndexT>
97 template <
typename IndexT>
100 tgt.
get().assign(crbegin(src.
get()), crend(src.
get()));
103 template <
typename IndexT>
106 tgt.
get().assign(crbegin(src.
get()), crend(src.
get()));
110 template <
typename IndexT>
113 tgt.
get().assign(crbegin(src.
get()), crend(src.
get()));
116 template <
typename IndexT>
119 tgt.
get().assign(crbegin(src.
get()), crend(src.
get()));
125 std::vector<int>& out)
127 int const ndims =
static_cast<int>(in.size());
129 std::transform(crbegin(in), crend(in), begin(out), [ndims](
int const& a) {
130 return ndims - a - 1;
148 template <
typename OutT,
typename InT>
151 return std::vector<OutT>{cbegin(in), cend(in)};
157 template <
typename IndexT>
163 template <
typename IndexT>
169 template <
typename IndexT>
175 template <
typename IndexT>
181 template <
typename IndexT>
187 template <
typename IndexT>
193 template <
typename IndexT>
199 template <
typename IndexT>
205 template <
typename IndexT>
211 template <
typename IndexT>
227 template <
typename Str
ideT,
typename DimT>
230 std::vector<DimT>
const& dim_vec = dims.
get();
231 size_t const ndims = dim_vec.size();
233 std::vector<StrideT> strides;
234 strides.reserve(ndims);
235 strides.push_back(StrideT{1});
236 for (
size_t ii = 0UL; ii < ndims - 1; ++ii)
237 strides.push_back(strides[ii] * static_cast<StrideT>(dim_vec[ii]));
247 template <
typename DimT>
250 return get_strides_as<DimT>(dims);
263 template <
typename T>
266 size_t const ndims = perm.size();
267 std::sort(begin(perm), end(perm));
268 for (
size_t ii = 0; ii < ndims; ++ii)
269 if (static_cast<size_t>(perm[ii]) != ii)
275 template <
typename T>
278 size_t const size = perm.size();
279 std::vector<T> out(size);
280 for (
size_t ii = 0; ii <
size; ++ii)
293 template <
typename IndexT,
typename PermT>
294 auto permute_impl(std::vector<IndexT>
const& in, std::vector<PermT>
const& perm)
296 if (perm.size() == 0UL)
299 size_t const ndims = in.size();
300 size_t const nperm = perm.size();
303 std::vector<IndexT> out;
305 for (
size_t ii = 0UL; ii < nperm; ++ii)
306 out.push_back(in[perm[ii]]);
307 for (
size_t ii = nperm; ii < ndims; ++ii)
308 out.push_back(in[ii]);
340 template <
typename IndexT>
346 template <
typename IndexT>
355 #endif // SRC_LAYERS_TRANSFORM_TENSOR_DIMS_UTILS_HPP_INCLUDED NamedVector(std::vector< T > const &v)
auto get_strides(ColMajorDims< DimT > const &dims)
Compute packed strides of the given dimensions.
auto permute_impl(std::vector< IndexT > const &in, std::vector< PermT > const &perm)
NamedVector(NamedVector< U, UTag > const &other)
#define LBANN_ASSERT_DEBUG(cond)
NamedVector(std::vector< T > &&v)
std::vector< int > vector_type
typename vector_type::value_type value_type
NamedVector & operator=(NamedVector const &other)=default
auto RowMajor(std::vector< IndexT > &&ds)
std::vector< T > & get() noexcept
auto size() const noexcept
auto invert_perm_impl(std::vector< T > const &perm)
Returns the inverse of the given permutation.
auto vec_convert(std::vector< InT > const &in)
Copy the input vector to a new type.
void switch_perm_majorness(std::vector< int > const &in, std::vector< int > &out)
auto get_strides_as(ColMajorDims< DimT > const &dims)
Compute packed strides of the given dimensions.
RowMajorPerm invert(RowMajorPerm const &in)
bool is_valid(RowMajorPerm const &perm)
void convert(RowMajorDims< IndexT > const &src, ColMajorDims< IndexT > &tgt)
auto ColMajor(std::vector< IndexT > &&ds)
void swap(NamedVector &other)
auto permute_dims(RowMajorDims< IndexT > const &in, RowMajorPerm const &perm)
NamedVector & operator=(NamedVector< U, UTag > const &other)
bool check_perm_impl(std::vector< T > perm)
Checks that the permutation is valid.