4
4
5
5
6
6
class LSTMDecoder (nn .Module ):
7
- def __init__ (self , hidden_size , output_size , device , attention = False ):
7
+ def __init__ (self , hidden_size , output_size , device , attention = False , pointer_network = False ):
8
8
super (LSTMDecoder , self ).__init__ ()
9
9
self .hidden_size = hidden_size
10
10
self .output_size = output_size
@@ -13,6 +13,7 @@ def __init__(self, hidden_size, output_size, device, attention=False):
13
13
self .out = nn .Linear (hidden_size , output_size ).to (device )
14
14
self .softmax = nn .LogSoftmax (dim = 1 )
15
15
self .attention = attention
16
+ self .pointer_network = pointer_network
16
17
self .attention_layer = nn .Linear (hidden_size * 2 , 1 ).to (device )
17
18
self .attention_combine = nn .Linear (hidden_size * 2 , hidden_size ).to (device )
18
19
self .device = device
@@ -22,14 +23,26 @@ def forward(self, input, hidden, encoder_hiddens):
22
23
output = self .embedding (input ).view (1 , 1 , - 1 )
23
24
24
25
if self .attention :
26
+ # Create a matrix of shape [batch_size, seq_len, 2 * hidden_dim] where the last
27
+ # dimension is a concatenation of the ith encoder hidden state and the current decoder
28
+ # hidden
25
29
hiddens = torch .cat ((encoder_hiddens , hidden [0 ].repeat (1 , encoder_hiddens .size (1 ), 1 )),
26
30
dim = 2 )
27
- attention_coeff = self .attention_layer (hiddens )
28
- context = torch .mm (torch .squeeze (encoder_hiddens , dim = 0 ).t (), torch .squeeze (
29
- attention_coeff , 2 ).t ()).view (1 , 1 , - 1 )
31
+
32
+ # attention_coeff has shape [seq_len] and contains the attention coeffiecients for
33
+ # each encoder hidden state
34
+ attention_coeff = F .softmax (torch .squeeze (self .attention_layer (hiddens )), dim = 0 )
35
+
36
+ # Make encoder_hiddens of shape [hidden_dim, seq_len] as long as batch size is 1
37
+ encoder_hiddens = torch .squeeze (encoder_hiddens , dim = 0 ).t ()
38
+
39
+ context = torch .matmul (encoder_hiddens , attention_coeff ).view (1 , 1 , - 1 )
30
40
output = torch .cat ((output , context ), 2 )
31
41
output = self .attention_combine (output )
32
42
43
+ elif self .pointer_network :
44
+ pass
45
+
33
46
output = F .relu (output )
34
47
output , hidden = self .gru (output , hidden )
35
48
output = self .softmax (self .out (output [0 ]))
0 commit comments