Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9433a92

Browse files
cfikenCopybara-Service
authored and
Copybara-Service
committed
internal merge of PR #1213
PiperOrigin-RevId: 221821207
1 parent 49e7cf5 commit 9433a92

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

tensor2tensor/data_generators/problem.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,8 @@ def input_fn(self,
802802
config=None,
803803
force_repeat=False,
804804
prevent_repeat=False,
805-
dataset_kwargs=None):
805+
dataset_kwargs=None,
806+
batch_shuffle_size=512):
806807
"""Builds input pipeline for problem.
807808
808809
Args:
@@ -817,6 +818,8 @@ def input_fn(self,
817818
Overrides force_repeat.
818819
dataset_kwargs: dict, if passed, will pass as kwargs to self.dataset
819820
method when called
821+
batch_shuffle_size: int, the size of the buffer to shuffle batches.
822+
if none, the batches will not be shuffled.
820823
821824
Returns:
822825
(features_dict<str name, Tensor feature>, Tensor targets)
@@ -966,9 +969,8 @@ def define_shapes(example):
966969
# buffer size for record shuffling is smaller than the batch size. In such
967970
# cases, adding batch shuffling ensures that the data is in random order
968971
# during training
969-
if hasattr(hparams, 'batch_shuffle_size'):
970-
if is_training and hparams.batch_shuffle_size:
971-
dataset = dataset.shuffle(hparams.batch_shuffle_size)
972+
if is_training and batch_shuffle_size:
973+
dataset = dataset.shuffle(batch_shuffle_size)
972974

973975
def prepare_for_output(example):
974976
if not config or not config.use_tpu:

tensor2tensor/layers/common_hparams.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def basic_params1():
3333
# of tokens per batch per GPU or per TPU core. Otherwise, this is
3434
# the number of examples per GPU or per TPU core.
3535
batch_size=4096,
36-
batch_shuffle_size=512,
3736
# If True, then if the features are of variable length, the batch_size is
3837
# used as the actual batch size (and not tokens per batch).
3938
use_fixed_batch_size=False,

tensor2tensor/utils/decoding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,8 @@ def decode_once(estimator,
319319
if decode_to_file:
320320
for i, (d_input, d_output, d_target) in enumerate(decoded_outputs):
321321
# Skip if all padding
322-
if d_input:
323-
if re.match("^({})+$".format(text_encoder.PAD), d_input):
324-
continue
322+
if re.match("^({})+$".format(text_encoder.PAD), d_input):
323+
continue
325324
beam_score_str = ""
326325
if decode_hp.write_beam_scores:
327326
beam_score_str = "\t%.2f" % decoded_scores[i]

0 commit comments

Comments
 (0)