27 #ifndef LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_BN_HPP_INCLUDED 28 #define LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_BN_HPP_INCLUDED 36 namespace kfac_bn_util {
40 template <El::Device Device>
42 const El::Matrix<DataType, Device>& errors,
43 const El::Matrix<DataType, Device>& scales,
44 const El::Matrix<DataType, Device>& biases,
45 El::Matrix<DataType, Device>& cols,
49 const El::SyncInfo<Device>& sync_info);
55 template <El::Device Device>
64 size_t inverse_proc_rank,
65 bool enable_copy_errors,
66 bool enable_copy_activations,
74 enable_copy_activations,
79 const bool is_after_fc =
82 Device>*>(parent) !=
nullptr);
86 Device>*>(parent) !=
nullptr);
87 if (!is_after_fc && !m_is_after_conv) {
88 std::stringstream err;
89 err <<
"The K-FAC only supports batch-normalization layers after " 90 <<
"fully-connected layers or convolutional layers." 92 <<
" parent type: " << parent->get_type();
97 const auto& dtl_parent =
99 const El::AbstractMatrix<DataType>& local_activations =
101 m_num_channels = local_activations.Height();
106 m_num_channels = input_dims[0];
109 for (
auto i = input_dims.begin() + 1; i != input_dims.end(); i++)
110 m_spatial_prod *= *i;
119 total_size += m_fisher_buf.Height() * m_fisher_buf.Width();
120 total_size += m_fisher_average.Height() * m_fisher_average.Width();
121 total_size += m_fisher_inverse.Height() * m_fisher_inverse.Width();
125 void compute_local_kronecker_factors(
lbann_comm* comm,
127 bool print_matrix_summary)
override;
129 const std::vector<El::AbstractMatrix<DataType>*>
132 std::vector<El::AbstractMatrix<DataType>*> ret = {&m_fisher_buf};
137 DataType kronecker_decay,
139 bool print_matrix_summary)
override;
141 void update_kronecker_inverse(
lbann_comm* comm,
143 DataType damping_act,
144 DataType damping_err,
145 DataType learning_rate_factor,
146 bool use_eigen_decomposition,
148 bool print_matrix_summary,
149 bool print_time)
override;
151 void compute_preconditioned_gradients(
lbann_comm* comm,
152 DataType learning_rate_factor,
154 bool print_matrix_summary,
155 bool print_time)
override;
157 void start_communication_forward_end(
lbann_comm* comm)
override;
158 void end_communication_forward_end(
lbann_comm* comm)
override;
159 void start_communication_backward_end(
lbann_comm* comm)
override;
160 void end_communication_backward_end(
lbann_comm* comm)
override;
162 const std::vector<El::AbstractMatrix<DataType>*>
163 get_preconditioned_grad_buffers()
override;
165 std::vector<std::tuple<std::string, size_t, size_t>>
166 get_internal_matrix_info()
const override;
170 std::ostringstream oss;
171 oss << kfac_block<Device>::get_info()
172 <<
", is_after_conv=" << m_is_after_conv;
176 int get_inverse_matrices(El::Matrix<DataType, Device>& output,
177 int offset)
override;
180 int get_inverse_matrices_size(
lbann_comm* comm)
override;
185 LBANN_ERROR(
"Sub-grid parallelism is not implemented for BN layer");
190 El::Matrix<double, El::Device::CPU>& inverse_matrices_size,
191 int block_number)
override 193 LBANN_ERROR(
"Sub-grid parallelism is not implemented for BN layer");
197 int set_inverse_matrices(El::Matrix<DataType, Device>& workspace,
218 #endif // LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_BN_HPP_INCLUDED bool m_is_after_conv
Information to perform its computation.
void resize_inverse_matrices_size(El::Matrix< double, El::Device::CPU > &inverse_matrices_size, int block_number) override
Get inverse matrices size vector.
int get_local_memory_consumption() override
Get local Memory Consumption.
std::vector< int > get_input_dims(size_t input_index=0) const
Get input tensor dimensions.
Neural network tensor operation.
const std::vector< El::AbstractMatrix< DataType > * > get_local_kronecker_buffers() override
Get buffers of Kronecker factors for reduce-scatter.
constexpr El::Device Device
OutputAbsMatrixType & get_local_activations(int child_index=0)
std::string get_info() const override
Get block's information in one line.
kfac_block_bn(Layer *layer, kfac::KFACExecutionContext *context, size_t layer_id, size_t inverse_proc_rank, bool enable_copy_errors, bool enable_copy_activations, int input_size, int output_size)
El::Matrix< DataType, Device > m_fisher_inverse
Inverse of the average Fisher matrix.
El::Matrix< DataType, Device > m_fisher_average
Exponential moving average of the Fisher matrix.
void update_kronecker_average(El::Matrix< DataType, Device > &Aave, const El::Matrix< DataType, Device > &A, size_t count, double decay, const El::SyncInfo< Device > &sync_info)
Update a Kronecker factor matrix using decay.
std::vector< const Layer * > get_parent_layers() const
std::string get_name() const
Get the layer instance's name.
void compute_bn_factor_data2col(const El::Matrix< DataType, Device > &activations, const El::Matrix< DataType, Device > &errors, const El::Matrix< DataType, Device > &scales, const El::Matrix< DataType, Device > &biases, El::Matrix< DataType, Device > &cols, size_t batch_size, size_t num_channels, size_t spatial_prod, const El::SyncInfo< Device > &sync_info)
The memory copy part of compute_bn_factor. Combined with GEMM.
El::Matrix< DataType, Device > m_fisher_buf
Lower triangle buffers of the Fisher block.
std::vector< int > get_inverse_matrices_size_vector(lbann_comm *comm) override
Get inverse matrices size vector.