LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
hydrogen_utils.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 namespace lbann {
28 
29 #include <El.hpp>
30 #include <lbann/utils/memory.hpp>
31 
32 template <typename TensorDataType, typename EvalDataType>
34 {
35  static std::unique_ptr<El::AbstractMatrix<EvalDataType>>
36  get(El::AbstractMatrix<TensorDataType> const& x)
37  {
38  switch (x.GetDevice()) {
39  case El::Device::CPU:
40  return get(
41  static_cast<El::Matrix<TensorDataType, El::Device::CPU> const&>(x));
42 #ifdef LBANN_HAS_GPU
43  case El::Device::GPU:
44  return get(
45  static_cast<El::Matrix<TensorDataType, El::Device::GPU> const&>(x));
46 #endif
47  default:
48  return nullptr;
49  }
50  }
51  template <El::Device D>
52  static std::unique_ptr<El::Matrix<EvalDataType, D>>
53  get(El::Matrix<TensorDataType, D> const& x)
54  {
55  auto ret = std::make_unique<El::Matrix<EvalDataType, D>>();
56  El::Copy(x, *ret);
57  return ret;
58  }
59 };
60 
61 // Specialize for same data type -- make a view instead.
62 template <typename DataType>
63 struct ViewIfPossibleOrCopy<DataType, DataType>
64 {
65  static std::unique_ptr<El::AbstractMatrix<DataType>>
66  get(El::AbstractMatrix<DataType> const& x)
67  {
68  switch (x.GetDevice()) {
69  case El::Device::CPU:
70  return get(static_cast<El::Matrix<DataType, El::Device::CPU> const&>(x));
71 #ifdef LBANN_HAS_GPU
72  case El::Device::GPU:
73  return get(static_cast<El::Matrix<DataType, El::Device::GPU> const&>(x));
74 #endif
75  default:
76  return nullptr;
77  }
78  }
79  template <El::Device D>
80  static std::unique_ptr<El::Matrix<DataType, D>>
81  get(El::Matrix<DataType, D> const& x)
82  {
83  auto ret = std::make_unique<El::Matrix<DataType, D>>();
84  El::LockedView(*ret, x);
85  return ret;
86  }
87 };
88 
89 } // namespace lbann