27 #ifndef LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_GRU_HPP_INCLUDED 28 #define LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_GRU_HPP_INCLUDED 35 namespace kfac_gru_util {
71 template <El::Device Device>
73 El::Matrix<DataType, Device>& r,
74 El::Matrix<DataType, Device>& i,
77 size_t local_batch_size,
78 const El::SyncInfo<Device>& sync_info);
81 template <El::Device Device>
83 const El::Matrix<DataType, Device>& R_y,
84 const El::Matrix<DataType, Device>& b_Wy,
85 const El::Matrix<DataType, Device>& b_Ry,
86 const El::Matrix<DataType, Device>& x_t,
87 const El::Matrix<DataType, Device>& hprev_t,
88 const El::Matrix<DataType, Device>& biases_ones,
89 El::Matrix<DataType, Device>& y_t);
92 template <El::Device Device>
93 void get_g(
const El::Matrix<DataType, Device>& h,
94 const El::Matrix<DataType, Device>& h0,
95 const El::Matrix<DataType, Device>& dh,
96 const El::Matrix<DataType, Device>& hfc,
97 const El::Matrix<DataType, Device>& r,
98 const El::Matrix<DataType, Device>& i,
99 El::Matrix<DataType, Device>& g_Rr,
100 El::Matrix<DataType, Device>& g_Ri,
101 El::Matrix<DataType, Device>& g_Rh,
102 El::Matrix<DataType, Device>& g_Wr,
103 El::Matrix<DataType, Device>& g_Wi,
104 El::Matrix<DataType, Device>& g_Wh,
107 size_t local_batch_size,
108 const El::SyncInfo<Device>& sync_info);
114 template <El::Device Device>
123 size_t inverse_proc_rank,
124 bool enable_copy_errors,
125 bool enable_copy_activations,
133 enable_copy_activations,
138 check_dnn_lib_spec();
140 const auto num_layers = get_gru_layer()->get_num_layers();
141 if (num_layers > 1) {
142 std::stringstream err;
143 err <<
"The K-FAC only supports single-layer GRU layer." 144 <<
" layer: " << layer->
get_name() <<
", num_layers: " << num_layers;
151 void on_forward_prop_end(
lbann_comm* comm)
override;
153 const std::vector<El::AbstractMatrix<DataType>*>
154 get_local_kronecker_buffers()
override;
159 LBANN_ERROR(
"this function is not implemented for GRU layer.");
162 void compute_local_kronecker_factors(
lbann_comm* comm,
164 bool print_matrix_summary)
override;
167 DataType kronecker_decay,
169 bool print_matrix_summary)
override;
171 void update_kronecker_inverse(
lbann_comm* comm,
173 DataType damping_act,
174 DataType damping_err,
175 DataType learning_rate_factor,
176 bool use_eigen_decomposition,
178 bool print_matrix_summary,
179 bool print_time)
override;
181 void compute_preconditioned_gradients(
lbann_comm* comm,
182 DataType learning_rate_factor,
184 bool print_matrix_summary,
185 bool print_time)
override;
187 void initialize_activations_and_errors(
lbann_comm* comm,
188 int num_local_activations,
189 int num_local_errors,
190 int num_weights)
override;
193 int get_inverse_matrices(El::Matrix<DataType, Device>& output,
194 int offset)
override;
197 int get_inverse_matrices_size(
lbann_comm* comm)
override;
199 int set_inverse_matrices(El::Matrix<DataType, Device>& workspace,
205 void start_communication_forward_end(
lbann_comm* comm)
override;
206 void end_communication_forward_end(
lbann_comm* comm)
override;
207 void start_communication_backward_end(
lbann_comm* comm)
override;
208 void end_communication_backward_end(
lbann_comm* comm)
override;
213 LBANN_ERROR(
"This function is not yet implemented for GRU layer");
218 El::Matrix<double, El::Device::CPU>& inverse_matrices_size,
219 int block_number)
override 221 LBANN_ERROR(
"This function is not yet implemented for GRU layer");
224 const std::vector<El::AbstractMatrix<DataType>*>
225 get_preconditioned_grad_buffers()
override;
228 void check_dnn_lib_spec()
const;
232 void get_r_i(El::Matrix<DataType, Device>& r,
233 El::Matrix<DataType, Device>& i,
234 const El::Matrix<DataType, Device>& biases_ones,
235 const El::Matrix<DataType, Device>& local_inputs,
236 const El::Matrix<DataType, Device>& local_outputs,
237 const El::Matrix<DataType, Device>& h0,
238 size_t local_batch_size,
239 const El::SyncInfo<Device>& sync_info);
244 El::Matrix<DataType, Device>& view);
247 El::Matrix<DataType, Device>& view);
249 El::Matrix<DataType, Device>& view);
251 std::vector<std::tuple<std::string, size_t, size_t>>
252 get_internal_matrix_info()
const override;
263 const auto input_dims = this->m_layer->get_input_dims();
264 return input_dims[1];
269 const auto input_dims = this->m_layer->get_input_dims();
270 return input_dims[0];
273 void send_recv_reserve_space(
lbann_comm* comm);
279 El::Matrix<DataType, Device> m_kronecker_factor_buf_A_h,
281 std::unordered_map<kfac_gru_util::weight_type, El::Matrix<DataType, Device>>
286 std::unordered_map<kfac_gru_util::weight_type, El::Matrix<DataType, Device>>
291 std::unordered_map<kfac_gru_util::weight_type, El::Matrix<DataType, Device>>
295 std::unordered_map<kfac_gru_util::weight_type, El::Matrix<DataType, Device>>
298 size_t m_reserve_space_fwd_size = 0;
305 #endif // LBANN_EXECUTION_ALGORITHMS_KFAC_KFAC_BLOCK_GRU_HPP_INCLUDED size_t get_hidden_size() const
std::unordered_map< kfac_gru_util::weight_type, El::Matrix< DataType, Device > > m_kronecker_average_G
std::pair< int, int > get_gru_weight_offset(weight_type matrix_type)
Get the weight ID and the row offset ID of a GRU weight matrix.
bool is_matrix_height_hidden(const weight_type &matrix_type)
Return whether the height of a GRU weight matrix is the hidden size.
El::Matrix< DataType, Device > m_kronecker_average_A_x
std::vector< kfac::ReqT > m_requests_workspace
void gru_gate_forward(const El::Matrix< DataType, Device > &W_y, const El::Matrix< DataType, Device > &R_y, const El::Matrix< DataType, Device > &b_Wy, const El::Matrix< DataType, Device > &b_Ry, const El::Matrix< DataType, Device > &x_t, const El::Matrix< DataType, Device > &hprev_t, const El::Matrix< DataType, Device > &biases_ones, El::Matrix< DataType, Device > &y_t)
Compute internal GRU gate state (r or i).
El::Matrix< DataType, Device > m_kronecker_factor_buf_A_x
El::Matrix< DataType, Device > m_kronecker_inverse_A_x
El::Matrix< DataType, Device > m_grad_buffer_A_x
Neural network tensor operation.
std::string get_matrix_type_name(const weight_type &matrix_type)
Get the name of a GRU weight matrix.
constexpr El::Device Device
std::unordered_map< kfac_gru_util::weight_type, El::Matrix< DataType, Device > > m_kronecker_inverse_G
void unpack_reserve_space(const DataType *reserve_space_fwd, El::Matrix< DataType, Device > &r, El::Matrix< DataType, Device > &i, size_t hidden_size, size_t seq_length, size_t local_batch_size, const El::SyncInfo< Device > &sync_info)
Copy r_t and i_t from the reserve space after the forward pass.
kfac_block_gru(Layer *layer, kfac::KFACExecutionContext *context, size_t layer_id, size_t inverse_proc_rank, bool enable_copy_errors, bool enable_copy_activations, int input_size, int output_size)
std::vector< int > get_inverse_matrices_size_vector(lbann_comm *comm) override
Get inverse matrices size vector.
void resize_inverse_matrices_size(El::Matrix< double, El::Device::CPU > &inverse_matrices_size, int block_number) override
Get inverse matrices size vector.
void update_kronecker_average(El::Matrix< DataType, Device > &Aave, const El::Matrix< DataType, Device > &A, size_t count, double decay, const El::SyncInfo< Device > &sync_info)
Update a Kronecker factor matrix using decay.
std::unordered_map< kfac_gru_util::weight_type, El::Matrix< DataType, Device > > m_grad_buffer_G
size_t get_input_size() const
hydrogen::simple_buffer< El::byte, Device > m_reserve_space_fwd
A copy of the reserve space after forward passes.
std::unordered_map< kfac_gru_util::weight_type, El::Matrix< DataType, Device > > m_kronecker_factor_buf_G
Stacked gated recurrent unit.
std::string get_name() const
Get the layer instance's name.
void get_g(const El::Matrix< DataType, Device > &h, const El::Matrix< DataType, Device > &h0, const El::Matrix< DataType, Device > &dh, const El::Matrix< DataType, Device > &hfc, const El::Matrix< DataType, Device > &r, const El::Matrix< DataType, Device > &i, El::Matrix< DataType, Device > &g_Rr, El::Matrix< DataType, Device > &g_Ri, El::Matrix< DataType, Device > &g_Rh, El::Matrix< DataType, Device > &g_Wr, El::Matrix< DataType, Device > &g_Wi, El::Matrix< DataType, Device > &g_Wh, size_t hidde_size, size_t seq_length, size_t local_batch_size, const El::SyncInfo< Device > &sync_info)
Compute d h_t / d g_t.
const std::vector< weight_type > LEARNABLE_MATRICES
gru_layer< DataType, data_layout::DATA_PARALLEL, Device > * get_gru_layer() const
Get the pointer to its GRU_layer.
int get_local_memory_consumption() override
Get local Memory Consumption.
size_t get_seq_length() const