Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 49e7cf5

Browse files
stefan-falkafrozenator
authored andcommitted
Exposing batch_shuffle_size as hparam (#1231)
* Pass data_dir to feature_encoders Pass data_dir to feature_encoders * Fixing error passing wrong data_dir * Exposing batch_shuffle_size as hparam * Checking d_input since d_input may be None
1 parent 57db0b7 commit 49e7cf5

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

tensor2tensor/data_generators/problem.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -802,8 +802,7 @@ def input_fn(self,
802802
config=None,
803803
force_repeat=False,
804804
prevent_repeat=False,
805-
dataset_kwargs=None,
806-
batch_shuffle_size=512):
805+
dataset_kwargs=None):
807806
"""Builds input pipeline for problem.
808807
809808
Args:
@@ -818,8 +817,6 @@ def input_fn(self,
818817
Overrides force_repeat.
819818
dataset_kwargs: dict, if passed, will pass as kwargs to self.dataset
820819
method when called
821-
batch_shuffle_size: int, the size of the buffer to shuffle batches.
822-
if none, the batches will not be shuffled.
823820
824821
Returns:
825822
(features_dict<str name, Tensor feature>, Tensor targets)
@@ -969,8 +966,9 @@ def define_shapes(example):
969966
# buffer size for record shuffling is smaller than the batch size. In such
970967
# cases, adding batch shuffling ensures that the data is in random order
971968
# during training
972-
if is_training and batch_shuffle_size:
973-
dataset = dataset.shuffle(batch_shuffle_size)
969+
if hasattr(hparams, 'batch_shuffle_size'):
970+
if is_training and hparams.batch_shuffle_size:
971+
dataset = dataset.shuffle(hparams.batch_shuffle_size)
974972

975973
def prepare_for_output(example):
976974
if not config or not config.use_tpu:

tensor2tensor/layers/common_hparams.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def basic_params1():
3333
# of tokens per batch per GPU or per TPU core. Otherwise, this is
3434
# the number of examples per GPU or per TPU core.
3535
batch_size=4096,
36+
batch_shuffle_size=512,
3637
# If True, then if the features are of variable length, the batch_size is
3738
# used as the actual batch size (and not tokens per batch).
3839
use_fixed_batch_size=False,

tensor2tensor/utils/decoding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,9 @@ def decode_once(estimator,
319319
if decode_to_file:
320320
for i, (d_input, d_output, d_target) in enumerate(decoded_outputs):
321321
# Skip if all padding
322-
if re.match("^({})+$".format(text_encoder.PAD), d_input):
323-
continue
322+
if d_input:
323+
if re.match("^({})+$".format(text_encoder.PAD), d_input):
324+
continue
324325
beam_score_str = ""
325326
if decode_hp.write_beam_scores:
326327
beam_score_str = "\t%.2f" % decoded_scores[i]

0 commit comments

Comments
 (0)