LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
permuteimpl.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_SRC_LAYERS_TRANSFORM_PERMUTE_PERMUTEIMPL_HPP_INCLUDED
27 #define LBANN_SRC_LAYERS_TRANSFORM_PERMUTE_PERMUTEIMPL_HPP_INCLUDED
28 
30 
31 #ifdef LBANN_HAS_CUTENSOR
32 #include "cutensor_permuteimpl.hpp"
33 #endif
34 
35 #if defined(LBANN_HAS_CUTT) || defined(LBANN_HAS_HIPTT)
36 #include "cutt_permuteimpl.hpp"
37 #endif
38 
39 #include "tensor_dims_utils.hpp"
40 
41 #include <cereal/cereal.hpp>
42 
43 namespace lbann {
44 
45 template <typename T>
47 {
48 public:
49 #ifdef LBANN_HAS_CUTENSOR
50  using DeviceImplType = cuTENSOR_PermuteImpl;
51 #elif defined(LBANN_HAS_CUTT) || defined(LBANN_HAS_HIPTT)
52  using DeviceImplType = cuTT_PermuteImpl;
53 #endif // LBANN_HAS_CU{TT,TENSOR}
54  using MatType = El::Matrix<T, El::Device::GPU>;
55  using DimsType = typename DeviceImplType::DimsType;
56 
57 public:
58  // LBANN uses row-major tensor ordering.
59  PermuteImpl(std::vector<int> const& perm_row_major);
60  PermuteImpl(PermuteImpl const& other) = default;
61  PermuteImpl(PermuteImpl&& other) = default;
62 
63  // Returns the row-major output dims.
64  std::vector<int> setup_dims(std::vector<int> const& input_dims);
65 
66  void forward_prop(MatType const& prev_acts, MatType& acts) const;
67 
68  // Activations don't actually matter here...
69  void backward_prop(MatType const& grad_wrt_out, MatType& grad_wrt_in);
70 
71  std::vector<int> get_perm() const;
72  std::string describe_perm() const;
73 
74  void swap(PermuteImpl& other);
75 
76  // Serialization
77 
78  template <typename ArchiveT>
79  void save(ArchiveT& ar) const;
80 
81  template <typename ArchiveT>
82  void load(ArchiveT& ar);
83 
84  template <typename ArchiveT>
85  static void load_and_construct(
86  ArchiveT& ar,
87  cereal::construct<PermuteLayer<T>::PermuteImpl>& construct);
88 
89 private:
90  DeviceImplType m_device_impl;
91 
92 }; // class PermuteImpl
93 
94 } // namespace lbann
95 #endif // LBANN_SRC_LAYERS_TRANSFORM_PERMUTE_PERMUTEIMPL_HPP_INCLUDED
void swap(PermuteLayer &other)
El::Matrix< T, El::Device::GPU > MatType
Definition: permuteimpl.hpp:54
void setup_dims() final
Setup tensor dimensions Called by the &#39;setup&#39; function. If there are any input tensors, the base method sets all uninitialized output tensor dimensions equal to the first input tensor dimensions.
cuTT-based implementation of tensor permute.
void load(std::string const &pbuf_filename, google::protobuf::Message &msg)
Fill the protobuf message from a binary file.
typename DeviceImplType::DimsType DimsType
Definition: permuteimpl.hpp:55
void save(ArchiveT &ar, ::El::AbstractMatrix< T > const &mat)
Save a matrix to a text-based archive.
cuTENSOR-based implementation of tensor permute.
Permute the indices of a tensor.
Definition: permute.hpp:49