Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7384eeb

Browse files
theormCopybara-Service
authored andcommitted
internal merge of PR #1284
PiperOrigin-RevId: 225288942
1 parent 969386a commit 7384eeb

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

tensor2tensor/data_generators/problem.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def eval_metrics(self):
368368
]
369369

370370
def eval_hooks(self, features, logits, hparams):
371+
del features, logits, hparams
371372
return []
372373

373374
@property
@@ -854,9 +855,9 @@ def tpu_valid_size(example):
854855

855856
def gpu_valid_size(example):
856857
drop_long_sequences = is_training or hparams.eval_drop_long_sequences
858+
max_validate_length = max_length if drop_long_sequences else 10**9
857859
return data_reader.example_valid_size(example, hparams.min_length,
858-
max_length
859-
if drop_long_sequences else 10**9)
860+
max_validate_length)
860861

861862
def define_shapes(example):
862863
batch_size = config and config.use_tpu and params["batch_size"]

tensor2tensor/utils/t2t_model.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,11 @@ def _custom_getter(self):
209209
if self.hparams.optimizer != "Adafactor":
210210
raise NotImplementedError(
211211
"weight_dtype=bfloat16 only implemented with Adafactor optimizer")
212+
activation_dtype = tf.float32
213+
if self.hparams.activation_dtype == "bfloat16":
214+
activation_dtype = tf.bfloat16
212215
return quantization.EighthPowerEncoding().custom_getter(
213-
activation_dtype=tf.bfloat16
214-
if self.hparams.activation_dtype == "bfloat16" else tf.float32)
216+
activation_dtype=activation_dtype)
215217
elif self.hparams.activation_dtype == "bfloat16":
216218
return quantization.bfloat16_activations_var_getter
217219
else:
@@ -834,8 +836,9 @@ def _greedy_infer(self, features, decode_length, use_tpu=False):
834836
"losses": a dictionary: {loss-name (string): floating point `Scalar`}
835837
}
836838
"""
837-
return (self._slow_greedy_infer_tpu(features, decode_length)
838-
if use_tpu else self._slow_greedy_infer(features, decode_length))
839+
if use_tpu:
840+
return self._slow_greedy_infer_tpu(features, decode_length)
841+
return self._slow_greedy_infer(features, decode_length)
839842

840843
def _slow_greedy_infer_tpu(self, features, decode_length):
841844
"""A slow greedy inference method on TPU.
@@ -1383,8 +1386,9 @@ def estimator_model_fn(cls,
13831386

13841387
# TRAIN mode
13851388
assert mode == tf.estimator.ModeKeys.TRAIN
1386-
num_async_replicas = (1 if (use_tpu or not config) else
1387-
config.t2t_device_info["num_async_replicas"])
1389+
num_async_replicas = 1
1390+
if config and not use_tpu:
1391+
num_async_replicas = config.t2t_device_info["num_async_replicas"]
13881392
return model.estimator_spec_train(
13891393
loss, num_async_replicas=num_async_replicas, use_tpu=use_tpu)
13901394

@@ -1522,11 +1526,11 @@ def estimator_spec_eval(self, features, logits, labels, loss, losses_dict):
15221526
def estimator_spec_predict(self, features, use_tpu=False):
15231527
"""Constructs `tf.estimator.EstimatorSpec` for PREDICT (inference) mode."""
15241528
decode_hparams = self._decode_hparams
1529+
top_beams = decode_hparams.beam_size if decode_hparams.return_beams else 1
15251530
infer_out = self.infer(
15261531
features,
15271532
beam_size=decode_hparams.beam_size,
1528-
top_beams=(decode_hparams.beam_size
1529-
if decode_hparams.return_beams else 1),
1533+
top_beams=top_beams,
15301534
alpha=decode_hparams.alpha,
15311535
decode_length=decode_hparams.extra_length,
15321536
use_tpu=use_tpu)

0 commit comments

Comments
 (0)