@@ -27,7 +27,7 @@ def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_l
27
27
28
28
self .word_embeddings = nn .Embedding (vocab_size , embedding_length )
29
29
self .word_embeddings .weight = nn .Parameter (weights , requires_grad = False )
30
- self .rnn = nn .RNN (embedding_length hidden_size , num_layers = 2 , bidirectional = True )
30
+ self .rnn = nn .RNN (embedding_length , hidden_size , num_layers = 2 , bidirectional = True )
31
31
self .label = nn .Linear (4 * hidden_size , output_size )
32
32
33
33
def forward (self , input_sentences , batch_size = None ):
@@ -52,7 +52,10 @@ def forward(self, input_sentences, batch_size=None):
52
52
else :
53
53
h_0 = Variable (torch .zeros (4 , batch_size , self .hidden_size ).cuda ())
54
54
output , h_n = self .rnn (input , h_0 )
55
+ # h_n.size() = (4, batch_size, hidden_size)
56
+ h_n = h_n .permute (1 , 0 , 2 ) # h_n.size() = (batch_size, 4, hidden_size)
57
+ h_n = h_n .contiguous ().view (h_n .size ()[0 ], h_n .size ()[1 ]* h_n .size ()[2 ])
55
58
# h_n.size() = (batch_size, 4*hidden_size)
56
- logits = self .label (h_n )
59
+ logits = self .label (h_n ) # logits.size() = (batch_size, output_size)
57
60
58
- return logits
61
+ return logits
0 commit comments