26 #ifndef LBANN_SRC_LAYERS_TRANSFORM_PERMUTE_PERMUTEIMPL_HPP_INCLUDED 27 #define LBANN_SRC_LAYERS_TRANSFORM_PERMUTE_PERMUTEIMPL_HPP_INCLUDED 31 #ifdef LBANN_HAS_CUTENSOR 35 #if defined(LBANN_HAS_CUTT) || defined(LBANN_HAS_HIPTT) 41 #include <cereal/cereal.hpp> 49 #ifdef LBANN_HAS_CUTENSOR 51 #elif defined(LBANN_HAS_CUTT) || defined(LBANN_HAS_HIPTT) 53 #endif // LBANN_HAS_CU{TT,TENSOR} 54 using MatType = El::Matrix<T, El::Device::GPU>;
55 using DimsType =
typename DeviceImplType::DimsType;
59 PermuteImpl(std::vector<int>
const& perm_row_major);
64 std::vector<int>
setup_dims(std::vector<int>
const& input_dims);
69 void backward_prop(
MatType const& grad_wrt_out,
MatType& grad_wrt_in);
71 std::vector<int> get_perm()
const;
72 std::string describe_perm()
const;
78 template <
typename ArchiveT>
79 void save(ArchiveT& ar)
const;
81 template <
typename ArchiveT>
82 void load(ArchiveT& ar);
84 template <
typename ArchiveT>
85 static void load_and_construct(
95 #endif // LBANN_SRC_LAYERS_TRANSFORM_PERMUTE_PERMUTEIMPL_HPP_INCLUDED void swap(PermuteLayer &other)
El::Matrix< T, El::Device::GPU > MatType
void setup_dims() final
Setup tensor dimensions Called by the 'setup' function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
cuTT-based implementation of tensor permute.
void load(std::string const &pbuf_filename, google::protobuf::Message &msg)
Fill the protobuf message from a binary file.
void forward_prop() final
typename DeviceImplType::DimsType DimsType
void save(ArchiveT &ar, ::El::AbstractMatrix< T > const &mat)
Save a matrix to a text-based archive.
cuTENSOR-based implementation of tensor permute.
DeviceImplType m_device_impl
Permute the indices of a tensor.