LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
miopen/utils.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 #ifndef LBANN_UTILS_DNN_LIB_MIOPEN_UTILS_HPP_
27 #define LBANN_UTILS_DNN_LIB_MIOPEN_UTILS_HPP_
28 
29 namespace lbann {
30 #if defined LBANN_HAS_MIOPEN
31 namespace dnn_lib {
32 
33 using namespace miopen;
34 
35 namespace internal {
36 
37 // Simple RAII class that sets the stream on creation, caches the old
38 // stream, and restores it on the way out.
39 class StreamManager
40 {
41 public:
42  StreamManager(miopenHandle_t handle, hipStream_t stream) : handle_(handle)
43  {
44  CHECK_MIOPEN(miopenGetStream(handle_, &old_stream_));
45  CHECK_MIOPEN(miopenSetStream(handle_, stream));
46  }
47 
48  ~StreamManager()
49  {
50  try {
51  if (handle_)
52  CHECK_MIOPEN(miopenSetStream(handle_, old_stream_));
53  }
54  catch (std::exception const& e) {
55  std::cerr << "Caught error in ~dnn_lib::StreamManager().\n\n e.what(): "
56  << e.what() << "\n\nCalling std::terminate()." << std::endl;
57  std::terminate();
58  }
59  }
60  StreamManager(StreamManager const& other) = delete;
61  StreamManager(StreamManager&& other)
62  : handle_{other.handle_}, old_stream_{other.old_stream_}
63  {
64  other.handle_ = nullptr;
65  other.old_stream_ = nullptr;
66  }
67  StreamManager& operator=(StreamManager const& other) = delete;
68  StreamManager& operator=(StreamManager&& other) = delete;
69 
70  miopenHandle_t get() const noexcept { return handle_; }
71 
72 private:
73  miopenHandle_t handle_;
74  hipStream_t old_stream_;
75 }; // struct StreamManager
76 
77 inline StreamManager
78 make_default_handle_manager(El::SyncInfo<El::Device::GPU> const& si)
79 {
80  return StreamManager(get_handle(), si.Stream());
81 }
82 
83 } // namespace internal
84 } // namespace dnn_lib
85 #endif // defined LBANN_HAS_MIOPEN
86 } // namespace lbann
87 #endif // LBANN_UTILS_DNN_LIB_MIOPEN_UTILS_HPP_