LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::he_initializer< TensorDataType > Class Template Reference

Fill weights with variance of 2 / fan-in. More...

#include <variance_scaling_initializers.hpp>

Inheritance diagram for lbann::he_initializer< TensorDataType >:
[legend]
Collaboration diagram for lbann::he_initializer< TensorDataType >:
[legend]

Public Member Functions

 he_initializer (probability_distribution prob_dist)
 
std::string get_type () const override
 
void write_proto (lbann_data::Initializer &init) const final
 Add initializer data to prototext. More...
 
- Public Member Functions inherited from lbann::Cloneable< he_initializer< TensorDataType >, variance_scaling_initializer< TensorDataType > >
std::unique_ptr< he_initializer< TensorDataType > > clone () const
 Return an exception-safe, memory-safe copy of this object. More...
 

Private Types

using BaseType = Cloneable< he_initializer< TensorDataType >, variance_scaling_initializer< TensorDataType > >
 

Private Member Functions

TensorDataType get_variance (El::Int fan_in, El::Int fan_out) override
 

Detailed Description

template<typename TensorDataType>
class lbann::he_initializer< TensorDataType >

Fill weights with variance of 2 / fan-in.

Definition at line 116 of file variance_scaling_initializers.hpp.

Member Typedef Documentation

◆ BaseType

template<typename TensorDataType >
using lbann::he_initializer< TensorDataType >::BaseType = Cloneable<he_initializer<TensorDataType>, variance_scaling_initializer<TensorDataType> >
private

Definition at line 121 of file variance_scaling_initializers.hpp.

Constructor & Destructor Documentation

◆ he_initializer()

template<typename TensorDataType >
lbann::he_initializer< TensorDataType >::he_initializer ( probability_distribution  prob_dist)
inline

Definition at line 124 of file variance_scaling_initializers.hpp.

Member Function Documentation

◆ get_type()

template<typename TensorDataType >
std::string lbann::he_initializer< TensorDataType >::get_type ( ) const
inlineoverride

Definition at line 125 of file variance_scaling_initializers.hpp.

Here is the call graph for this function:

◆ get_variance()

template<typename TensorDataType >
TensorDataType lbann::he_initializer< TensorDataType >::get_variance ( El::Int  fan_in,
El::Int  fan_out 
)
overrideprivate

◆ write_proto()

template<typename TensorDataType >
void lbann::he_initializer< TensorDataType >::write_proto ( lbann_data::Initializer &  init) const
final

Add initializer data to prototext.


The documentation for this class was generated from the following file: