@@ -438,21 +438,14 @@ def update_hparams_for_universal_transformer(hparams):
438
438
439
439
@registry .register_hparams
440
440
def universal_transformer_base ():
441
- hparams = transformer .transformer_base ()
442
- # To have a similar capacity to the transformer_base with 6 layers,
443
- # we need to increase the size of the UT's layer
444
- # since, in fact, UT has a single layer repeating multiple times.
445
- hparams .hidden_size = 1024
446
- hparams .filter_size = 4096
447
- hparams .num_heads = 16
448
- hparams .layer_prepostprocess_dropout = 0.3
441
+ hparams = transformer .transformer_big ()
449
442
hparams = update_hparams_for_universal_transformer (hparams )
450
443
return hparams
451
444
452
445
453
446
@registry .register_hparams
454
447
def universal_transformer_base_tpu ():
455
- hparams = universal_transformer_base ()
448
+ hparams = transformer . transformer_big ()
456
449
hparams = update_hparams_for_universal_transformer (hparams )
457
450
transformer .update_hparams_for_tpu (hparams )
458
451
hparams .add_step_timing_signal = False
@@ -461,7 +454,7 @@ def universal_transformer_base_tpu():
461
454
462
455
@registry .register_hparams
463
456
def universal_transformer_big ():
464
- hparams = universal_transformer_base ()
457
+ hparams = transformer . transformer_big ()
465
458
hparams = update_hparams_for_universal_transformer (hparams )
466
459
hparams .hidden_size = 2048
467
460
hparams .filter_size = 8192
0 commit comments