Skip to content

Added greedy decoder input #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions 5-1.Transformer/Transformer-Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def __init__(self):
self.dec_enc_attn = MultiHeadAttention()
self.pos_ffn = PoswiseFeedForwardNet()

def forward(self, dec_inputs, enc_outputs, enc_attn_mask):
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, None)
def forward(self, dec_inputs, enc_outputs, enc_attn_mask, dec_attn_mask=None):
dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_attn_mask)
dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, enc_attn_mask)
dec_outputs = self.pos_ffn(dec_outputs)
return dec_outputs, dec_self_attn, dec_enc_attn
Expand All @@ -148,13 +148,15 @@ def __init__(self):
self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(tgt_len+1 , d_model),freeze=True)
self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

def forward(self, dec_inputs, enc_inputs, enc_outputs): # dec_inputs : [batch_size x target_len]
def forward(self, dec_inputs, enc_inputs, enc_outputs, dec_attn_mask=None): # dec_inputs : [batch_size x target_len]
dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,5]]))
dec_enc_attn_pad_mask = get_attn_pad_mask(dec_inputs, enc_inputs)

dec_self_attns, dec_enc_attns = [], []
for layer in self.layers:
dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, enc_attn_mask=dec_enc_attn_pad_mask)
dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs,
enc_attn_mask=dec_enc_attn_pad_mask,
dec_attn_mask=dec_attn_mask)
dec_self_attns.append(dec_self_attn)
dec_enc_attns.append(dec_enc_attn)
return dec_outputs, dec_self_attns, dec_enc_attns
Expand All @@ -165,12 +167,38 @@ def __init__(self):
self.encoder = Encoder()
self.decoder = Decoder()
self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
def forward(self, enc_inputs, dec_inputs):
def forward(self, enc_inputs, dec_inputs, decoder_mask=None):
enc_outputs, enc_self_attns = self.encoder(enc_inputs)
dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs, decoder_mask)
dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size]
return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns


def greedy_decoder(model, enc_input, start_symbol):
"""
For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
:param model: Transformer Model
:param enc_input: The encoder input
:param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
:return: The target input
"""
memory, attention = model.encoder(enc_input)
dec_input = torch.ones(1, 5).fill_(0).type_as(enc_input.data)
dec_mask = torch.from_numpy(np.triu(np.ones((1, 5, 5)), 1).astype('uint8')) == 0
next_symbol = start_symbol
for i in range(0, 5):
dec_input[0][i] = next_symbol
out = model.decoder(Variable(dec_input), enc_input, memory, dec_mask)
projected = model.projection(out[0])
prob = projected.view(-1, projected.size(-1))
prob = prob.data.max(1, keepdim=True)[1]
next_word = prob.data[i]
next_symbol = next_word[0]
return dec_input


model = Transformer()

criterion = nn.CrossEntropyLoss()
Expand All @@ -196,8 +224,10 @@ def showgraph(attn):
ax.set_yticklabels(['']+sentences[2].split(), fontdict={'fontsize': 14})
plt.show()


# Test
predict, _, _, _ = model(enc_inputs, dec_inputs)
greedy_dec_input = greedy_decoder(model, enc_inputs, start_symbol=4)
predict, _, _, _ = model(enc_inputs, greedy_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])

Expand Down