Skip to content

Commit 0a86b8a

Browse files
author
Vineet John
committed
Removed unused references to encoder_output in the attention decoder
1 parent 280662e commit 0a86b8a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

intermediate_source/seq2seq_translation_tutorial.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def __init__(self, hidden_size, output_size, n_layers=1, dropout_p=0.1, max_leng
466466
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
467467
self.out = nn.Linear(self.hidden_size, self.output_size)
468468

469-
def forward(self, input, hidden, encoder_output, encoder_outputs):
469+
def forward(self, input, hidden, encoder_outputs):
470470
embedded = self.embedding(input).view(1, 1, -1)
471471
embedded = self.dropout(embedded)
472472

@@ -591,15 +591,15 @@ def train(input_variable, target_variable, encoder, decoder, encoder_optimizer,
591591
# Teacher forcing: Feed the target as the next input
592592
for di in range(target_length):
593593
decoder_output, decoder_hidden, decoder_attention = decoder(
594-
decoder_input, decoder_hidden, encoder_output, encoder_outputs)
594+
decoder_input, decoder_hidden, encoder_outputs)
595595
loss += criterion(decoder_output, target_variable[di])
596596
decoder_input = target_variable[di] # Teacher forcing
597597

598598
else:
599599
# Without teacher forcing: use its own predictions as the next input
600600
for di in range(target_length):
601601
decoder_output, decoder_hidden, decoder_attention = decoder(
602-
decoder_input, decoder_hidden, encoder_output, encoder_outputs)
602+
decoder_input, decoder_hidden, encoder_outputs)
603603
topv, topi = decoder_output.data.topk(1)
604604
ni = topi[0][0]
605605

@@ -745,7 +745,7 @@ def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
745745

746746
for di in range(max_length):
747747
decoder_output, decoder_hidden, decoder_attention = decoder(
748-
decoder_input, decoder_hidden, encoder_output, encoder_outputs)
748+
decoder_input, decoder_hidden, encoder_outputs)
749749
decoder_attentions[di] = decoder_attention.data
750750
topv, topi = decoder_output.data.topk(1)
751751
ni = topi[0][0]

0 commit comments

Comments
 (0)