LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
variance_scaling_initializers.hpp
Go to the documentation of this file.
1 // Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
3 // Produced at the Lawrence Livermore National Laboratory.
4 // Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5 // the CONTRIBUTORS file. <lbann-dev@llnl.gov>
6 //
7 // LLNL-CODE-697807.
8 // All rights reserved.
9 //
10 // This file is part of LBANN: Livermore Big Artificial Neural Network
11 // Toolkit. For details, see http://software.llnl.gov/LBANN or
12 // https://github.com/LLNL/LBANN.
13 //
14 // Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15 // may not use this file except in compliance with the License. You may
16 // obtain a copy of the License at:
17 //
18 // http://www.apache.org/licenses/LICENSE-2.0
19 //
20 // Unless required by applicable law or agreed to in writing, software
21 // distributed under the License is distributed on an "AS IS" BASIS,
22 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23 // implied. See the License for the specific language governing
24 // permissions and limitations under the license.
26 
27 #ifndef LBANN_WEIGHTS_VARIANCE_SCALING_INITIALIZER_HPP
28 #define LBANN_WEIGHTS_VARIANCE_SCALING_INITIALIZER_HPP
29 
31 #include "lbann/utils/random.hpp"
33 
34 namespace lbann {
35 
46 template <typename TensorDataType>
48  : public Cloneable<
49  HasAbstractFunction<variance_scaling_initializer<TensorDataType>>,
50  data_type_weights_initializer<TensorDataType>>
51 {
52 public:
54 
57  using AbsDistMatrixType = El::AbstractDistMatrix<TensorDataType>;
58 
60 
61 public:
63  description get_description() const override;
64  void fill(AbsDistMatrixType& matrix) override;
65 
67  void set_fan_in(El::Int fan_in) { m_fan_in = fan_in; }
69  void set_fan_out(El::Int fan_out) { m_fan_out = fan_out; }
70 
73  {
74  return m_prob_dist;
75  }
76 
77 private:
79  virtual TensorDataType get_variance(El::Int fan_in, El::Int fan_out) = 0;
80 
81 private:
85  El::Int m_fan_in;
87  El::Int m_fan_out;
88 };
89 
94 template <typename TensorDataType>
96  : public Cloneable<glorot_initializer<TensorDataType>,
97  variance_scaling_initializer<TensorDataType>>
98 {
101 
102 public:
104  {}
105  std::string get_type() const override { return "Glorot"; }
106 
108  void write_proto(lbann_data::Initializer& init) const final;
109 
110 private:
111  TensorDataType get_variance(El::Int fan_in, El::Int fan_out) override;
112 };
113 
115 template <typename TensorDataType>
117  : public Cloneable<he_initializer<TensorDataType>,
118  variance_scaling_initializer<TensorDataType>>
119 {
122 
123 public:
125  std::string get_type() const override { return "He"; }
126 
128  void write_proto(lbann_data::Initializer& init) const final;
129 
130 private:
131  TensorDataType get_variance(El::Int fan_in, El::Int fan_out) override;
132 };
133 
135 template <typename TensorDataType>
137  : public Cloneable<lecun_initializer<TensorDataType>,
138  variance_scaling_initializer<TensorDataType>>
139 {
142 
143 public:
145  std::string get_type() const override { return "LeCun"; }
146 
148  void write_proto(lbann_data::Initializer& init) const final;
149 
150 private:
151  TensorDataType get_variance(El::Int fan_in, El::Int fan_out) override;
152 };
153 
154 void set_fan_in(weights_initializer& initializer, double value);
155 void set_fan_out(weights_initializer& initializer, double value);
156 
157 template <typename TensorDataType>
158 std::unique_ptr<weights_initializer>
159 build_glorot_initializer_from_pbuf(google::protobuf::Message const& msg);
160 template <typename TensorDataType>
161 std::unique_ptr<weights_initializer>
162 build_he_initializer_from_pbuf(google::protobuf::Message const& msg);
163 template <typename TensorDataType>
164 std::unique_ptr<weights_initializer>
165 build_lecun_initializer_from_pbuf(google::protobuf::Message const& msg);
166 
167 #ifndef LBANN_VARIANCE_SCALING_INITIALIZER_INSTANTIATE
168 #define PROTO(T) \
169  extern template class glorot_initializer<T>; \
170  extern template class he_initializer<T>; \
171  extern template class lecun_initializer<T>
172 
173 #define LBANN_INSTANTIATE_CPU_HALF
174 #define LBANN_INSTANTIATE_GPU_HALF
176 #undef PROTO
177 #undef LBANN_INSTANTIATE_CPU_HALF
178 #undef LBANN_INSTANTIATE_GPU_HALF
179 #endif // LBANN_VARIANCE_SCALING_INITIALIZER_INSTANTIATE
180 
181 } // namespace lbann
182 
183 #endif // LBANN_WEIGHTS_VARIANCE_SCALING_INITIALIZER_HPP
Fill weights with variance of 2 / fan-in.
std::string get_type() const override
std::unique_ptr< weights_initializer > build_he_initializer_from_pbuf(google::protobuf::Message const &msg)
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
Inject polymorphic clone functions into hierarchies.
Definition: cloneable.hpp:94
Fill weights with variance of 1 / fan-in.
void fill(AbsDistMatrixType &matrix) override
Generates nicely formatted description messages.
Definition: description.hpp:49
glorot_initializer(probability_distribution prob_dist)
lecun_initializer(probability_distribution prob_dist)
std::unique_ptr< weights_initializer > build_glorot_initializer_from_pbuf(google::protobuf::Message const &msg)
virtual TensorDataType get_variance(El::Int fan_in, El::Int fan_out)=0
variance_scaling_initializer(probability_distribution dist)
Scheme for initializing weight values.
Definition: initializer.hpp:43
he_initializer(probability_distribution prob_dist)
std::string get_type() const override
Fill weights with variance of 2 / (fan-in + fan-out).
std::string get_type() const override
probability_distribution get_prob_dist() const noexcept
std::unique_ptr< weights_initializer > build_lecun_initializer_from_pbuf(google::protobuf::Message const &msg)
probability_distribution
Definition: random.hpp:39
description get_description() const override
Generalization of "Xavier" initialization.