Skip to content

Commit 2f6c162

Browse files
committed
Fix bug in dimensions
1 parent 4153c78 commit 2f6c162

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

models/lstm_decoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ def __init__(self, hidden_size, output_size, device, attention=False):
1818
self.device = device
1919

2020
def forward(self, input, hidden, encoder_hiddens):
21+
# encoder_hiddens has shape [batch_size, seq_len, hidden_dim]
2122
output = self.embedding(input).view(1, 1, -1)
2223

2324
if self.attention:
2425
hiddens = torch.cat((encoder_hiddens, hidden[0].repeat(1, encoder_hiddens.size(1), 1)),
25-
dim=1)
26+
dim=2)
2627
attention_coeff = self.attention_layer(hiddens)
27-
context = torch.mm(torch.squeeze(encoder_hiddens, dim=1).t(), torch.squeeze(
28+
context = torch.mm(torch.squeeze(encoder_hiddens, dim=0).t(), torch.squeeze(
2829
attention_coeff, 2).t()).view(1, 1, -1)
2930
output = torch.cat((output, context), 2)
3031
output = self.attention_combine(output)

0 commit comments

Comments
 (0)