Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b40d5da

Browse files
T2T Teamcopybara-github
authored andcommitted
Add testing gin configs.
PiperOrigin-RevId: 246423640
1 parent 0a4e912 commit b40d5da

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import tensor2tensor.trax.inputs
2+
import tensor2tensor.trax.learning_rate
3+
import tensor2tensor.trax.models
4+
import tensor2tensor.trax.optimizers
5+
import tensor2tensor.trax.trax
6+
7+
# Parameters for batch_fun:
8+
# ==============================================================================
9+
batch_fun.batch_size_per_device = 32
10+
batch_fun.bucket_length = 32
11+
batch_fun.buckets = None
12+
batch_fun.eval_batch_size = 32
13+
14+
# Parameters for inputs:
15+
# ==============================================================================
16+
inputs.data_dir = None
17+
inputs.dataset_name = 't2t_image_imagenet224'
18+
19+
# Parameters for MultifactorSchedule:
20+
# ==============================================================================
21+
EvalAdjustingSchedule.constant = 1.0
22+
MultifactorSchedule.factors = 'constant * linear_warmup'
23+
MultifactorSchedule.warmup_steps = 400
24+
25+
# Parameters for momentum:
26+
# ==============================================================================
27+
momentum.mass = 0.9
28+
29+
30+
# Parameters for Resnet50:
31+
# ==============================================================================
32+
Resnet50.hidden_size = 64
33+
Resnet50.num_output_classes = 1001
34+
35+
# Parameters for train:
36+
# ==============================================================================
37+
train.eval_frequency = 2000
38+
train.eval_steps = 20
39+
train.inputs = @trax.inputs.inputs
40+
train.model = @trax.models.Resnet50
41+
train.optimizer = @trax.optimizers.momentum
42+
train.train_steps = 100000
43+
train.lr_schedule = @learning_rate.EvalAdjustingSchedule
44+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import tensor2tensor.trax.inputs
2+
import tensor2tensor.trax.models
3+
import tensor2tensor.trax.optimizers
4+
import tensor2tensor.trax.trax
5+
6+
# Parameters for batch_fun:
7+
# ==============================================================================
8+
batch_fun.batch_size_per_device = 128
9+
batch_fun.eval_batch_size = 128
10+
batch_fun.max_eval_length = 2048
11+
12+
# Parameters for inputs:
13+
# ==============================================================================
14+
inputs.data_dir = None
15+
inputs.dataset_name = 't2t_languagemodel_lm1b32k'
16+
inputs.input_name = 'targets'
17+
18+
# Parameters for mask:
19+
# ==============================================================================
20+
masked_mean.mask_id = 0
21+
22+
# Parameters for MultifactorSchedule:
23+
# ==============================================================================
24+
MultifactorSchedule.constant = 0.1
25+
MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay'
26+
MultifactorSchedule.warmup_steps = 8000
27+
28+
# Parameters for preprocess_fun:
29+
# ==============================================================================
30+
shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess
31+
lm1b_preprocess.max_target_length = 512
32+
lm1b_preprocess.max_eval_target_length = 2048
33+
34+
# Parameters for train:
35+
# ==============================================================================
36+
train.eval_frequency = 1000
37+
train.eval_steps = 10
38+
train.inputs = @trax.inputs.inputs
39+
train.model = @trax.models.TransformerLM
40+
train.run_debug_step = False
41+
train.train_steps = 100000
42+
43+
# Parameters for TransformerLM:
44+
# ==============================================================================
45+
TransformerLM.dropout = 0.1
46+
TransformerLM.feature_depth = 512
47+
TransformerLM.feedforward_depth = 2048
48+
TransformerLM.max_len = 2048
49+
TransformerLM.mode = 'train'
50+
TransformerLM.num_heads = 8
51+
TransformerLM.num_layers = 6
52+
TransformerLM.vocab_size = 32000

0 commit comments

Comments
 (0)