11class TrainTaskConfig (object ):
2- use_gpu = False
2+ use_gpu = True
33 # the epoch number to train.
4- pass_num = 2
5-
4+ pass_num = 30
65 # the number of sequences contained in a mini-batch.
7- batch_size = 64
8-
6+ batch_size = 32
97 # the hyper parameters for Adam optimizer.
10- learning_rate = 0.001
8+ # This static learning_rate will be multiplied to the LearningRateScheduler
9+ # derived learning rate the to get the final learning rate.
10+ learning_rate = 1
1111 beta1 = 0.9
1212 beta2 = 0.98
1313 eps = 1e-9
14-
1514 # the parameters for learning rate scheduling.
1615 warmup_steps = 4000
17-
1816 # the flag indicating to use average loss or sum loss when training.
19- use_avg_cost = False
20-
17+ use_avg_cost = True
18+ # the weight used to mix up the ground-truth distribution and the fixed
19+ # uniform distribution in label smoothing when training.
20+ # Set this as zero if label smoothing is not wanted.
21+ label_smooth_eps = 0.1
2122 # the directory for saving trained models.
2223 model_dir = "trained_models"
24+ # the directory for saving checkpoints.
25+ ckpt_dir = "trained_ckpts"
26+ # the directory for loading checkpoint.
27+ # If provided, continue training from the checkpoint.
28+ ckpt_path = None
29+ # the parameter to initialize the learning rate scheduler.
30+ # It should be provided if use checkpoints, since the checkpoint doesn't
31+ # include the training step counter currently.
32+ start_step = 0
2333
2434
2535class InferTaskConfig (object ):
26- use_gpu = False
36+ use_gpu = True
2737 # the number of examples in one run for sequence generation.
2838 batch_size = 10
29-
3039 # the parameters for beam search.
3140 beam_size = 5
3241 max_length = 30
3342 # the number of decoded sentences to output.
3443 n_best = 1
35-
3644 # the flags indicating whether to output the special tokens.
3745 output_bos = False
3846 output_eos = False
3947 output_unk = False
40-
4148 # the directory for loading the trained model.
4249 model_path = "trained_models/pass_1.infer.model"
4350
@@ -47,30 +54,24 @@ class ModelHyperParams(object):
4754 # <unk> token has alreay been added. As for the <pad> token, any token
4855 # included in dict can be used to pad, since the paddings' loss will be
4956 # masked out and make no effect on parameter gradients.
50-
5157 # size of source word dictionary.
5258 src_vocab_size = 10000
53-
5459 # size of target word dictionay
5560 trg_vocab_size = 10000
56-
5761 # index for <bos> token
5862 bos_idx = 0
5963 # index for <eos> token
6064 eos_idx = 1
6165 # index for <unk> token
6266 unk_idx = 2
63-
6467 # max length of sequences.
6568 # The size of position encoding table should at least plus 1, since the
6669 # sinusoid position encoding starts from 1 and 0 can be used as the padding
6770 # token for position encoding.
6871 max_length = 50
69-
7072 # the dimension for word embeddings, which is also the last dimension of
7173 # the input and output of multi-head attention, position-wise feed-forward
7274 # networks, encoder and decoder.
73-
7475 d_model = 512
7576 # size of the hidden layer in position-wise feed-forward networks.
7677 d_inner_hid = 1024
@@ -86,34 +87,116 @@ class ModelHyperParams(object):
8687 dropout = 0.1
8788
8889
90+ def merge_cfg_from_list (cfg_list , g_cfgs ):
91+ """
92+ Set the above global configurations using the cfg_list.
93+ """
94+ assert len (cfg_list ) % 2 == 0
95+ for key , value in zip (cfg_list [0 ::2 ], cfg_list [1 ::2 ]):
96+ for g_cfg in g_cfgs :
97+ if hasattr (g_cfg , key ):
98+ try :
99+ value = eval (value )
100+ except SyntaxError : # for file path
101+ pass
102+ setattr (g_cfg , key , value )
103+ break
104+
105+
106+ # Here list the data shapes and data types of all inputs.
107+ # The shapes here act as placeholder and are set to pass the infer-shape in
108+ # compile time.
109+ input_descs = {
110+ # The actual data shape of src_word is:
111+ # [batch_size * max_src_len_in_batch, 1]
112+ "src_word" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
113+ # The actual data shape of src_pos is:
114+ # [batch_size * max_src_len_in_batch, 1]
115+ "src_pos" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
116+ # This input is used to remove attention weights on paddings in the
117+ # encoder.
118+ # The actual data shape of src_slf_attn_bias is:
119+ # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
120+ "src_slf_attn_bias" :
121+ [(1 , ModelHyperParams .n_head , (ModelHyperParams .max_length + 1 ),
122+ (ModelHyperParams .max_length + 1 )), "float32" ],
123+ # This shape input is used to reshape the output of embedding layer.
124+ "src_data_shape" : [(3L , ), "int32" ],
125+ # This shape input is used to reshape before softmax in self attention.
126+ "src_slf_attn_pre_softmax_shape" : [(2L , ), "int32" ],
127+ # This shape input is used to reshape after softmax in self attention.
128+ "src_slf_attn_post_softmax_shape" : [(4L , ), "int32" ],
129+ # The actual data shape of trg_word is:
130+ # [batch_size * max_trg_len_in_batch, 1]
131+ "trg_word" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
132+ # The actual data shape of trg_pos is:
133+ # [batch_size * max_trg_len_in_batch, 1]
134+ "trg_pos" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
135+ # This input is used to remove attention weights on paddings and
136+ # subsequent words in the decoder.
137+ # The actual data shape of trg_slf_attn_bias is:
138+ # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
139+ "trg_slf_attn_bias" : [(1 , ModelHyperParams .n_head ,
140+ (ModelHyperParams .max_length + 1 ),
141+ (ModelHyperParams .max_length + 1 )), "float32" ],
142+ # This input is used to remove attention weights on paddings of the source
143+ # input in the encoder-decoder attention.
144+ # The actual data shape of trg_src_attn_bias is:
145+ # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
146+ "trg_src_attn_bias" : [(1 , ModelHyperParams .n_head ,
147+ (ModelHyperParams .max_length + 1 ),
148+ (ModelHyperParams .max_length + 1 )), "float32" ],
149+ # This shape input is used to reshape the output of embedding layer.
150+ "trg_data_shape" : [(3L , ), "int32" ],
151+ # This shape input is used to reshape before softmax in self attention.
152+ "trg_slf_attn_pre_softmax_shape" : [(2L , ), "int32" ],
153+ # This shape input is used to reshape after softmax in self attention.
154+ "trg_slf_attn_post_softmax_shape" : [(4L , ), "int32" ],
155+ # This shape input is used to reshape before softmax in encoder-decoder
156+ # attention.
157+ "trg_src_attn_pre_softmax_shape" : [(2L , ), "int32" ],
158+ # This shape input is used to reshape after softmax in encoder-decoder
159+ # attention.
160+ "trg_src_attn_post_softmax_shape" : [(4L , ), "int32" ],
161+ # This input is used in independent decoder program for inference.
162+ # The actual data shape of enc_output is:
163+ # [batch_size, max_src_len_in_batch, d_model]
164+ "enc_output" : [(1 , (ModelHyperParams .max_length + 1 ),
165+ ModelHyperParams .d_model ), "float32" ],
166+ # The actual data shape of label_word is:
167+ # [batch_size * max_trg_len_in_batch, 1]
168+ "lbl_word" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "int64" ],
169+ # This input is used to mask out the loss of paddding tokens.
170+ # The actual data shape of label_weight is:
171+ # [batch_size * max_trg_len_in_batch, 1]
172+ "lbl_weight" : [(1 * (ModelHyperParams .max_length + 1 ), 1L ), "float32" ],
173+ }
174+
89175# Names of position encoding table which will be initialized externally.
90176pos_enc_param_names = (
91177 "src_pos_enc_table" ,
92178 "trg_pos_enc_table" , )
93-
94- # Names of all data layers in encoder listed in order.
95- encoder_input_data_names = (
179+ # separated inputs for different usages.
180+ encoder_data_input_fields = (
96181 "src_word" ,
97182 "src_pos" ,
98- "src_slf_attn_bias" ,
183+ "src_slf_attn_bias" , )
184+ encoder_util_input_fields = (
99185 "src_data_shape" ,
100186 "src_slf_attn_pre_softmax_shape" ,
101187 "src_slf_attn_post_softmax_shape" , )
102-
103- # Names of all data layers in decoder listed in order.
104- decoder_input_data_names = (
188+ decoder_data_input_fields = (
105189 "trg_word" ,
106190 "trg_pos" ,
107191 "trg_slf_attn_bias" ,
108192 "trg_src_attn_bias" ,
193+ "enc_output" , )
194+ decoder_util_input_fields = (
109195 "trg_data_shape" ,
110196 "trg_slf_attn_pre_softmax_shape" ,
111197 "trg_slf_attn_post_softmax_shape" ,
112198 "trg_src_attn_pre_softmax_shape" ,
113- "trg_src_attn_post_softmax_shape" ,
114- "enc_output" , )
115-
116- # Names of label related data layers listed in order.
117- label_data_names = (
199+ "trg_src_attn_post_softmax_shape" , )
200+ label_data_input_fields = (
118201 "lbl_word" ,
119202 "lbl_weight" , )
0 commit comments