27 #ifndef LBANN_WEIGHTS_VARIANCE_SCALING_INITIALIZER_HPP 28 #define LBANN_WEIGHTS_VARIANCE_SCALING_INITIALIZER_HPP 46 template <
typename TensorDataType>
49 HasAbstractFunction<variance_scaling_initializer<TensorDataType>>,
50 data_type_weights_initializer<TensorDataType>>
79 virtual TensorDataType
get_variance(El::Int fan_in, El::Int fan_out) = 0;
94 template <
typename TensorDataType>
96 :
public Cloneable<glorot_initializer<TensorDataType>,
97 variance_scaling_initializer<TensorDataType>>
105 std::string
get_type()
const override {
return "Glorot"; }
108 void write_proto(lbann_data::Initializer& init)
const final;
111 TensorDataType
get_variance(El::Int fan_in, El::Int fan_out)
override;
115 template <
typename TensorDataType>
117 :
public Cloneable<he_initializer<TensorDataType>,
118 variance_scaling_initializer<TensorDataType>>
125 std::string
get_type()
const override {
return "He"; }
128 void write_proto(lbann_data::Initializer& init)
const final;
131 TensorDataType
get_variance(El::Int fan_in, El::Int fan_out)
override;
135 template <
typename TensorDataType>
137 :
public Cloneable<lecun_initializer<TensorDataType>,
138 variance_scaling_initializer<TensorDataType>>
145 std::string
get_type()
const override {
return "LeCun"; }
148 void write_proto(lbann_data::Initializer& init)
const final;
151 TensorDataType
get_variance(El::Int fan_in, El::Int fan_out)
override;
157 template <
typename TensorDataType>
158 std::unique_ptr<weights_initializer>
160 template <
typename TensorDataType>
161 std::unique_ptr<weights_initializer>
163 template <
typename TensorDataType>
164 std::unique_ptr<weights_initializer>
167 #ifndef LBANN_VARIANCE_SCALING_INITIALIZER_INSTANTIATE 169 extern template class glorot_initializer<T>; \ 170 extern template class he_initializer<T>; \ 171 extern template class lecun_initializer<T> 173 #define LBANN_INSTANTIATE_CPU_HALF 174 #define LBANN_INSTANTIATE_GPU_HALF 177 #undef LBANN_INSTANTIATE_CPU_HALF 178 #undef LBANN_INSTANTIATE_GPU_HALF 179 #endif // LBANN_VARIANCE_SCALING_INITIALIZER_INSTANTIATE 183 #endif // LBANN_WEIGHTS_VARIANCE_SCALING_INITIALIZER_HPP Fill weights with variance of 2 / fan-in.
std::string get_type() const override
std::unique_ptr< weights_initializer > build_he_initializer_from_pbuf(google::protobuf::Message const &msg)
El::AbstractDistMatrix< TensorDataType > AbsDistMatrixType
The tensor type expected in this object.
Inject polymorphic clone functions into hierarchies.
Fill weights with variance of 1 / fan-in.
void fill(AbsDistMatrixType &matrix) override
Generates nicely formatted description messages.
glorot_initializer(probability_distribution prob_dist)
void set_fan_in(El::Int fan_in)
lecun_initializer(probability_distribution prob_dist)
std::unique_ptr< weights_initializer > build_glorot_initializer_from_pbuf(google::protobuf::Message const &msg)
virtual TensorDataType get_variance(El::Int fan_in, El::Int fan_out)=0
variance_scaling_initializer(probability_distribution dist)
Scheme for initializing weight values.
probability_distribution m_prob_dist
he_initializer(probability_distribution prob_dist)
std::string get_type() const override
Fill weights with variance of 2 / (fan-in + fan-out).
std::string get_type() const override
probability_distribution get_prob_dist() const noexcept
void set_fan_out(El::Int fan_out)
std::unique_ptr< weights_initializer > build_lecun_initializer_from_pbuf(google::protobuf::Message const &msg)
description get_description() const override
Generalization of "Xavier" initialization.