Skip to content

Commit b27193d

Browse files
authored
Merge pull request #894 from guoshengCS/add-transformer-data-util
Add reader, ParallelExecutor and refine for Transformer
2 parents 1cbbddd + a3ed9b0 commit b27193d

File tree

6 files changed

+884
-323
lines changed

6 files changed

+884
-323
lines changed
Lines changed: 115 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,50 @@
11
class TrainTaskConfig(object):
2-
use_gpu = False
2+
use_gpu = True
33
# the epoch number to train.
4-
pass_num = 2
5-
4+
pass_num = 30
65
# the number of sequences contained in a mini-batch.
7-
batch_size = 64
8-
6+
batch_size = 32
97
# the hyper parameters for Adam optimizer.
10-
learning_rate = 0.001
8+
# This static learning_rate will be multiplied to the LearningRateScheduler
9+
# derived learning rate the to get the final learning rate.
10+
learning_rate = 1
1111
beta1 = 0.9
1212
beta2 = 0.98
1313
eps = 1e-9
14-
1514
# the parameters for learning rate scheduling.
1615
warmup_steps = 4000
17-
1816
# the flag indicating to use average loss or sum loss when training.
19-
use_avg_cost = False
20-
17+
use_avg_cost = True
18+
# the weight used to mix up the ground-truth distribution and the fixed
19+
# uniform distribution in label smoothing when training.
20+
# Set this as zero if label smoothing is not wanted.
21+
label_smooth_eps = 0.1
2122
# the directory for saving trained models.
2223
model_dir = "trained_models"
24+
# the directory for saving checkpoints.
25+
ckpt_dir = "trained_ckpts"
26+
# the directory for loading checkpoint.
27+
# If provided, continue training from the checkpoint.
28+
ckpt_path = None
29+
# the parameter to initialize the learning rate scheduler.
30+
# It should be provided if use checkpoints, since the checkpoint doesn't
31+
# include the training step counter currently.
32+
start_step = 0
2333

2434

2535
class InferTaskConfig(object):
26-
use_gpu = False
36+
use_gpu = True
2737
# the number of examples in one run for sequence generation.
2838
batch_size = 10
29-
3039
# the parameters for beam search.
3140
beam_size = 5
3241
max_length = 30
3342
# the number of decoded sentences to output.
3443
n_best = 1
35-
3644
# the flags indicating whether to output the special tokens.
3745
output_bos = False
3846
output_eos = False
3947
output_unk = False
40-
4148
# the directory for loading the trained model.
4249
model_path = "trained_models/pass_1.infer.model"
4350

@@ -47,30 +54,24 @@ class ModelHyperParams(object):
4754
# <unk> token has alreay been added. As for the <pad> token, any token
4855
# included in dict can be used to pad, since the paddings' loss will be
4956
# masked out and make no effect on parameter gradients.
50-
5157
# size of source word dictionary.
5258
src_vocab_size = 10000
53-
5459
# size of target word dictionay
5560
trg_vocab_size = 10000
56-
5761
# index for <bos> token
5862
bos_idx = 0
5963
# index for <eos> token
6064
eos_idx = 1
6165
# index for <unk> token
6266
unk_idx = 2
63-
6467
# max length of sequences.
6568
# The size of position encoding table should at least plus 1, since the
6669
# sinusoid position encoding starts from 1 and 0 can be used as the padding
6770
# token for position encoding.
6871
max_length = 50
69-
7072
# the dimension for word embeddings, which is also the last dimension of
7173
# the input and output of multi-head attention, position-wise feed-forward
7274
# networks, encoder and decoder.
73-
7475
d_model = 512
7576
# size of the hidden layer in position-wise feed-forward networks.
7677
d_inner_hid = 1024
@@ -86,34 +87,116 @@ class ModelHyperParams(object):
8687
dropout = 0.1
8788

8889

90+
def merge_cfg_from_list(cfg_list, g_cfgs):
91+
"""
92+
Set the above global configurations using the cfg_list.
93+
"""
94+
assert len(cfg_list) % 2 == 0
95+
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
96+
for g_cfg in g_cfgs:
97+
if hasattr(g_cfg, key):
98+
try:
99+
value = eval(value)
100+
except SyntaxError: # for file path
101+
pass
102+
setattr(g_cfg, key, value)
103+
break
104+
105+
106+
# Here list the data shapes and data types of all inputs.
107+
# The shapes here act as placeholder and are set to pass the infer-shape in
108+
# compile time.
109+
input_descs = {
110+
# The actual data shape of src_word is:
111+
# [batch_size * max_src_len_in_batch, 1]
112+
"src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
113+
# The actual data shape of src_pos is:
114+
# [batch_size * max_src_len_in_batch, 1]
115+
"src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
116+
# This input is used to remove attention weights on paddings in the
117+
# encoder.
118+
# The actual data shape of src_slf_attn_bias is:
119+
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
120+
"src_slf_attn_bias":
121+
[(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1),
122+
(ModelHyperParams.max_length + 1)), "float32"],
123+
# This shape input is used to reshape the output of embedding layer.
124+
"src_data_shape": [(3L, ), "int32"],
125+
# This shape input is used to reshape before softmax in self attention.
126+
"src_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
127+
# This shape input is used to reshape after softmax in self attention.
128+
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
129+
# The actual data shape of trg_word is:
130+
# [batch_size * max_trg_len_in_batch, 1]
131+
"trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
132+
# The actual data shape of trg_pos is:
133+
# [batch_size * max_trg_len_in_batch, 1]
134+
"trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
135+
# This input is used to remove attention weights on paddings and
136+
# subsequent words in the decoder.
137+
# The actual data shape of trg_slf_attn_bias is:
138+
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
139+
"trg_slf_attn_bias": [(1, ModelHyperParams.n_head,
140+
(ModelHyperParams.max_length + 1),
141+
(ModelHyperParams.max_length + 1)), "float32"],
142+
# This input is used to remove attention weights on paddings of the source
143+
# input in the encoder-decoder attention.
144+
# The actual data shape of trg_src_attn_bias is:
145+
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
146+
"trg_src_attn_bias": [(1, ModelHyperParams.n_head,
147+
(ModelHyperParams.max_length + 1),
148+
(ModelHyperParams.max_length + 1)), "float32"],
149+
# This shape input is used to reshape the output of embedding layer.
150+
"trg_data_shape": [(3L, ), "int32"],
151+
# This shape input is used to reshape before softmax in self attention.
152+
"trg_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
153+
# This shape input is used to reshape after softmax in self attention.
154+
"trg_slf_attn_post_softmax_shape": [(4L, ), "int32"],
155+
# This shape input is used to reshape before softmax in encoder-decoder
156+
# attention.
157+
"trg_src_attn_pre_softmax_shape": [(2L, ), "int32"],
158+
# This shape input is used to reshape after softmax in encoder-decoder
159+
# attention.
160+
"trg_src_attn_post_softmax_shape": [(4L, ), "int32"],
161+
# This input is used in independent decoder program for inference.
162+
# The actual data shape of enc_output is:
163+
# [batch_size, max_src_len_in_batch, d_model]
164+
"enc_output": [(1, (ModelHyperParams.max_length + 1),
165+
ModelHyperParams.d_model), "float32"],
166+
# The actual data shape of label_word is:
167+
# [batch_size * max_trg_len_in_batch, 1]
168+
"lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
169+
# This input is used to mask out the loss of paddding tokens.
170+
# The actual data shape of label_weight is:
171+
# [batch_size * max_trg_len_in_batch, 1]
172+
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"],
173+
}
174+
89175
# Names of position encoding table which will be initialized externally.
90176
pos_enc_param_names = (
91177
"src_pos_enc_table",
92178
"trg_pos_enc_table", )
93-
94-
# Names of all data layers in encoder listed in order.
95-
encoder_input_data_names = (
179+
# separated inputs for different usages.
180+
encoder_data_input_fields = (
96181
"src_word",
97182
"src_pos",
98-
"src_slf_attn_bias",
183+
"src_slf_attn_bias", )
184+
encoder_util_input_fields = (
99185
"src_data_shape",
100186
"src_slf_attn_pre_softmax_shape",
101187
"src_slf_attn_post_softmax_shape", )
102-
103-
# Names of all data layers in decoder listed in order.
104-
decoder_input_data_names = (
188+
decoder_data_input_fields = (
105189
"trg_word",
106190
"trg_pos",
107191
"trg_slf_attn_bias",
108192
"trg_src_attn_bias",
193+
"enc_output", )
194+
decoder_util_input_fields = (
109195
"trg_data_shape",
110196
"trg_slf_attn_pre_softmax_shape",
111197
"trg_slf_attn_post_softmax_shape",
112198
"trg_src_attn_pre_softmax_shape",
113-
"trg_src_attn_post_softmax_shape",
114-
"enc_output", )
115-
116-
# Names of label related data layers listed in order.
117-
label_data_names = (
199+
"trg_src_attn_post_softmax_shape", )
200+
label_data_input_fields = (
118201
"lbl_word",
119202
"lbl_weight", )

fluid/neural_machine_translation/transformer/infer.py

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import argparse
12
import numpy as np
23

34
import paddle
@@ -6,9 +7,62 @@
67
import model
78
from model import wrap_encoder as encoder
89
from model import wrap_decoder as decoder
9-
from config import InferTaskConfig, ModelHyperParams, \
10-
encoder_input_data_names, decoder_input_data_names
10+
from config import *
1111
from train import pad_batch_data
12+
import reader
13+
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser("Training for Transformer.")
17+
parser.add_argument(
18+
"--src_vocab_fpath",
19+
type=str,
20+
required=True,
21+
help="The path of vocabulary file of source language.")
22+
parser.add_argument(
23+
"--trg_vocab_fpath",
24+
type=str,
25+
required=True,
26+
help="The path of vocabulary file of target language.")
27+
parser.add_argument(
28+
"--test_file_pattern",
29+
type=str,
30+
required=True,
31+
help="The pattern to match test data files.")
32+
parser.add_argument(
33+
"--batch_size",
34+
type=int,
35+
default=50,
36+
help="The number of examples in one run for sequence generation.")
37+
parser.add_argument(
38+
"--pool_size",
39+
type=int,
40+
default=10000,
41+
help="The buffer size to pool data.")
42+
parser.add_argument(
43+
"--special_token",
44+
type=str,
45+
default=["<s>", "<e>", "<unk>"],
46+
nargs=3,
47+
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
48+
parser.add_argument(
49+
'opts',
50+
help='See config.py for all options',
51+
default=None,
52+
nargs=argparse.REMAINDER)
53+
args = parser.parse_args()
54+
# Append args related to dict
55+
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
56+
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
57+
dict_args = [
58+
"src_vocab_size", str(len(src_dict)), "trg_vocab_size",
59+
str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
60+
"eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
61+
str(src_dict[args.special_token[2]])
62+
]
63+
merge_cfg_from_list(args.opts + dict_args,
64+
[InferTaskConfig, ModelHyperParams])
65+
return args
1266

1367

1468
def translate_batch(exe,
@@ -243,7 +297,7 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
243297
return seqs, scores[:, :n_best].tolist()
244298

245299

246-
def main():
300+
def infer(args):
247301
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
248302
exe = fluid.Executor(place)
249303

@@ -292,13 +346,23 @@ def main():
292346
decoder_program = fluid.io.get_inference_program(
293347
target_vars=[predict], main_program=decoder_program)
294348

295-
test_data = paddle.batch(
296-
paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size,
297-
ModelHyperParams.trg_vocab_size),
298-
batch_size=InferTaskConfig.batch_size)
349+
test_data = reader.DataReader(
350+
src_vocab_fpath=args.src_vocab_fpath,
351+
trg_vocab_fpath=args.trg_vocab_fpath,
352+
fpattern=args.test_file_pattern,
353+
batch_size=args.batch_size,
354+
use_token_batch=False,
355+
pool_size=args.pool_size,
356+
sort_type=reader.SortType.NONE,
357+
shuffle=False,
358+
shuffle_batch=False,
359+
start_mark=args.special_token[0],
360+
end_mark=args.special_token[1],
361+
unk_mark=args.special_token[2],
362+
clip_last_batch=False)
299363

300-
trg_idx2word = paddle.dataset.wmt16.get_dict(
301-
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
364+
trg_idx2word = test_data.load_dict(
365+
dict_path=args.trg_vocab_fpath, reverse=True)
302366

303367
def post_process_seq(seq,
304368
bos_idx=ModelHyperParams.bos_idx,
@@ -320,15 +384,16 @@ def post_process_seq(seq,
320384
(output_eos or idx != eos_idx),
321385
seq)
322386

323-
for batch_id, data in enumerate(test_data()):
387+
for batch_id, data in enumerate(test_data.batch_generator()):
324388
batch_seqs, batch_scores = translate_batch(
325389
exe,
326390
[item[0] for item in data],
327391
encoder_program,
328-
encoder_input_data_names,
392+
encoder_data_input_fields + encoder_util_input_fields,
329393
[enc_output.name],
330394
decoder_program,
331-
decoder_input_data_names,
395+
decoder_data_input_fields[:-1] + decoder_util_input_fields +
396+
(decoder_data_input_fields[-1], ),
332397
[predict.name],
333398
InferTaskConfig.beam_size,
334399
InferTaskConfig.max_length,
@@ -351,4 +416,5 @@ def post_process_seq(seq,
351416

352417

353418
if __name__ == "__main__":
354-
main()
419+
args = parse_args()
420+
infer(args)

0 commit comments

Comments
 (0)