27 #ifndef LBANN_WEIGHTS_INITIALIZER_HPP 28 #define LBANN_WEIGHTS_INITIALIZER_HPP 34 #include <google/protobuf/message.h> 44 :
public Cloneable<HasAbstractFunction<weights_initializer>>
51 virtual std::string get_type()
const = 0;
58 template <
typename TensorDataType>
61 HasAbstractFunction<data_type_weights_initializer<TensorDataType>>,
78 std::string
get_type()
const override {
return "data_type_weights"; }
84 virtual void write_proto(lbann_data::Initializer& proto)
const = 0;
88 template <
typename TensorDataType>
90 :
public Cloneable<constant_initializer<TensorDataType>,
91 data_type_weights_initializer<TensorDataType>>
104 std::string
get_type()
const override {
return "constant"; }
109 void write_proto(lbann_data::Initializer& init)
const final;
124 template <
typename TensorDataType>
126 :
public Cloneable<value_initializer<TensorDataType>,
127 data_type_weights_initializer<TensorDataType>>
140 : m_values{std::move(values)}
142 std::string
get_type()
const override {
return "value"; }
146 void write_proto(lbann_data::Initializer& init)
const final;
158 template <
typename TensorDataType>
160 :
public Cloneable<numpy_initializer<TensorDataType>,
161 data_type_weights_initializer<TensorDataType>>
174 std::string
get_type()
const override {
return "NumPy"; }
178 void write_proto(lbann_data::Initializer& init)
const final;
186 template <
typename TensorDataType>
188 :
public Cloneable<uniform_initializer<TensorDataType>,
189 data_type_weights_initializer<TensorDataType>>
202 TensorDataType max = El::To<TensorDataType>(1))
203 : m_min{min}, m_max{max}
205 std::string
get_type()
const override {
return "uniform"; }
210 void write_proto(lbann_data::Initializer& init)
const final;
220 template <
typename TensorDataType>
222 :
public Cloneable<normal_initializer<TensorDataType>,
223 data_type_weights_initializer<TensorDataType>>
236 TensorDataType mean = El::TypeTraits<TensorDataType>::Zero(),
237 TensorDataType standard_deviation = El::TypeTraits<TensorDataType>::One())
238 : m_mean{mean}, m_standard_deviation{standard_deviation}
240 std::string
get_type()
const override {
return "normal"; }
245 void write_proto(lbann_data::Initializer& init)
const final;
254 template <
typename TensorDataType>
255 std::unique_ptr<weights_initializer>
258 template <
typename TensorDataType>
259 std::unique_ptr<weights_initializer>
262 template <
typename TensorDataType>
263 std::unique_ptr<weights_initializer>
266 template <
typename TensorDataType>
267 std::unique_ptr<weights_initializer>
270 template <
typename TensorDataType>
271 std::unique_ptr<weights_initializer>
274 #ifndef LBANN_INITIALIZER_INSTANTIATE 276 extern template class data_type_weights_initializer<T>; \ 277 extern template class constant_initializer<T>; \ 278 extern template class value_initializer<T>; \ 279 extern template class numpy_initializer<T>; \ 280 extern template class uniform_initializer<T>; \ 281 extern template class normal_initializer<T> 283 #define LBANN_INSTANTIATE_CPU_HALF 284 #define LBANN_INSTANTIATE_GPU_HALF 287 #undef LBANN_INSTANTIATE_CPU_HALF 288 #undef LBANN_INSTANTIATE_GPU_HALF 289 #endif // LBANN_INITIALIZER_INSTANTIATE 293 #endif // LBANN_WEIGHTS_INITIALIZER_HPP std::unique_ptr< weights_initializer > build_uniform_initializer_from_pbuf(google::protobuf::Message const &msg)
Fill weights with values from a NumPy file.
void fill(std::istream &is, google::protobuf::Message &msg)
Fill the protobuf message from a binary stream.
normal_initializer(TensorDataType mean=El::TypeTraits< TensorDataType >::Zero(), TensorDataType standard_deviation=El::TypeTraits< TensorDataType >::One())
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
Inject polymorphic clone functions into hierarchies.
numpy_initializer(std::string file)
Scheme for initializing weight values.
Generates nicely formatted description messages.
value_initializer(std::vector< TensorDataType > values)
std::string get_type() const override
Fill weights with a single constant value.
std::unique_ptr< weights_initializer > build_constant_initializer_from_pbuf(google::protobuf::Message const &msg)
std::unique_ptr< weights_initializer > build_numpy_initializer_from_pbuf(google::protobuf::Message const &msg)
Fill weights with values from a list.
TensorDataType m_standard_deviation
Scheme for initializing weight values.
Draw weights values from a normal random distribution.
std::vector< TensorDataType > m_values
std::string get_type() const override
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
std::string get_type() const override
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
std::unique_ptr< weights_initializer > build_normal_initializer_from_pbuf(google::protobuf::Message const &msg)
constant_initializer(TensorDataType value)
std::string get_type() const override
std::string get_type() const override
std::unique_ptr< weights_initializer > build_value_initializer_from_pbuf(google::protobuf::Message const &msg)