Skip to content

Commit 2a09779

Browse files
googlehjxkpe
authored andcommitted
internal merge of PR tensorflow#1389
PiperOrigin-RevId: 230774856
1 parent d3081bb commit 2a09779

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

tensor2tensor/data_generators/problem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ def serving_input_fn(self, hparams):
877877
dtype=tf.string, shape=[None], name="serialized_example")
878878
dataset = tf.data.Dataset.from_tensor_slices(serialized_example)
879879
dataset = dataset.map(self.decode_example)
880-
dataset = dataset.map(lambda ex: self.preprocess_example(ex, mode, hparams))
880+
dataset = dataset.map(lambda ex: self.preprocess_example(ex, mode, hparams))
881881
dataset = dataset.map(data_reader.cast_ints_to_int32)
882882
dataset = dataset.padded_batch(
883883
tf.shape(serialized_example, out_type=tf.int64)[0],

tensor2tensor/models/research/universal_transformer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -438,21 +438,14 @@ def update_hparams_for_universal_transformer(hparams):
438438

439439
@registry.register_hparams
440440
def universal_transformer_base():
441-
hparams = transformer.transformer_base()
442-
# To have a similar capacity to the transformer_base with 6 layers,
443-
# we need to increase the size of the UT's layer
444-
# since, in fact, UT has a single layer repeating multiple times.
445-
hparams.hidden_size = 1024
446-
hparams.filter_size = 4096
447-
hparams.num_heads = 16
448-
hparams.layer_prepostprocess_dropout = 0.3
441+
hparams = transformer.transformer_big()
449442
hparams = update_hparams_for_universal_transformer(hparams)
450443
return hparams
451444

452445

453446
@registry.register_hparams
454447
def universal_transformer_base_tpu():
455-
hparams = universal_transformer_base()
448+
hparams = transformer.transformer_big()
456449
hparams = update_hparams_for_universal_transformer(hparams)
457450
transformer.update_hparams_for_tpu(hparams)
458451
hparams.add_step_timing_signal = False
@@ -461,7 +454,7 @@ def universal_transformer_base_tpu():
461454

462455
@registry.register_hparams
463456
def universal_transformer_big():
464-
hparams = universal_transformer_base()
457+
hparams = transformer.transformer_big()
465458
hparams = update_hparams_for_universal_transformer(hparams)
466459
hparams.hidden_size = 2048
467460
hparams.filter_size = 8192

0 commit comments

Comments
 (0)