11class TrainTaskConfig (object ):
2- use_gpu = False
2+ use_gpu = True
33 # the epoch number to train.
4- pass_num = 2
5-
4+ pass_num = 30
65 # the number of sequences contained in a mini-batch.
7- batch_size = 64
8-
6+ batch_size = 32
97 # the hyper parameters for Adam optimizer.
10- learning_rate = 0.001
8+ # This learning_rate final learning rate.
9+ learning_rate = 1
1110 beta1 = 0.9
1211 beta2 = 0.98
1312 eps = 1e-9
14-
1513 # the parameters for learning rate scheduling.
1614 warmup_steps = 4000
17-
1815 # the flag indicating to use average loss or sum loss when training.
19- use_avg_cost = False
20-
16+ use_avg_cost = True
17+ # the weight used to mix up the ground-truth distribution and the fixed
18+ # uniform distribution in label smoothing when training.
19+ # Set this as zero if label smoothing is not wanted.
20+ label_smooth_eps = 0.1
2121 # the directory for saving trained models.
2222 model_dir = "trained_models"
23+ # the directory for saving checkpoints.
24+ ckpt_dir = "trained_ckpts"
25+ # the directory for loading checkpoint.
26+ # If provided, continue training from the checkpoint.
27+ ckpt_path = None
28+ # the parameter to initialize the learning rate scheduler.
29+ # It should be provided if use checkpoints, since the checkpoint doesn't
30+ # include the training step counter currently.
31+ start_step = 0
2332
2433
2534class InferTaskConfig (object ):
26- use_gpu = False
35+ use_gpu = True
2736 # the number of examples in one run for sequence generation.
2837 batch_size = 10
29-
3038 # the parameters for beam search.
3139 beam_size = 5
3240 max_length = 30
3341 # the number of decoded sentences to output.
3442 n_best = 1
35-
3643 # the flags indicating whether to output the special tokens.
3744 output_bos = False
3845 output_eos = False
3946 output_unk = False
40-
4147 # the directory for loading the trained model.
4248 model_path = "trained_models/pass_1.infer.model"
4349
@@ -47,30 +53,24 @@ class ModelHyperParams(object):
4753 # <unk> token has alreay been added. As for the <pad> token, any token
4854 # included in dict can be used to pad, since the paddings' loss will be
4955 # masked out and make no effect on parameter gradients.
50-
5156 # size of source word dictionary.
5257 src_vocab_size = 10000
53-
5458 # size of target word dictionay
5559 trg_vocab_size = 10000
56-
5760 # index for <bos> token
5861 bos_idx = 0
5962 # index for <eos> token
6063 eos_idx = 1
6164 # index for <unk> token
6265 unk_idx = 2
63-
6466 # max length of sequences.
6567 # The size of position encoding table should at least plus 1, since the
6668 # sinusoid position encoding starts from 1 and 0 can be used as the padding
6769 # token for position encoding.
6870 max_length = 50
69-
7071 # the dimension for word embeddings, which is also the last dimension of
7172 # the input and output of multi-head attention, position-wise feed-forward
7273 # networks, encoder and decoder.
73-
7474 d_model = 512
7575 # size of the hidden layer in position-wise feed-forward networks.
7676 d_inner_hid = 1024
@@ -86,34 +86,116 @@ class ModelHyperParams(object):
8686 dropout = 0.1
8787
8888
89+ def merge_cfg_from_list (cfg_list , g_cfgs ):
90+ """
91+ Set the above global configurations using the cfg_list.
92+ """
93+ assert len (cfg_list ) % 2 == 0
94+ for key , value in zip (cfg_list [0 ::2 ], cfg_list [1 ::2 ]):
95+ for g_cfg in g_cfgs :
96+ if hasattr (g_cfg , key ):
97+ try :
98+ value = eval (value )
99+ except SyntaxError : # for file path
100+ pass
101+ setattr (g_cfg , key , value )
102+ break
103+
104+
105+ # Here list the data shapes and data types of all inputs.
106+ # The shapes here act as placeholder and are set to pass the infer-shape in
107+ # compile time.
108+ input_descs = {
109+ # The actual data shape of src_word is:
110+ # [batch_size * max_src_len_in_batch, 1]
111+ "src_word" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
112+ # The actual data shape of src_pos is:
113+ # [batch_size * max_src_len_in_batch, 1]
114+ "src_pos" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
115+ # This input is used to remove attention weights on paddings in the
116+ # encoder.
117+ # The actual data shape of src_slf_attn_bias is:
118+ # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
119+ "src_slf_attn_bias" :
120+ [(1 , ModelHyperParams .n_head , (ModelHyperParams .max_length + 1 ),
121+ (ModelHyperParams .max_length + 1 )), "float32" ],
122+ # This shape input is used to reshape the output of embedding layer.
123+ "src_data_shape" : [(3L , ), "int32" ],
124+ # This shape input is used to reshape before softmax in self attention.
125+ "src_slf_attn_pre_softmax_shape" : [(2L , ), "int32" ],
126+ # This shape input is used to reshape after softmax in self attention.
127+ "src_slf_attn_post_softmax_shape" : [(4L , ), "int32" ],
128+ # The actual data shape of trg_word is:
129+ # [batch_size * max_trg_len_in_batch, 1]
130+ "trg_word" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
131+ # The actual data shape of trg_pos is:
132+ # [batch_size * max_trg_len_in_batch, 1]
133+ "trg_pos" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
134+ # This input is used to remove attention weights on paddings and
135+ # subsequent words in the decoder.
136+ # The actual data shape of trg_slf_attn_bias is:
137+ # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
138+ "trg_slf_attn_bias" : [(1 , ModelHyperParams .n_head ,
139+ (ModelHyperParams .max_length + 1 ),
140+ (ModelHyperParams .max_length + 1 )), "float32" ],
141+ # This input is used to remove attention weights on paddings of the source
142+ # input in the encoder-decoder attention.
143+ # The actual data shape of trg_src_attn_bias is:
144+ # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
145+ "trg_src_attn_bias" : [(1 , ModelHyperParams .n_head ,
146+ (ModelHyperParams .max_length + 1 ),
147+ (ModelHyperParams .max_length + 1 )), "float32" ],
148+ # This shape input is used to reshape the output of embedding layer.
149+ "trg_data_shape" : [(3L , ), "int32" ],
150+ # This shape input is used to reshape before softmax in self attention.
151+ "trg_slf_attn_pre_softmax_shape" : [(2L , ), "int32" ],
152+ # This shape input is used to reshape after softmax in self attention.
153+ "trg_slf_attn_post_softmax_shape" : [(4L , ), "int32" ],
154+ # This shape input is used to reshape before softmax in encoder-decoder
155+ # attention.
156+ "trg_src_attn_pre_softmax_shape" : [(2L , ), "int32" ],
157+ # This shape input is used to reshape after softmax in encoder-decoder
158+ # attention.
159+ "trg_src_attn_post_softmax_shape" : [(4L , ), "int32" ],
160+ # This input is used in independent decoder program for inference.
161+ # The actual data shape of enc_output is:
162+ # [batch_size, max_src_len_in_batch, d_model]
163+ "enc_output" : [(1 , (ModelHyperParams .max_length + 1 ),
164+ ModelHyperParams .d_model ), "float32" ],
165+ # The actual data shape of label_word is:
166+ # [batch_size * max_trg_len_in_batch, 1]
167+ "lbl_word" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
168+ # This input is used to mask out the loss of paddding tokens.
169+ # The actual data shape of label_weight is:
170+ # [batch_size * max_trg_len_in_batch, 1]
171+ "lbl_weight" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "float32" ],
172+ }
173+
89174# Names of position encoding table which will be initialized externally.
90175pos_enc_param_names = (
91176 "src_pos_enc_table" ,
92177 "trg_pos_enc_table" , )
93-
94- # Names of all data layers in encoder listed in order.
95- encoder_input_data_names = (
178+ # separated inputs for different usages.
179+ encoder_data_input_fields = (
96180 "src_word" ,
97181 "src_pos" ,
98- "src_slf_attn_bias" ,
182+ "src_slf_attn_bias" , )
183+ encoder_util_input_fields = (
99184 "src_data_shape" ,
100185 "src_slf_attn_pre_softmax_shape" ,
101186 "src_slf_attn_post_softmax_shape" , )
102-
103- # Names of all data layers in decoder listed in order.
104- decoder_input_data_names = (
187+ decoder_data_input_fields = (
105188 "trg_word" ,
106189 "trg_pos" ,
107190 "trg_slf_attn_bias" ,
108191 "trg_src_attn_bias" ,
192+ "enc_output" , )
193+ decoder_util_input_fields = (
109194 "trg_data_shape" ,
110195 "trg_slf_attn_pre_softmax_shape" ,
111196 "trg_slf_attn_post_softmax_shape" ,
112197 "trg_src_attn_pre_softmax_shape" ,
113- "trg_src_attn_post_softmax_shape" ,
114- "enc_output" , )
115-
116- # Names of label related data layers listed in order.
117- label_data_names = (
198+ "trg_src_attn_post_softmax_shape" , )
199+ label_data_input_fields = (
118200 "lbl_word" ,
119201 "lbl_weight" , )
0 commit comments