27 #ifndef LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_HPP_INCLUDED 28 #define LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_HPP_INCLUDED 37 class KFACExecutionContext;
41 #if defined AL_HAS_NCCL 43 #elif defined AL_HAS_HOST_TRANSFER 44 using BackendT = ::Al::HostTransferBackend;
49 using ReqT =
typename BackendT::req_type;
54 template <El::Device Device>
63 size_t inverse_proc_rank,
64 bool enable_copy_errors,
65 bool enable_copy_activations,
70 m_inverse_proc_rank(inverse_proc_rank),
71 m_input_size(input_size),
72 m_output_size(output_size),
73 m_enable_copy_errors(enable_copy_errors),
74 m_enable_copy_activations(enable_copy_activations),
77 m_has_kronecker_inverse =
false;
84 virtual int get_local_memory_consumption() = 0;
87 virtual void compute_local_kronecker_factors(
lbann_comm* comm,
89 bool print_matrix_summary);
92 virtual const std::vector<El::AbstractMatrix<DataType>*>
93 get_local_kronecker_buffers();
97 DataType kronecker_decay,
99 bool print_matrix_summary);
102 virtual void update_kronecker_inverse(
lbann_comm* comm,
104 DataType damping_act,
105 DataType damping_err,
106 DataType learning_rate_factor,
107 bool use_eigen_decomposition,
109 bool print_matrix_summary,
113 virtual void compute_preconditioned_gradients(
lbann_comm* comm,
114 DataType learning_rate_factor,
116 bool print_matrix_summary,
121 virtual void initialize_activations_and_errors(
lbann_comm* comm,
122 int num_local_activations,
123 int num_local_errors,
126 virtual void start_communication_forward_end(
lbann_comm* comm) = 0;
127 virtual void end_communication_forward_end(
lbann_comm* comm) = 0;
128 virtual void start_communication_backward_end(
lbann_comm* comm) = 0;
129 virtual void end_communication_backward_end(
lbann_comm* comm) = 0;
132 virtual const std::vector<El::AbstractMatrix<DataType>*>
133 get_preconditioned_grad_buffers();
136 virtual int get_inverse_matrices(El::Matrix<DataType, Device>& output,
140 virtual int get_inverse_matrices_size(
lbann_comm* comm) = 0;
143 virtual std::vector<int>
144 get_inverse_matrices_size_vector(
lbann_comm* comm) = 0;
147 virtual void resize_inverse_matrices_size(
148 El::Matrix<double, El::Device::CPU>& inverse_matrices_size,
149 int block_number) = 0;
152 virtual int set_inverse_matrices(El::Matrix<DataType, Device>& workspace,
161 std::ostringstream oss;
162 oss <<
"name=" << m_layer->get_name() <<
", id=" << m_layer_id
163 <<
", type=" << m_layer->get_type()
164 <<
", inverse_proc_rank=" << m_inverse_proc_rank;
168 std::string
get_name()
const {
return m_layer->get_name(); }
174 return m_parent_local_activations[index]->Buffer();
179 return m_child_local_errors[index]->Buffer();
184 return m_weight_values[index]->Buffer();
189 return m_weight_gradients[index]->Buffer();
201 virtual std::vector<std::tuple<std::string, size_t, size_t>>
202 get_internal_matrix_info()
const;
207 El::Matrix<DataType, Device>&
208 get_workspace_matrix(
const std::string& key,
size_t height,
size_t width);
224 std::vector<std::unique_ptr<AbsDistMat>> m_parent_local_activations,
256 #endif // LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_HPP_INCLUDED bool m_enable_copy_errors
Enable copying of errors to enhance async communication.
Layer * m_layer
The target layer.
const int m_inverse_proc_rank
The process ID which perform inverse on Kronecker.
::Al::MPIBackend BackendT
void set_current_batch_size(El::Int batch_size)
bool m_enable_copy_activations
Enable copying of activations to enhance async communication.
El::Int get_current_batch_size()
DataType * get_gradient_wrt_weight_buffer(int index)
Neural network tensor operation.
std::vector< std::unique_ptr< AbsDistMat > > m_weight_values
Translatebetweengrid funciton has a basic implementation for STAR,STAR distributed matrices...
typename BackendT::req_type ReqT
const size_t m_layer_id
The layer ID in the model. TODO: Remove this.
El::Int get_output_size()
kfac::KFACExecutionContext * m_context
The execution context that created this block. TODO: Use its own workspace and remove this pointer...
void update_kronecker_average(El::Matrix< DataType, Device > &Aave, const El::Matrix< DataType, Device > &A, size_t count, double decay, const El::SyncInfo< Device > &sync_info)
Update a Kronecker factor matrix using decay.
std::vector< kfac::ReqT > m_requests_forward_end
virtual std::string get_info() const
Get block's information in one line.
std::string get_name() const
DataType * get_local_activation_buffer(int index)
DataType * get_local_error_buffer(int index)
bool m_has_kronecker_inverse
Whether this block already has an inverse history.
DataType * get_weight_buffer(int index)
size_t get_inverse_proc_rank() const
virtual void on_forward_prop_end(lbann_comm *comm)
kfac_block(Layer *layer, kfac::KFACExecutionContext *context, size_t layer_id, size_t inverse_proc_rank, bool enable_copy_errors, bool enable_copy_activations, int input_size, int output_size)
std::vector< std::unique_ptr< AbsDistMat > > m_weight_gradients
El::SyncInfo< D > get_sync_info(El::Matrix< TensorDataType, D > const &m) noexcept
Get a SyncInfo from an Matrix.