Skip to content

Commit fff1606

Browse files
committed
Fix bug in squeezing attention coefficients
1 parent a9b0e19 commit fff1606

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

models/lstm_decoder.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, hidden_size, output_size, device, attention=False, pointer_ne
1818
self.attention_combine = nn.Linear(hidden_size * 2, hidden_size).to(device)
1919
self.device = device
2020

21-
def forward(self, input, hidden, encoder_hiddens):
21+
def forward(self, input, hidden, encoder_hiddens, input_seq=None):
2222
# encoder_hiddens has shape [batch_size, seq_len, hidden_dim]
2323
output = self.embedding(input).view(1, 1, -1)
2424

@@ -31,7 +31,11 @@ def forward(self, input, hidden, encoder_hiddens):
3131

3232
# attention_coeff has shape [seq_len] and contains the attention coeffiecients for
3333
# each encoder hidden state
34-
attention_coeff = F.softmax(torch.squeeze(self.attention_layer(hiddens)), dim=0)
34+
# attention_coeff has shape [batch_size, seq_len, 1]
35+
attention_coeff = self.attention_layer(hiddens)
36+
attention_coeff = torch.squeeze(attention_coeff, dim=2)
37+
attention_coeff = torch.squeeze(attention_coeff, dim=0)
38+
attention_coeff = F.softmax(attention_coeff, dim=0)
3539

3640
# Make encoder_hiddens of shape [hidden_dim, seq_len] as long as batch size is 1
3741
encoder_hiddens = torch.squeeze(encoder_hiddens, dim=0).t()
@@ -41,7 +45,16 @@ def forward(self, input, hidden, encoder_hiddens):
4145
output = self.attention_combine(output)
4246

4347
elif self.pointer_network:
44-
pass
48+
# Create a matrix of shape [batch_size, seq_len, 2 * hidden_dim] where the last
49+
# dimension is a concatenation of the ith encoder hidden state and the current decoder
50+
# hidden
51+
hiddens = torch.cat((encoder_hiddens, hidden[0].repeat(1, encoder_hiddens.size(1), 1)),
52+
dim=2)
53+
54+
# attention_coeff has shape [seq_len] and contains the attention coeffiecients for
55+
# each encoder hidden state
56+
attention_coeff = F.softmax(torch.squeeze(self.attention_layer(hiddens)), dim=0)
57+
# TODO: This is the output already
4558

4659
output = F.relu(output)
4760
output, hidden = self.gru(output, hidden)

0 commit comments

Comments
 (0)