LBANN  0.103.0
LivermoreBigArtificialNeuralNetworkToolkit
lbann::glorot_initializer< TensorDataType > Class Template Reference

Fill weights with variance of 2 / (fan-in + fan-out). More...

#include <variance_scaling_initializers.hpp>

Inheritance diagram for lbann::glorot_initializer< TensorDataType >:
[legend]
Collaboration diagram for lbann::glorot_initializer< TensorDataType >:
[legend]

Public Member Functions

 glorot_initializer (probability_distribution prob_dist)
 
std::string get_type () const override
 
void write_proto (lbann_data::Initializer &init) const final
 Add initializer data to prototext. More...
 
- Public Member Functions inherited from lbann::Cloneable< glorot_initializer< TensorDataType >, variance_scaling_initializer< TensorDataType > >
std::unique_ptr< glorot_initializer< TensorDataType > > clone () const
 Return an exception-safe, memory-safe copy of this object. More...
 

Private Types

using BaseType = Cloneable< glorot_initializer< TensorDataType >, variance_scaling_initializer< TensorDataType > >
 

Private Member Functions

TensorDataType get_variance (El::Int fan_in, El::Int fan_out) override
 

Detailed Description

template<typename TensorDataType>
class lbann::glorot_initializer< TensorDataType >

Fill weights with variance of 2 / (fan-in + fan-out).

Also called Xavier initialization.

Definition at line 95 of file variance_scaling_initializers.hpp.

Member Typedef Documentation

◆ BaseType

template<typename TensorDataType >
using lbann::glorot_initializer< TensorDataType >::BaseType = Cloneable<glorot_initializer<TensorDataType>, variance_scaling_initializer<TensorDataType> >
private

Definition at line 100 of file variance_scaling_initializers.hpp.

Constructor & Destructor Documentation

◆ glorot_initializer()

template<typename TensorDataType >
lbann::glorot_initializer< TensorDataType >::glorot_initializer ( probability_distribution  prob_dist)
inline

Definition at line 103 of file variance_scaling_initializers.hpp.

Member Function Documentation

◆ get_type()

template<typename TensorDataType >
std::string lbann::glorot_initializer< TensorDataType >::get_type ( ) const
inlineoverride

Definition at line 105 of file variance_scaling_initializers.hpp.

Here is the call graph for this function:

◆ get_variance()

template<typename TensorDataType >
TensorDataType lbann::glorot_initializer< TensorDataType >::get_variance ( El::Int  fan_in,
El::Int  fan_out 
)
overrideprivate

◆ write_proto()

template<typename TensorDataType >
void lbann::glorot_initializer< TensorDataType >::write_proto ( lbann_data::Initializer &  init) const
final

Add initializer data to prototext.


The documentation for this class was generated from the following file: