Skip to content

Commit 5d14d1c

Browse files
committed
Add reader, ParallelExecutor and refine for Transformer
1 parent e7684f0 commit 5d14d1c

File tree

6 files changed

+847
-322
lines changed

6 files changed

+847
-322
lines changed
Lines changed: 114 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,49 @@
11
class TrainTaskConfig(object):
2-
use_gpu = False
2+
use_gpu = True
33
# the epoch number to train.
4-
pass_num = 2
5-
4+
pass_num = 30
65
# the number of sequences contained in a mini-batch.
7-
batch_size = 64
8-
6+
batch_size = 32
97
# the hyper parameters for Adam optimizer.
10-
learning_rate = 0.001
8+
# This learning_rate final learning rate.
9+
learning_rate = 1
1110
beta1 = 0.9
1211
beta2 = 0.98
1312
eps = 1e-9
14-
1513
# the parameters for learning rate scheduling.
1614
warmup_steps = 4000
17-
1815
# the flag indicating to use average loss or sum loss when training.
19-
use_avg_cost = False
20-
16+
use_avg_cost = True
17+
# the weight used to mix up the ground-truth distribution and the fixed
18+
# uniform distribution in label smoothing when training.
19+
# Set this as zero if label smoothing is not wanted.
20+
label_smooth_eps = 0.1
2121
# the directory for saving trained models.
2222
model_dir = "trained_models"
23+
# the directory for saving checkpoints.
24+
ckpt_dir = "trained_ckpts"
25+
# the directory for loading checkpoint.
26+
# If provided, continue training from the checkpoint.
27+
ckpt_path = None
28+
# the parameter to initialize the learning rate scheduler.
29+
# It should be provided if use checkpoints, since the checkpoint doesn't
30+
# include the training step counter currently.
31+
start_step = 0
2332

2433

2534
class InferTaskConfig(object):
26-
use_gpu = False
35+
use_gpu = True
2736
# the number of examples in one run for sequence generation.
2837
batch_size = 10
29-
3038
# the parameters for beam search.
3139
beam_size = 5
3240
max_length = 30
3341
# the number of decoded sentences to output.
3442
n_best = 1
35-
3643
# the flags indicating whether to output the special tokens.
3744
output_bos = False
3845
output_eos = False
3946
output_unk = False
40-
4147
# the directory for loading the trained model.
4248
model_path = "trained_models/pass_1.infer.model"
4349

@@ -47,30 +53,24 @@ class ModelHyperParams(object):
4753
# <unk> token has alreay been added. As for the <pad> token, any token
4854
# included in dict can be used to pad, since the paddings' loss will be
4955
# masked out and make no effect on parameter gradients.
50-
5156
# size of source word dictionary.
5257
src_vocab_size = 10000
53-
5458
# size of target word dictionay
5559
trg_vocab_size = 10000
56-
5760
# index for <bos> token
5861
bos_idx = 0
5962
# index for <eos> token
6063
eos_idx = 1
6164
# index for <unk> token
6265
unk_idx = 2
63-
6466
# max length of sequences.
6567
# The size of position encoding table should at least plus 1, since the
6668
# sinusoid position encoding starts from 1 and 0 can be used as the padding
6769
# token for position encoding.
6870
max_length = 50
69-
7071
# the dimension for word embeddings, which is also the last dimension of
7172
# the input and output of multi-head attention, position-wise feed-forward
7273
# networks, encoder and decoder.
73-
7474
d_model = 512
7575
# size of the hidden layer in position-wise feed-forward networks.
7676
d_inner_hid = 1024
@@ -86,34 +86,116 @@ class ModelHyperParams(object):
8686
dropout = 0.1
8787

8888

89+
def merge_cfg_from_list(cfg_list, g_cfgs):
90+
"""
91+
Set the above global configurations using the cfg_list.
92+
"""
93+
assert len(cfg_list) % 2 == 0
94+
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
95+
for g_cfg in g_cfgs:
96+
if hasattr(g_cfg, key):
97+
try:
98+
value = eval(value)
99+
except SyntaxError: # for file path
100+
pass
101+
setattr(g_cfg, key, value)
102+
break
103+
104+
105+
# Here list the data shapes and data types of all inputs.
106+
# The shapes here act as placeholder and are set to pass the infer-shape in
107+
# compile time.
108+
input_descs = {
109+
# The actual data shape of src_word is:
110+
# [batch_size * max_src_len_in_batch, 1]
111+
"src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
112+
# The actual data shape of src_pos is:
113+
# [batch_size * max_src_len_in_batch, 1]
114+
"src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
115+
# This input is used to remove attention weights on paddings in the
116+
# encoder.
117+
# The actual data shape of src_slf_attn_bias is:
118+
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
119+
"src_slf_attn_bias":
120+
[(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1),
121+
(ModelHyperParams.max_length + 1)), "float32"],
122+
# This shape input is used to reshape the output of embedding layer.
123+
"src_data_shape": [(3L, ), "int32"],
124+
# This shape input is used to reshape before softmax in self attention.
125+
"src_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
126+
# This shape input is used to reshape after softmax in self attention.
127+
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
128+
# The actual data shape of trg_word is:
129+
# [batch_size * max_trg_len_in_batch, 1]
130+
"trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
131+
# The actual data shape of trg_pos is:
132+
# [batch_size * max_trg_len_in_batch, 1]
133+
"trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
134+
# This input is used to remove attention weights on paddings and
135+
# subsequent words in the decoder.
136+
# The actual data shape of trg_slf_attn_bias is:
137+
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
138+
"trg_slf_attn_bias": [(1, ModelHyperParams.n_head,
139+
(ModelHyperParams.max_length + 1),
140+
(ModelHyperParams.max_length + 1)), "float32"],
141+
# This input is used to remove attention weights on paddings of the source
142+
# input in the encoder-decoder attention.
143+
# The actual data shape of trg_src_attn_bias is:
144+
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
145+
"trg_src_attn_bias": [(1, ModelHyperParams.n_head,
146+
(ModelHyperParams.max_length + 1),
147+
(ModelHyperParams.max_length + 1)), "float32"],
148+
# This shape input is used to reshape the output of embedding layer.
149+
"trg_data_shape": [(3L, ), "int32"],
150+
# This shape input is used to reshape before softmax in self attention.
151+
"trg_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
152+
# This shape input is used to reshape after softmax in self attention.
153+
"trg_slf_attn_post_softmax_shape": [(4L, ), "int32"],
154+
# This shape input is used to reshape before softmax in encoder-decoder
155+
# attention.
156+
"trg_src_attn_pre_softmax_shape": [(2L, ), "int32"],
157+
# This shape input is used to reshape after softmax in encoder-decoder
158+
# attention.
159+
"trg_src_attn_post_softmax_shape": [(4L, ), "int32"],
160+
# This input is used in independent decoder program for inference.
161+
# The actual data shape of enc_output is:
162+
# [batch_size, max_src_len_in_batch, d_model]
163+
"enc_output": [(1, (ModelHyperParams.max_length + 1),
164+
ModelHyperParams.d_model), "float32"],
165+
# The actual data shape of label_word is:
166+
# [batch_size * max_trg_len_in_batch, 1]
167+
"lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
168+
# This input is used to mask out the loss of paddding tokens.
169+
# The actual data shape of label_weight is:
170+
# [batch_size * max_trg_len_in_batch, 1]
171+
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"],
172+
}
173+
89174
# Names of position encoding table which will be initialized externally.
90175
pos_enc_param_names = (
91176
"src_pos_enc_table",
92177
"trg_pos_enc_table", )
93-
94-
# Names of all data layers in encoder listed in order.
95-
encoder_input_data_names = (
178+
# separated inputs for different usages.
179+
encoder_data_input_fields = (
96180
"src_word",
97181
"src_pos",
98-
"src_slf_attn_bias",
182+
"src_slf_attn_bias", )
183+
encoder_util_input_fields = (
99184
"src_data_shape",
100185
"src_slf_attn_pre_softmax_shape",
101186
"src_slf_attn_post_softmax_shape", )
102-
103-
# Names of all data layers in decoder listed in order.
104-
decoder_input_data_names = (
187+
decoder_data_input_fields = (
105188
"trg_word",
106189
"trg_pos",
107190
"trg_slf_attn_bias",
108191
"trg_src_attn_bias",
192+
"enc_output", )
193+
decoder_util_input_fields = (
109194
"trg_data_shape",
110195
"trg_slf_attn_pre_softmax_shape",
111196
"trg_slf_attn_post_softmax_shape",
112197
"trg_src_attn_pre_softmax_shape",
113-
"trg_src_attn_post_softmax_shape",
114-
"enc_output", )
115-
116-
# Names of label related data layers listed in order.
117-
label_data_names = (
198+
"trg_src_attn_post_softmax_shape", )
199+
label_data_input_fields = (
118200
"lbl_word",
119201
"lbl_weight", )

fluid/neural_machine_translation/transformer/infer.py

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import numpy as np
23

34
import paddle
@@ -6,9 +7,52 @@
67
import model
78
from model import wrap_encoder as encoder
89
from model import wrap_decoder as decoder
9-
from config import InferTaskConfig, ModelHyperParams, \
10-
encoder_input_data_names, decoder_input_data_names
10+
from config import *
1111
from train import pad_batch_data
12+
import reader
13+
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser("Training for Transformer.")
17+
parser.add_argument(
18+
"--src_vocab_fpath",
19+
type=str,
20+
required=True,
21+
help="The path of vocabulary file of source language.")
22+
parser.add_argument(
23+
"--trg_vocab_fpath",
24+
type=str,
25+
required=True,
26+
help="The path of vocabulary file of target language.")
27+
parser.add_argument(
28+
"--test_file_pattern",
29+
type=str,
30+
required=True,
31+
help="The pattern to match test data files.")
32+
parser.add_argument(
33+
"--batch_size",
34+
type=int,
35+
default=50,
36+
help="The number of examples in one run for sequence generation.")
37+
parser.add_argument(
38+
"--pool_size",
39+
type=int,
40+
default=10000,
41+
help="The buffer size to pool data.")
42+
parser.add_argument(
43+
"--special_token",
44+
type=str,
45+
default=["<s>", "<e>", "<unk>"],
46+
nargs=3,
47+
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
48+
parser.add_argument(
49+
'opts',
50+
help='See config.py for all options',
51+
default=None,
52+
nargs=argparse.REMAINDER)
53+
args = parser.parse_args()
54+
merge_cfg_from_list(args.opts, [InferTaskConfig, ModelHyperParams])
55+
return args
1256

1357

1458
def translate_batch(exe,
@@ -243,7 +287,7 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
243287
return seqs, scores[:, :n_best].tolist()
244288

245289

246-
def main():
290+
def infer(args):
247291
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
248292
exe = fluid.Executor(place)
249293

@@ -292,13 +336,23 @@ def main():
292336
decoder_program = fluid.io.get_inference_program(
293337
target_vars=[predict], main_program=decoder_program)
294338

295-
test_data = paddle.batch(
296-
paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size,
297-
ModelHyperParams.trg_vocab_size),
298-
batch_size=InferTaskConfig.batch_size)
339+
test_data = reader.DataReader(
340+
src_vocab_fpath=args.src_vocab_fpath,
341+
trg_vocab_fpath=args.trg_vocab_fpath,
342+
fpattern=args.test_file_pattern,
343+
batch_size=args.batch_size,
344+
use_token_batch=False,
345+
pool_size=args.pool_size,
346+
sort_type=reader.SortType.NONE,
347+
shuffle=False,
348+
shuffle_batch=False,
349+
start_mark=args.special_token[0],
350+
end_mark=args.special_token[1],
351+
unk_mark=args.special_token[2],
352+
clip_last_batch=False)
299353

300-
trg_idx2word = paddle.dataset.wmt16.get_dict(
301-
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
354+
trg_idx2word = test_data._load_dict(
355+
dict_path=args.trg_vocab_fpath, reverse=True)
302356

303357
def post_process_seq(seq,
304358
bos_idx=ModelHyperParams.bos_idx,
@@ -320,15 +374,16 @@ def post_process_seq(seq,
320374
(output_eos or idx != eos_idx),
321375
seq)
322376

323-
for batch_id, data in enumerate(test_data()):
377+
for batch_id, data in enumerate(test_data.batch_generator()):
324378
batch_seqs, batch_scores = translate_batch(
325379
exe,
326380
[item[0] for item in data],
327381
encoder_program,
328-
encoder_input_data_names,
382+
encoder_data_input_fields + encoder_util_input_fields,
329383
[enc_output.name],
330384
decoder_program,
331-
decoder_input_data_names,
385+
decoder_data_input_fields[:-1] + decoder_util_input_fields +
386+
(decoder_data_input_fields[-1], ),
332387
[predict.name],
333388
InferTaskConfig.beam_size,
334389
InferTaskConfig.max_length,
@@ -351,4 +406,5 @@ def post_process_seq(seq,
351406

352407

353408
if __name__ == "__main__":
354-
main()
409+
args = parse_args()
410+
infer(args)

0 commit comments

Comments
 (0)