LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
random_number_generators.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_UTILS_RNG_HPP
28 #define LBANN_UTILS_RNG_HPP
29 
30 #include "lbann/comm.hpp"
32 #include <atomic>
33 #include <random>
34 #include <thread>
35 
36 namespace lbann {
37 
38 using rng_gen = std::mt19937; // Mersenne Twister
39 using fast_rng_gen = std::minstd_rand; // Minimum standard, LC
40 
41 struct io_rng_t
42 {
45  // Track the owner so that it is easy to ensure the right thread is
46  // using this structure.
47  std::atomic<std::thread::id> active_thread_id;
48 
50  : generator(42ULL),
51  fast_generator(42ULL),
52  active_thread_id(std::thread::id())
53  {}
54 
55  io_rng_t(const io_rng_t& other)
56  : generator(other.generator),
57  fast_generator(other.fast_generator),
58  active_thread_id(other.active_thread_id.load())
59  {}
60 };
61 
63 {
66  {
67  std::thread::id prev_tid =
68  rng_->active_thread_id.exchange(std::this_thread::get_id());
69  if (prev_tid != std::thread::id()) {
70  LBANN_ERROR("Acquired a \'locked\' RNG that isn't owned by this thread");
71  }
72  }
73  explicit operator io_rng_t&() { return *rng_; }
75  {
76  std::thread::id prev_tid =
77  rng_->active_thread_id.exchange(std::thread::id());
78  if (prev_tid != std::this_thread::get_id()) {
80  "Releasing a \'locked\' RNG that isn't owned by this thread");
81  }
82  }
84 };
85 
91 
98 
104 
111 
114 
117 
124 
131 
141 void init_random(int seed = -1,
142  int num_io_RNGs = 1,
143  lbann_comm* comm = nullptr);
144 
152 void init_data_seq_random(int seed = -1);
153 
159 void init_ltfb_random(int seed = -1);
160 
169 void init_io_random(int seed = -1, int num_io_RNGs = 1);
170 
171 } // namespace lbann
172 
173 #endif // LBANN_UTILS_RNG_HPP
void init_io_random(int seed=-1, int num_io_RNGs=1)
fast_rng_gen & get_fast_generator()
locked_io_rng_ref set_io_generators_local_index(size_t idx)
Sets the local index for a thread to access the correct I/O RNGs.
lbann::fast_rng_gen fast_generator
#define LBANN_ERROR(...)
Definition: exception.hpp:37
std::minstd_rand fast_rng_gen
fast_rng_gen & get_ltfb_generator()
void load(std::string const &pbuf_filename, google::protobuf::Message &msg)
Fill the protobuf message from a binary file.
std::atomic< std::thread::id > active_thread_id
rng_gen & get_io_generator()
void init_data_seq_random(int seed=-1)
std::mt19937 rng_gen
void init_ltfb_random(int seed=-1)
io_rng_t(const io_rng_t &other)
fast_rng_gen & get_fast_io_generator()
#define LBANN_WARNING(...)
Definition: exception.hpp:53
rng_gen & get_data_seq_generator()
rng_gen & get_generator()
void init_random(int seed=-1, int num_io_RNGs=1, lbann_comm *comm=nullptr)
Initialize the random number generator (with optional seed).
int get_num_io_generators()
Returns the number of provisioned I/O generators.