Skip to content

Commit

Permalink
Passing attention correctly to the cnn decode step.
Browse files Browse the repository at this point in the history
  • Loading branch information
bricksdont authored and tdomhan committed Oct 6, 2017
1 parent da6e860 commit dded8c0
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,8 @@ def decode_sequence(self,
# target_embed: (batch_size, target_seq_len, num_target_embed)
target_embed, target_lengths, target_max_length = self.embedding.encode(target, target_lengths,
target_max_length)
target_hidden = self._step(source_encoded_lengths=source_encoded_lengths,
target_hidden = self._step(attention=attention,
source_encoded_lengths=source_encoded_lengths,
source_encoded_max_length=source_encoded_max_length,
target_hidden=target_embed,
target_lengths=target_lengths,
Expand Down Expand Up @@ -1019,7 +1020,8 @@ def decode_step(self,
target_max_length)

# (batch_size, target_max_length, num_hidden)
target_hidden = self._step(source_encoded_lengths=source_encoded_lengths,
target_hidden = self._step(attention=attention,
source_encoded_lengths=source_encoded_lengths,
source_encoded_max_length=source_encoded_max_length,
target_hidden=target_embed,
target_lengths=target_lengths,
Expand All @@ -1043,6 +1045,7 @@ def decode_step(self,


def _step(self,
attention: Callable,
source_encoded_lengths: mx.sym.Symbol,
source_encoded_max_length: mx.sym.Symbol,
target_hidden: mx.sym.Symbol,
Expand Down

0 comments on commit dded8c0

Please sign in to comment.