Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Fix universal transformer decoding #1257

Merged
merged 10 commits into from
Dec 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ You can choose the bAbi task from the range [1,20] and the subset from 1k or
10k. To combine test data from all tasks into a single test set, use
`--problem=babi_qa_concat_all_tasks_10k`


### Image Classification

For image classification, we have a number of standard data-sets:
Expand Down
5 changes: 2 additions & 3 deletions tensor2tensor/data_generators/babi_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def _prepare_babi_data(tmp_dir, data_dir):
tf.gfile.MakeDirs(data_dir)

file_path = os.path.join(tmp_dir, _TAR)
headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36"} # pylint: disable=line-too-long
headers = {'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/63.0.3239.132 Safari/537.36'}
resp = requests.get(_URL, headers=headers)
with open(file_path, "wb") as f:
with open(file_path, 'wb') as f:
f.write(resp.content)

tar = tarfile.open(file_path)
Expand Down Expand Up @@ -459,7 +459,6 @@ def hparams(self, defaults, unused_model_hparams):
if "context" in p.vocab_size:
del p.vocab_size["context"]


def _problems_to_register():
"""Problems for which we want to create datasets.

Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/models/research/universal_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _greedy_infer(self, features, decode_length, use_tpu=False):
return (self._slow_greedy_infer_tpu(features, decode_length) if use_tpu else
self._slow_greedy_infer(features, decode_length))

def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha, use_tpu=False):
"""Beam search decoding.

Args:
Expand All @@ -266,7 +266,7 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
# Caching is not ebabled in Universal Transformer
# TODO(dehghani): Support fast decoding for Universal Transformer
return self._beam_decode_slow(features, decode_length, beam_size,
top_beams, alpha)
top_beams, alpha, use_tpu)


@registry.register_model
Expand Down