@@ -209,9 +209,11 @@ def _custom_getter(self):
209
209
if self .hparams .optimizer != "Adafactor" :
210
210
raise NotImplementedError (
211
211
"weight_dtype=bfloat16 only implemented with Adafactor optimizer" )
212
+ activation_dtype = tf .float32
213
+ if self .hparams .activation_dtype == "bfloat16" :
214
+ activation_dtype = tf .bfloat16
212
215
return quantization .EighthPowerEncoding ().custom_getter (
213
- activation_dtype = tf .bfloat16
214
- if self .hparams .activation_dtype == "bfloat16" else tf .float32 )
216
+ activation_dtype = activation_dtype )
215
217
elif self .hparams .activation_dtype == "bfloat16" :
216
218
return quantization .bfloat16_activations_var_getter
217
219
else :
@@ -834,8 +836,9 @@ def _greedy_infer(self, features, decode_length, use_tpu=False):
834
836
"losses": a dictionary: {loss-name (string): floating point `Scalar`}
835
837
}
836
838
"""
837
- return (self ._slow_greedy_infer_tpu (features , decode_length )
838
- if use_tpu else self ._slow_greedy_infer (features , decode_length ))
839
+ if use_tpu :
840
+ return self ._slow_greedy_infer_tpu (features , decode_length )
841
+ return self ._slow_greedy_infer (features , decode_length )
839
842
840
843
def _slow_greedy_infer_tpu (self , features , decode_length ):
841
844
"""A slow greedy inference method on TPU.
@@ -1383,8 +1386,9 @@ def estimator_model_fn(cls,
1383
1386
1384
1387
# TRAIN mode
1385
1388
assert mode == tf .estimator .ModeKeys .TRAIN
1386
- num_async_replicas = (1 if (use_tpu or not config ) else
1387
- config .t2t_device_info ["num_async_replicas" ])
1389
+ num_async_replicas = 1
1390
+ if config and not use_tpu :
1391
+ num_async_replicas = config .t2t_device_info ["num_async_replicas" ]
1388
1392
return model .estimator_spec_train (
1389
1393
loss , num_async_replicas = num_async_replicas , use_tpu = use_tpu )
1390
1394
@@ -1522,11 +1526,11 @@ def estimator_spec_eval(self, features, logits, labels, loss, losses_dict):
1522
1526
def estimator_spec_predict (self , features , use_tpu = False ):
1523
1527
"""Constructs `tf.estimator.EstimatorSpec` for PREDICT (inference) mode."""
1524
1528
decode_hparams = self ._decode_hparams
1529
+ top_beams = decode_hparams .beam_size if decode_hparams .return_beams else 1
1525
1530
infer_out = self .infer (
1526
1531
features ,
1527
1532
beam_size = decode_hparams .beam_size ,
1528
- top_beams = (decode_hparams .beam_size
1529
- if decode_hparams .return_beams else 1 ),
1533
+ top_beams = top_beams ,
1530
1534
alpha = decode_hparams .alpha ,
1531
1535
decode_length = decode_hparams .extra_length ,
1532
1536
use_tpu = use_tpu )
0 commit comments