Skip to content

Commit

Permalink
annotated rnn
Browse files Browse the repository at this point in the history
  • Loading branch information
tiendung committed Nov 23, 2022
1 parent 5fbd58d commit 33decbc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
31 changes: 19 additions & 12 deletions kim/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,26 +416,33 @@ def forward(self, X, h0=None):
"""

seq_len, bs, input_size = X.shape
output = [] # recorded output sequence
h = [] # current hidden states
outputs = [] # the output sequence
hiddens = [] # hidden states, started at h0

# Init h from h0
for layer in range(self.num_layers):
if h0 is None: h.append(None)
else: h.append(h0[layer,:,:].reshape((bs, self.hidden_size)).compact())
# Init hiddens from h0 (a.k.a started at h0)
if h0 is None:
for layer in range(self.num_layers):
hiddens.append(init.zeros(bs, self.hidden_size, dtype=self.dtype, device=self.device))
else:
for layer in range(self.num_layers):
array = h0.cached_data[layer,:,:].reshape((bs, self.hidden_size)).compact()
hiddens.append(Tensor(array, device=self.device))


for i in range(seq_len):
x = X.cached_data[i,:,:].reshape((bs, input_size)).compact()
curr_input = Tensor(X.cached_data[i,:,:].reshape((bs, input_size)).compact(), device=self.device)

# Calculate new hiddens from current input and self.rnn_cells
for layer in range(self.num_layers):
rnn_cell = self.rnn_cells[layer]
x = rnn_cell(Tensor(x, device=self.device), h[layer])
h[layer] = x
prev_hidden = hiddens[layer]
curr_rnn_cell = self.rnn_cells[layer]
curr_hidden = curr_rnn_cell(curr_input, prev_hidden)
hiddens[layer] = curr_hidden
curr_input = curr_hidden

output.append(x)
outputs.append(curr_hidden) # hidden ouput from last layer

return ops.stack(output, 0), ops.stack(h, 0)
return ops.stack(outputs, 0), ops.stack(hiddens, 0)



Expand Down
5 changes: 2 additions & 3 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ python3 -m mugrade submit _r1VOvEAgPZvLXFJ18agr -k "language_model"
python3 -m pytest -l -v -k "language_model_implementation"
python3 -m pytest -l -v -k "language_model_training"


# python3 -m pytest tests/test_sequence_models.py
python3 -m pytest tests/test_sequence_models.py -k "rnn_cell"
# python3 -m pytest tests/test_sequence_models.py -k "rnn_cell"
KIM_DEVICE=cuda python3 -m pytest tests/test_sequence_models.py -k "test_rnn and cuda-relu-False-False-12-11-15-2-13"
KIM_DEVICE=cuda python3 -m pytest tests/test_sequence_models.py -k "test_rnn"
python3 -m pytest tests/test_sequence_models.py


# python3 -m mugrade submit _r1VOvEAgPZvLXFJ18agr -k "conv_forward"
Expand Down

0 comments on commit 33decbc

Please sign in to comment.