Skip to content

Commit 4095235

Browse files
committed
Fix attention mechanism
1 parent 2f6c162 commit 4095235

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

models/lstm_decoder.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class LSTMDecoder(nn.Module):
7-
def __init__(self, hidden_size, output_size, device, attention=False):
7+
def __init__(self, hidden_size, output_size, device, attention=False, pointer_network=False):
88
super(LSTMDecoder, self).__init__()
99
self.hidden_size = hidden_size
1010
self.output_size = output_size
@@ -13,6 +13,7 @@ def __init__(self, hidden_size, output_size, device, attention=False):
1313
self.out = nn.Linear(hidden_size, output_size).to(device)
1414
self.softmax = nn.LogSoftmax(dim=1)
1515
self.attention = attention
16+
self.pointer_network = pointer_network
1617
self.attention_layer = nn.Linear(hidden_size * 2, 1).to(device)
1718
self.attention_combine = nn.Linear(hidden_size * 2, hidden_size).to(device)
1819
self.device = device
@@ -22,14 +23,26 @@ def forward(self, input, hidden, encoder_hiddens):
2223
output = self.embedding(input).view(1, 1, -1)
2324

2425
if self.attention:
26+
# Create a matrix of shape [batch_size, seq_len, 2 * hidden_dim] where the last
27+
# dimension is a concatenation of the ith encoder hidden state and the current decoder
28+
# hidden
2529
hiddens = torch.cat((encoder_hiddens, hidden[0].repeat(1, encoder_hiddens.size(1), 1)),
2630
dim=2)
27-
attention_coeff = self.attention_layer(hiddens)
28-
context = torch.mm(torch.squeeze(encoder_hiddens, dim=0).t(), torch.squeeze(
29-
attention_coeff, 2).t()).view(1, 1, -1)
31+
32+
# attention_coeff has shape [seq_len] and contains the attention coeffiecients for
33+
# each encoder hidden state
34+
attention_coeff = F.softmax(torch.squeeze(self.attention_layer(hiddens)), dim=0)
35+
36+
# Make encoder_hiddens of shape [hidden_dim, seq_len] as long as batch size is 1
37+
encoder_hiddens = torch.squeeze(encoder_hiddens, dim=0).t()
38+
39+
context = torch.matmul(encoder_hiddens, attention_coeff).view(1, 1, -1)
3040
output = torch.cat((output, context), 2)
3141
output = self.attention_combine(output)
3242

43+
elif self.pointer_network:
44+
pass
45+
3346
output = F.relu(output)
3447
output, hidden = self.gru(output, hidden)
3548
output = self.softmax(self.out(output[0]))

models/lstm_to_lstm_full_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def train(input_tensor, target_tensor, seq2seq_model, optimizer, criterion):
7878
return loss.item(), pred
7979

8080

81-
def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0.05,
81+
def train_iters(seq2seq_model, n_iters, pairs, print_every=1000, learning_rate=0.01,
8282
model_dir=None, lang=None):
8383
train_losses = []
8484
val_losses = []

0 commit comments

Comments
 (0)