27 #ifndef LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_FC_CONV_HPP_INCLUDED 28 #define LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_FC_CONV_HPP_INCLUDED 35 namespace kfac_fc_conv_util {
38 template <El::Device Device>
40 const El::Matrix<DataType, Device>& A,
41 const El::SyncInfo<Device>& sync_info);
44 template <El::Device Device>
45 void conv_transpose(
const El::Matrix<DataType, Device>& activations,
46 El::Matrix<DataType, Device>& act_columns,
47 size_t mini_batch_size,
50 const El::SyncInfo<Device>& sync_info);
53 template <El::Device Device>
54 void im2col(
const El::Matrix<DataType, Device>& im,
55 El::Matrix<DataType, Device>& col,
56 const int num_channels,
57 const int im_num_dims,
60 const int* window_dims,
61 const int* window_strides,
63 const El::SyncInfo<Device>& sync_info);
70 template <El::Device Device>
78 const size_t layer_id,
79 const size_t inverse_proc_rank,
80 const bool enable_copy_errors,
81 const bool enable_copy_activations,
83 const int output_size,
90 enable_copy_activations,
94 m_has_bias(layer->num_weights() > 1)
97 m_conv_input_spatial_prod = 1;
99 for (
auto i = input_dims.begin() + 1; i != input_dims.end(); i++) {
100 m_conv_input_spatial_prod *= *i;
101 m_conv_input_spatial_dims.push_back(*i);
104 m_conv_output_spatial_prod = 1;
106 for (
auto i = output_dims.begin() + 1; i != output_dims.end(); i++) {
107 m_conv_output_spatial_prod *= *i;
108 m_conv_output_spatial_dims.push_back(*i);
111 if (input_dims.size() != 3 && input_dims.size() != 4) {
112 std::stringstream err;
113 err <<
"The K-FAC only supports 2D or 3D tensors." 114 <<
" layer: " << layer->
get_name() <<
", input_dims: ";
115 for (
auto i = input_dims.begin(); i != input_dims.end(); i++)
116 err << (std::distance(input_dims.begin(), i) > 0 ?
"," :
"") << *i;
121 if (m_is_conv && m_has_bias) {
122 std::stringstream err;
123 err <<
"The K-FAC does not currently support biases for convolutional " 137 m_kronecker_inverse_A.Height() * m_kronecker_inverse_A.Width();
139 m_kronecker_inverse_G.Height() * m_kronecker_inverse_G.Width();
141 m_kronecker_average_A.Height() * m_kronecker_average_A.Width();
143 m_kronecker_average_G.Height() * m_kronecker_average_G.Width();
145 m_kronecker_factor_buf_A.Height() * m_kronecker_factor_buf_A.Width();
147 m_kronecker_factor_buf_G.Height() * m_kronecker_factor_buf_G.Width();
148 total_size += m_grad_buffer_v.Height() * m_grad_buffer_v.Width();
152 void compute_local_kronecker_factors(
lbann_comm* comm,
154 bool print_matrix_summary)
override;
156 const std::vector<El::AbstractMatrix<DataType>*>
159 std::vector<El::AbstractMatrix<DataType>*> ret = {
160 &m_kronecker_factor_buf_A,
161 &m_kronecker_factor_buf_G};
166 DataType kronecker_decay,
168 bool print_matrix_summary)
override;
170 void update_kronecker_inverse(
lbann_comm* comm,
172 DataType damping_act,
173 DataType damping_err,
174 DataType learning_rate_factor,
175 bool use_eigen_decomposition,
177 bool print_matrix_summary,
178 bool print_time)
override;
180 void compute_preconditioned_gradients(
lbann_comm* comm,
181 DataType learning_rate_factor,
183 bool print_matrix_summary,
184 bool print_time)
override;
186 void initialize_activations_and_errors(
lbann_comm* comm,
187 int num_local_activations,
188 int num_local_errors,
189 int num_weights)
override;
191 void start_communication_forward_end(
lbann_comm* comm)
override;
192 void end_communication_forward_end(
lbann_comm* comm)
override;
193 void start_communication_backward_end(
lbann_comm* comm)
override;
194 void end_communication_backward_end(
lbann_comm* comm)
override;
196 const std::vector<El::AbstractMatrix<DataType>*>
197 get_preconditioned_grad_buffers()
override;
199 int get_inverse_matrices(El::Matrix<DataType, Device>& output,
200 int offset)
override;
202 int get_inverse_matrices_size(
lbann_comm* comm)
override;
204 std::vector<int> get_inverse_matrices_size_vector(
lbann_comm* comm)
override;
206 void resize_inverse_matrices_size(
207 El::Matrix<double, El::Device::CPU>& inverse_matrices_size,
208 int block_number)
override;
210 int set_inverse_matrices(El::Matrix<DataType, Device>& workspace,
216 std::ostringstream oss;
217 oss << kfac_block<Device>::get_info() <<
", is_conv=" << m_is_conv;
224 get_kronecker_factor_fc(El::AbstractMatrix<DataType>& factor,
225 const El::AbstractMatrix<DataType>& activations,
229 static void get_kronecker_factor_conv(
230 El::Matrix<DataType, Device>& factor,
231 El::Matrix<DataType, Device>& Acol,
232 const El::Matrix<DataType, Device>& activations,
234 size_t local_batch_size,
236 const std::vector<int>& spatial_dims,
240 const El::SyncInfo<Device>& sync_info);
243 static double compute_pi(
const El::Matrix<DataType, Device>& A,
244 const El::Matrix<DataType, Device>& G,
245 El::Matrix<DataType, Device>& ws,
246 const El::SyncInfo<Device>& sync_info);
257 std::vector<std::tuple<std::string, size_t, size_t>>
258 get_internal_matrix_info()
const override;
266 El::Matrix<DataType, Device> m_kronecker_factor_buf_A,
279 size_t m_Ainv_height = 0, m_Ainv_width = 0, m_Ginv_height = 0,
288 #endif // LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_FC_CONV_HPP_INCLUDED std::string get_info() const override
Get block's information in one line.
size_t m_conv_output_spatial_prod
El::Matrix< DataType, Device > m_grad_buffer_v
Vectorized gradient buffer (only for fully-connecter layers).
convolution_layer< DataType, data_layout::DATA_PARALLEL, Device > * get_conv_layer()
Get the pointer to its convolution_layer.
int get_local_memory_consumption() override
Get local Memory Consumption.
std::vector< int > get_input_dims(size_t input_index=0) const
Get input tensor dimensions.
Neural network tensor operation.
std::vector< int > m_conv_output_spatial_dims
constexpr El::Device Device
El::Matrix< DataType, Device > m_kronecker_factor_buf_G
const bool m_is_conv
Information to perform its computation.
void update_kronecker_average(El::Matrix< DataType, Device > &Aave, const El::Matrix< DataType, Device > &A, size_t count, double decay, const El::SyncInfo< Device > &sync_info)
Update a Kronecker factor matrix using decay.
El::Matrix< DataType, Device > m_kronecker_average_G
El::Matrix< DataType, Device > m_kronecker_inverse_G
std::string get_name() const
Get the layer instance's name.
void conv_transpose(const El::Matrix< DataType, Device > &activations, El::Matrix< DataType, Device > &act_columns, size_t mini_batch_size, size_t num_channels, size_t spatial_prod, const El::SyncInfo< Device > &sync_info)
Transpose NC(D)HW matrix to N(D)HWC.
void im2col(const El::Matrix< DataType, Device > &im, El::Matrix< DataType, Device > &col, const int num_channels, const int im_num_dims, const int *im_dims, const int *im_pads, const int *window_dims, const int *window_strides, const int batch_size, const El::SyncInfo< Device > &sync_info)
im2col.
kfac_block_fc_conv(Layer *layer, kfac::KFACExecutionContext *context, const size_t layer_id, const size_t inverse_proc_rank, const bool enable_copy_errors, const bool enable_copy_activations, const int input_size, const int output_size, const bool is_conv)
void get_diagonal(El::Matrix< DataType, Device > &diag, const El::Matrix< DataType, Device > &A, const El::SyncInfo< Device > &sync_info)
Get diagonal elements of a matrix.
const std::vector< El::AbstractMatrix< DataType > * > get_local_kronecker_buffers() override
Get buffers of Kronecker factors for reduce-scatter.
std::vector< int > get_output_dims(size_t output_index=0) const
Get output tensor dimensions.