Skip to content
This repository has been archived by the owner on Mar 23, 2021. It is now read-only.

Commit

Permalink
Work around a TensorFlow bug when only one answer is extracted
Browse files Browse the repository at this point in the history
  • Loading branch information
Tavian Barnes committed May 25, 2018
1 parent 707087e commit 0001f72
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion qgen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,14 @@

batch = expand_answers(batch, answers)

helper = seq2seq.GreedyEmbeddingHelper(embedding, tf.fill([batch["size"]], START_TOKEN), END_TOKEN)
# Work around https://github.com/tensorflow/nmt/issues/117
class FixedHelper(seq2seq.GreedyEmbeddingHelper):
def sample(self, *args, **kwargs):
result = super().sample(*args, **kwargs)
result.set_shape([batch["size"]])
return result

helper = FixedHelper(embedding, tf.fill([batch["size"]], START_TOKEN), END_TOKEN)
decoder = seq2seq.BasicDecoder(decoder_cell, helper, encoder_state, output_layer=projection)

if batch["size"] > 0:
Expand Down

0 comments on commit 0001f72

Please sign in to comment.