Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 1130cea

Browse files
author
zhangshu
committed
fix workspace size
1 parent e9e26c5 commit 1130cea

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/operator/rnn-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
105105
LOG(FATAL) << "Only LSTM is supported at the moment";
106106
break;
107107
case rnn_enum::kLstm:
108-
size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 3
108+
size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2
109109
+ seq_length * batch_size * hidden_size * direction;
110110
break;
111111
default:

src/operator/rnn_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ void LstmForwardInference(DType* ws,
250250
Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
251251
const int b_size = 2 * H * 4;
252252
const int cell_size = N * H;
253-
DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 3;
253+
DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 2;
254254
DType* y_cur_ptr = y_ptr;
255255
int idx = 0; // state & cell state's idx;
256256
bool flag = L % 2 ? false : true;
@@ -419,7 +419,7 @@ void LstmBackward(DType* ws,
419419
const int w_size1 = (I + H) * H * 4; // first layer
420420
const int w_size2 = (D * H + H) * H * 4; // other layers
421421
const int cell_size = N * H;
422-
DType* dy_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 3;
422+
DType* dy_tmp_ptr = ws + T * cell_size * 4 + cell_size * 3;
423423
for (int i = L - 1; i >= 0; --i) {
424424
const int input_size = i ? H * D : I;
425425
const int w_size = i ? w_size2 : w_size1;

0 commit comments

Comments
 (0)