Skip to content

Commit

Permalink
The reformer training for the English-Polish pair (paracrawl).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 321256624
  • Loading branch information
henrykmichalewski authored and copybara-github committed Jul 14, 2020
1 parent fe0fa78 commit 5fb8aa8
Showing 1 changed file with 108 additions and 0 deletions.
108 changes: 108 additions & 0 deletions trax/supervised/configs/reformer_pc_enpl.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2020 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import trax.models
import trax.optimizers
import trax.supervised.tf_inputs
import trax.supervised.trainer_lib

import t5.data.preprocessors
import t5.data.sentencepiece_vocabulary

max_length = 512
sequence_length = {'inputs': 512, 'targets': 512}
mean_noise_span_length = 3.0
noise_density = 0.15

# Parameters for batcher:
# ==============================================================================
batcher.data_streams = @tf_inputs.data_streams
batcher.batch_size_per_device = 256
batcher.eval_batch_size = 64
batcher.max_eval_length = 512
batcher.bucket_length = 32
batcher.buckets_include_inputs_in_length=True
batcher.id_to_mask = 0

# Parameters for data_streams:
# ==============================================================================
data_streams.data_dir = None
data_streams.eval_holdout_size = 0.1
data_streams.dataset_name = 'para_crawl/enpl_plain_text'
data_streams.bare_preprocess_fn = @trax.supervised.tf_inputs.generic_text_dataset_preprocess_fn
data_streams.input_name = 'inputs'
data_streams.target_name = 'targets'

# Parameters for filter_dataset_on_len:
# ==============================================================================
filter_dataset_on_len.len_map = {'inputs': (1, 512), 'targets': (1, 512)}
filter_dataset_on_len.filter_on_eval = True

# Parameters for truncate_dataset_on_len:
# ==============================================================================
truncate_dataset_on_len.len_map = {'inputs': 512, 'targets': 512}
truncate_dataset_on_len.truncate_on_eval = True

# Parameters for generic_text_dataset_preprocess_fn:
# ==============================================================================

generic_text_dataset_preprocess_fn.text_preprocess_fns = [
@rekey/get_t5_preprocessor_by_name()
]
generic_text_dataset_preprocess_fn.token_preprocess_fns = [
@trax.supervised.tf_inputs.truncate_dataset_on_len,
@trax.supervised.tf_inputs.filter_dataset_on_len,
]

# Parameters for get_t5_preprocessor_by_name:
# ==============================================================================
rekey/get_t5_preprocessor_by_name.name = 'rekey'
rekey/get_t5_preprocessor_by_name.fn_kwargs = {'key_map': {'inputs': 'en', 'targets': 'pl'}}

# Parameters for multifactor:
# ==============================================================================
# 0.044 ~= 512^-0.5 = d_model^-0.5
multifactor.constant = 0.088
multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
multifactor.warmup_steps = 8000

# Parameters for Adam:
# ==============================================================================
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-9

# Parameters for train:
# ==============================================================================
train.eval_frequency = 1000
train.eval_steps = 10
train.model = @trax.models.Reformer
train.optimizer = @trax.optimizers.Adam
train.steps = 500000
train.save_graphs = False
train.checkpoints_at = [100000, 200000, 300000, 400000, 500000]

# Parameters for Reformer:
# ==============================================================================
Reformer.d_model= 512
Reformer.d_ff = 2048
Reformer.dropout = 0.1
Reformer.ff_activation = @trax.layers.Relu
Reformer.ff_dropout = 0.1
Reformer.max_len = 2048
Reformer.mode = 'train'
Reformer.n_heads = 8
Reformer.n_encoder_layers = 6
Reformer.n_decoder_layers = 6
Reformer.input_vocab_size = 8064

0 comments on commit 5fb8aa8

Please sign in to comment.