@@ -18,7 +18,7 @@ def __init__(self, hidden_size, output_size, device, attention=False, pointer_ne
18
18
self .attention_combine = nn .Linear (hidden_size * 2 , hidden_size ).to (device )
19
19
self .device = device
20
20
21
- def forward (self , input , hidden , encoder_hiddens ):
21
+ def forward (self , input , hidden , encoder_hiddens , input_seq = None ):
22
22
# encoder_hiddens has shape [batch_size, seq_len, hidden_dim]
23
23
output = self .embedding (input ).view (1 , 1 , - 1 )
24
24
@@ -31,7 +31,11 @@ def forward(self, input, hidden, encoder_hiddens):
31
31
32
32
# attention_coeff has shape [seq_len] and contains the attention coeffiecients for
33
33
# each encoder hidden state
34
- attention_coeff = F .softmax (torch .squeeze (self .attention_layer (hiddens )), dim = 0 )
34
+ # attention_coeff has shape [batch_size, seq_len, 1]
35
+ attention_coeff = self .attention_layer (hiddens )
36
+ attention_coeff = torch .squeeze (attention_coeff , dim = 2 )
37
+ attention_coeff = torch .squeeze (attention_coeff , dim = 0 )
38
+ attention_coeff = F .softmax (attention_coeff , dim = 0 )
35
39
36
40
# Make encoder_hiddens of shape [hidden_dim, seq_len] as long as batch size is 1
37
41
encoder_hiddens = torch .squeeze (encoder_hiddens , dim = 0 ).t ()
@@ -41,7 +45,16 @@ def forward(self, input, hidden, encoder_hiddens):
41
45
output = self .attention_combine (output )
42
46
43
47
elif self .pointer_network :
44
- pass
48
+ # Create a matrix of shape [batch_size, seq_len, 2 * hidden_dim] where the last
49
+ # dimension is a concatenation of the ith encoder hidden state and the current decoder
50
+ # hidden
51
+ hiddens = torch .cat ((encoder_hiddens , hidden [0 ].repeat (1 , encoder_hiddens .size (1 ), 1 )),
52
+ dim = 2 )
53
+
54
+ # attention_coeff has shape [seq_len] and contains the attention coeffiecients for
55
+ # each encoder hidden state
56
+ attention_coeff = F .softmax (torch .squeeze (self .attention_layer (hiddens )), dim = 0 )
57
+ # TODO: This is the output already
45
58
46
59
output = F .relu (output )
47
60
output , hidden = self .gru (output , hidden )
0 commit comments