|
| std::string | lbann::kfac_gru_util::get_matrix_type_name (const weight_type &matrix_type) |
| | Get the name of a GRU weight matrix. More...
|
| |
| bool | lbann::kfac_gru_util::is_matrix_height_hidden (const weight_type &matrix_type) |
| | Return whether the height of a GRU weight matrix is the hidden size. More...
|
| |
| std::pair< int, int > | lbann::kfac_gru_util::get_gru_weight_offset (weight_type matrix_type) |
| | Get the weight ID and the row offset ID of a GRU weight matrix. More...
|
| |
| template<El::Device Device> |
| void | lbann::kfac_gru_util::unpack_reserve_space (const DataType *reserve_space_fwd, El::Matrix< DataType, Device > &r, El::Matrix< DataType, Device > &i, size_t hidden_size, size_t seq_length, size_t local_batch_size, const El::SyncInfo< Device > &sync_info) |
| | Copy r_t and i_t from the reserve space after the forward pass. More...
|
| |
| template<El::Device Device> |
| void | lbann::kfac_gru_util::gru_gate_forward (const El::Matrix< DataType, Device > &W_y, const El::Matrix< DataType, Device > &R_y, const El::Matrix< DataType, Device > &b_Wy, const El::Matrix< DataType, Device > &b_Ry, const El::Matrix< DataType, Device > &x_t, const El::Matrix< DataType, Device > &hprev_t, const El::Matrix< DataType, Device > &biases_ones, El::Matrix< DataType, Device > &y_t) |
| | Compute internal GRU gate state (r or i). More...
|
| |
| template<El::Device Device> |
| void | lbann::kfac_gru_util::get_g (const El::Matrix< DataType, Device > &h, const El::Matrix< DataType, Device > &h0, const El::Matrix< DataType, Device > &dh, const El::Matrix< DataType, Device > &hfc, const El::Matrix< DataType, Device > &r, const El::Matrix< DataType, Device > &i, El::Matrix< DataType, Device > &g_Rr, El::Matrix< DataType, Device > &g_Ri, El::Matrix< DataType, Device > &g_Rh, El::Matrix< DataType, Device > &g_Wr, El::Matrix< DataType, Device > &g_Wi, El::Matrix< DataType, Device > &g_Wh, size_t hidde_size, size_t seq_length, size_t local_batch_size, const El::SyncInfo< Device > &sync_info) |
| | Compute d h_t / d g_t. More...
|
| |