Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8c23cbb

Browse files
T2T Teamcopybara-github
T2T Team
authored andcommitted
Reversible Transformer
PiperOrigin-RevId: 253881109
1 parent eb6d825 commit 8c23cbb

File tree

5 files changed

+479
-0
lines changed

5 files changed

+479
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import tensor2tensor.trax.inputs
2+
import tensor2tensor.trax.models
3+
import tensor2tensor.trax.optimizers
4+
import tensor2tensor.trax.trax
5+
6+
# Parameters for batch_fun:
7+
# ==============================================================================
8+
batch_fun.batch_size_per_device = 32
9+
batch_fun.eval_batch_size = 8
10+
batch_fun.max_eval_length = 12288 # 64 * 64 * 3
11+
12+
# Parameters for inputs:
13+
# ==============================================================================
14+
inputs.data_dir = None
15+
inputs.dataset_name = 't2t_image_imagenet64_gen_flat_rev'
16+
inputs.input_name = 'targets'
17+
inputs.n_chunks = 64
18+
19+
# Parameters for MultifactorSchedule:
20+
# ==============================================================================
21+
MultifactorSchedule.constant = 0.3
22+
MultifactorSchedule.factors = 'constant * linear_warmup'
23+
MultifactorSchedule.warmup_steps = 8000
24+
25+
# Parameters for train:
26+
# ==============================================================================
27+
train.eval_frequency = 100
28+
train.eval_steps = 10
29+
train.inputs = @trax.inputs.inputs
30+
train.model = @trax.models.TransformerRevnetLM
31+
train.optimizer = @trax.optimizers.SM3
32+
train.train_steps = 500000
33+
train.trainer_class = @MemoryEfficientTrainer
34+
35+
# Parameters for TransformerRevnetLM:
36+
# ==============================================================================
37+
TransformerRevnetLM.d_feature = 512
38+
TransformerRevnetLM.d_feedforward = 2048
39+
TransformerRevnetLM.dropout = 0.1
40+
TransformerRevnetLM.max_len = 12288 # 64 * 64 * 3
41+
TransformerRevnetLM.mode = 'train'
42+
TransformerRevnetLM.n_heads = 8
43+
TransformerRevnetLM.n_layers = 6
44+
TransformerRevnetLM.vocab_size = 256
45+
TransformerRevnetLM.n_chunks = 64
46+
TransformerRevnetLM.n_attention_chunks = 64
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import tensor2tensor.trax.inputs
2+
import tensor2tensor.trax.models
3+
import tensor2tensor.trax.optimizers
4+
import tensor2tensor.trax.trax
5+
6+
# Parameters for batch_fun:
7+
# ==============================================================================
8+
batch_fun.batch_size_per_device = 4096
9+
batch_fun.eval_batch_size = 128
10+
batch_fun.max_eval_length = 2048
11+
12+
# Parameters for inputs:
13+
# ==============================================================================
14+
inputs.data_dir = None
15+
inputs.dataset_name = 't2t_languagemodel_lm1b32k_packed'
16+
inputs.input_name = 'targets'
17+
inputs.n_chunks = 32
18+
19+
# Parameters for mask:
20+
# ==============================================================================
21+
masked_mean.mask_id = 0
22+
23+
# Parameters for MultifactorSchedule:
24+
# ==============================================================================
25+
MultifactorSchedule.constant = 0.3
26+
MultifactorSchedule.factors = 'constant * linear_warmup'
27+
MultifactorSchedule.warmup_steps = 8000
28+
29+
# Parameters for preprocess_fun:
30+
# ==============================================================================
31+
shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess
32+
lm1b_preprocess.max_target_length = 512
33+
lm1b_preprocess.max_eval_target_length = 2048
34+
35+
# Parameters for train:
36+
# ==============================================================================
37+
train.eval_frequency = 100
38+
train.eval_steps = 10
39+
train.inputs = @trax.inputs.inputs
40+
train.model = @trax.models.TransformerRevnetLM
41+
train.optimizer = @trax.optimizers.SM3
42+
train.train_steps = 500000
43+
train.trainer_class = @MemoryEfficientTrainer
44+
45+
# Parameters for TransformerRevnetLM:
46+
# ==============================================================================
47+
TransformerRevnetLM.d_feature = 512
48+
TransformerRevnetLM.d_feedforward = 2048
49+
TransformerRevnetLM.dropout = 0.1
50+
TransformerRevnetLM.max_len = 2048
51+
TransformerRevnetLM.mode = 'train'
52+
TransformerRevnetLM.n_heads = 8
53+
TransformerRevnetLM.n_layers = 6
54+
TransformerRevnetLM.vocab_size = 32000
55+
TransformerRevnetLM.n_chunks = 32
56+
TransformerRevnetLM.n_attention_chunks = 8

tensor2tensor/trax/layers/combinators.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,13 @@ def Add(x, **unused_kwargs):
303303
return _binary_op(x, op=sum)
304304

305305

306+
@base.layer(stack_items_to_pass=0)
307+
def SubtractTop(x, **unused_kwargs):
308+
"""Subtract the first element on the stack from the second element."""
309+
# Here x is a list of tensors of the same shape, or nested structures.
310+
return _binary_op(x, op=lambda xs: xs[1] - xs[0])
311+
312+
306313
@base.layer(stack_items_to_pass=0)
307314
def Multiply(x, **unused_kwargs):
308315
"""Multiply first and second element on the stack."""
@@ -372,6 +379,10 @@ def stack_items_to_pass(self):
372379

373380
def call(self, inputs, params=(), **kwargs):
374381
rngs = _pop_rng_and_split(kwargs, self._nlayers)
382+
# Note that zip silently truncates its result if lengths don't match.
383+
assert len(inputs) == self._nlayers
384+
assert len(params) == self._nlayers
385+
assert len(rngs) == self._nlayers
375386
return tuple(layer(x, params=p, rng=r, **kwargs)
376387
for layer, x, p, r in zip(self._layers, inputs, params, rngs))
377388

tensor2tensor/trax/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensor2tensor.trax.models import resnet
2626
from tensor2tensor.trax.models import transformer
2727
from tensor2tensor.trax.models.research import position_lookup_transformer
28+
from tensor2tensor.trax.models.research import transformer_revnet
2829

2930

3031
# Ginify
@@ -42,4 +43,5 @@ def model_configure(*args, **kwargs):
4243
Transformer = model_configure(transformer.Transformer)
4344
TransformerEncoder = model_configure(transformer.TransformerEncoder)
4445
TransformerLM = model_configure(transformer.TransformerLM)
46+
TransformerRevnetLM = model_configure(transformer_revnet.TransformerRevnetLM)
4547
WideResnet = model_configure(resnet.WideResnet)

0 commit comments

Comments
 (0)