Skip to content

Commit 0b724b1

Browse files
Fixed few bugs in RNN.py
1 parent ed799b0 commit 0b724b1

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

models/RNN.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_l
2727

2828
self.word_embeddings = nn.Embedding(vocab_size, embedding_length)
2929
self.word_embeddings.weight = nn.Parameter(weights, requires_grad=False)
30-
self.rnn = nn.RNN(embedding_length hidden_size, num_layers=2, bidirectional=True)
30+
self.rnn = nn.RNN(embedding_length, hidden_size, num_layers=2, bidirectional=True)
3131
self.label = nn.Linear(4*hidden_size, output_size)
3232

3333
def forward(self, input_sentences, batch_size=None):
@@ -52,7 +52,10 @@ def forward(self, input_sentences, batch_size=None):
5252
else:
5353
h_0 = Variable(torch.zeros(4, batch_size, self.hidden_size).cuda())
5454
output, h_n = self.rnn(input, h_0)
55+
# h_n.size() = (4, batch_size, hidden_size)
56+
h_n = h_n.permute(1, 0, 2) # h_n.size() = (batch_size, 4, hidden_size)
57+
h_n = h_n.contiguous().view(h_n.size()[0], h_n.size()[1]*h_n.size()[2])
5558
# h_n.size() = (batch_size, 4*hidden_size)
56-
logits = self.label(h_n)
59+
logits = self.label(h_n) # logits.size() = (batch_size, output_size)
5760

58-
return logits
61+
return logits

0 commit comments

Comments
 (0)