LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
beta.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_UTILS_BETA_HPP
28 #define LBANN_UTILS_BETA_HPP
29 
30 #include <cmath>
31 #include <istream>
32 #include <ostream>
33 #include <random>
34 
36 #include "lbann/utils/random.hpp"
37 
38 namespace lbann {
39 
50 template <typename RealType = double>
52 {
53 public:
54  using result_type = RealType;
55 
56  class param_type
57  {
58  public:
60 
61  explicit param_type(RealType param_a, RealType param_b)
62  : m_a(param_a), m_b(param_b)
63  {
64  if (param_a <= RealType(0) || param_b <= RealType(0)) {
65  LBANN_ERROR("Beta distribution parameters must be positive");
66  }
67  }
68 
69  constexpr RealType a() const { return m_a; }
70  constexpr RealType b() const { return m_b; }
71 
72  bool operator==(const param_type& other) const
73  {
74  return m_a == other.m_a && m_b == other.m_b;
75  }
76  bool operator!=(const param_type& other) const
77  {
78  return m_a != other.m_a || m_b != other.m_b;
79  }
80 
81  private:
82  RealType m_a, m_b;
83  };
84 
85  explicit beta_distribution(RealType a, RealType b)
86  : m_params(a, b), m_gamma_a(a), m_gamma_b(b)
87  {}
88  explicit beta_distribution(const param_type& p)
89  : m_params(p), m_gamma_a(p.a()), m_gamma_b(p.b())
90  {}
91 
92  result_type a() const { return m_params.a(); }
93  result_type b() const { return m_params.b(); }
94 
95  void reset() {}
96 
97  param_type param() const { return m_params; }
98  void param(const param_type& p)
99  {
100  m_params = p;
101  m_gamma_a = gamma_dist(p.a());
102  m_gamma_b = gamma_dist(p.b());
103  }
104 
105  template <typename Generator>
106  result_type operator()(Generator& g)
107  {
108  return generate(g);
109  }
110  template <typename Generator>
111  result_type operator()(Generator& g, const param_type& p)
112  {
113  return generate(g, p);
114  }
115 
116  result_type min() const { return result_type(0); }
117  result_type max() const { return result_type(1); }
118 
120  {
121  return param() == other.param();
122  }
124  {
125  return param() != other.param();
126  }
127 
128 private:
130 
131  using gamma_dist = std::gamma_distribution<RealType>;
133 
134  // Generator for when we use the distribution's parameters.
135  template <typename Generator>
136  result_type generate(Generator& g)
137  {
138  if (a() <= result_type(1) && b() <= result_type(1)) {
139  return generate_johnk(g, m_params.a(), m_params.b());
140  }
141  else {
142  return generate_gamma(g, m_gamma_a, m_gamma_b);
143  }
144  }
145  // Generator for when we use specified parameters.
146  template <typename Generator>
147  result_type generate(Generator& g, const param_type& p)
148  {
149  if (p.a() <= result_type(1) && p.b() <= result_type(1)) {
150  return generate_johnk(g, p.a(), p.b());
151  }
152  else {
153  gamma_dist gamma_a(p.a()), gamma_b(p.b());
154  return generate_gamma(g, gamma_a, gamma_b);
155  }
156  }
157 
185  template <typename Generator>
187  {
188  while (true) {
189  const result_type U = random_uniform<result_type>(g);
190  const result_type V = random_uniform<result_type>(g);
191  const result_type X = std::pow(U, result_type(1) / a);
192  const result_type Y = std::pow(V, result_type(1) / b);
193  const result_type XplusY = X + Y;
194  if (XplusY <= result_type(1.0)) {
195  if (XplusY > result_type(0)) {
196  return X / XplusY;
197  }
198  else if (U != result_type(0) && V != result_type(0)) {
199  // Work with logs instead if a/b is too small.
200  result_type logX = std::log(U) / a;
201  result_type logY = std::log(V) / b;
202  const result_type log_max = std::max(logX, logY);
203  logX -= log_max;
204  logY -= log_max;
205  return std::exp(logX - std::log(std::exp(logX) + std::exp(logY)));
206  }
207  }
208  }
209  }
210 
217  template <typename Generator>
219  generate_gamma(Generator& g, gamma_dist& gamma_a, gamma_dist& gamma_b)
220  {
221  const result_type Ga = gamma_a(g);
222  const result_type Gb = gamma_b(g);
223  return Ga / (Ga + Gb);
224  }
225 };
226 
227 template <typename CharT, typename RealType>
228 std::basic_ostream<CharT>& operator<<(std::basic_ostream<CharT>& os,
230 {
231  os << "~Beta(" << d.a() << "," << d.b() << ")";
232  return os;
233 }
234 
235 template <typename CharT, typename RealType>
236 std::basic_istream<CharT>& operator>>(std::basic_istream<CharT>& is,
238 {
239  std::string s;
240  RealType a, b;
241  if (std::getline(is, s, '(') && s == "~Beta" && is >> a && is.get() == ',' &&
242  is >> b && is.get() == ')') {
244  }
245  else {
246  is.setstate(std::ios::failbit);
247  }
248  return is;
249 }
250 
251 } // namespace lbann
252 
253 #endif // LBANN_UTILS_BETA_HPP
param_type(RealType param_a, RealType param_b)
Definition: beta.hpp:61
result_type generate_johnk(Generator &g, result_type a, result_type b)
Definition: beta.hpp:186
result_type generate(Generator &g)
Definition: beta.hpp:136
constexpr RealType b() const
Definition: beta.hpp:70
constexpr RealType a() const
Definition: beta.hpp:69
#define LBANN_ERROR(...)
Definition: exception.hpp:37
result_type operator()(Generator &g)
Definition: beta.hpp:106
beta_distribution(const param_type &p)
Definition: beta.hpp:88
void param(const param_type &p)
Definition: beta.hpp:98
result_type max() const
Definition: beta.hpp:117
beta_distribution(RealType a, RealType b)
Definition: beta.hpp:85
result_type min() const
Definition: beta.hpp:116
RealType result_type
Definition: beta.hpp:54
std::basic_istream< CharT > & operator>>(std::basic_istream< CharT > &is, beta_distribution< RealType > &d)
Definition: beta.hpp:236
bool operator!=(const beta_distribution< result_type > &other) const
Definition: beta.hpp:123
result_type b() const
Definition: beta.hpp:93
result_type a() const
Definition: beta.hpp:92
result_type generate_gamma(Generator &g, gamma_dist &gamma_a, gamma_dist &gamma_b)
Definition: beta.hpp:219
param_type m_params
Definition: beta.hpp:129
bool operator==(const param_type &other) const
Definition: beta.hpp:72
result_type generate(Generator &g, const param_type &p)
Definition: beta.hpp:147
gamma_dist m_gamma_b
Definition: beta.hpp:132
bool operator!=(const param_type &other) const
Definition: beta.hpp:76
std::gamma_distribution< RealType > gamma_dist
Definition: beta.hpp:131
bool operator==(const beta_distribution< result_type > &other) const
Definition: beta.hpp:119
gamma_dist m_gamma_a
Definition: beta.hpp:132
param_type param() const
Definition: beta.hpp:97
result_type operator()(Generator &g, const param_type &p)
Definition: beta.hpp:111