|
| 1 | +import tensor2tensor.trax.inputs |
| 2 | +import tensor2tensor.trax.models |
| 3 | +import tensor2tensor.trax.optimizers |
| 4 | +import tensor2tensor.trax.trax |
| 5 | + |
| 6 | +# Parameters for batch_fun: |
| 7 | +# ============================================================================== |
| 8 | +batch_fun.batch_size_per_device = 128 |
| 9 | +batch_fun.eval_batch_size = 128 |
| 10 | +batch_fun.max_eval_length = 2048 |
| 11 | + |
| 12 | +# Parameters for inputs: |
| 13 | +# ============================================================================== |
| 14 | +inputs.data_dir = None |
| 15 | +inputs.dataset_name = 't2t_languagemodel_lm1b32k' |
| 16 | +inputs.input_name = 'targets' |
| 17 | + |
| 18 | +# Parameters for mask: |
| 19 | +# ============================================================================== |
| 20 | +masked_mean.mask_id = 0 |
| 21 | + |
| 22 | +# Parameters for MultifactorSchedule: |
| 23 | +# ============================================================================== |
| 24 | +MultifactorSchedule.constant = 0.1 |
| 25 | +MultifactorSchedule.factors = 'constant * linear_warmup * rsqrt_decay' |
| 26 | +MultifactorSchedule.warmup_steps = 8000 |
| 27 | + |
| 28 | +# Parameters for preprocess_fun: |
| 29 | +# ============================================================================== |
| 30 | +shuffle_and_batch_data.preprocess_fun=@trax.inputs.lm1b_preprocess |
| 31 | +lm1b_preprocess.max_target_length = 512 |
| 32 | +lm1b_preprocess.max_eval_target_length = 2048 |
| 33 | + |
| 34 | +# Parameters for train: |
| 35 | +# ============================================================================== |
| 36 | +train.eval_frequency = 1000 |
| 37 | +train.eval_steps = 10 |
| 38 | +train.inputs = @trax.inputs.inputs |
| 39 | +train.model = @trax.models.TransformerLM |
| 40 | +train.run_debug_step = False |
| 41 | +train.train_steps = 100000 |
| 42 | + |
| 43 | +# Parameters for TransformerLM: |
| 44 | +# ============================================================================== |
| 45 | +TransformerLM.dropout = 0.1 |
| 46 | +TransformerLM.feature_depth = 512 |
| 47 | +TransformerLM.feedforward_depth = 2048 |
| 48 | +TransformerLM.max_len = 2048 |
| 49 | +TransformerLM.mode = 'train' |
| 50 | +TransformerLM.num_heads = 8 |
| 51 | +TransformerLM.num_layers = 6 |
| 52 | +TransformerLM.vocab_size = 32000 |
0 commit comments