-
Notifications
You must be signed in to change notification settings - Fork 3
/
run.py
383 lines (322 loc) · 10.9 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import os
import random
import socket
import numpy as np
import tensorflow as tf
import tensorflow.contrib as tc
import models
import main as graph
from vocab import Vocab
from utils.recorder import Recorder
from utils import dtype, util
logger = tf.get_logger()
logger.propagate = False
# define global initial parameters
global_params = tc.training.HParams(
# whether share source and target word embedding
shared_source_target_embedding=False,
# whether share target and softmax word embedding
shared_target_softmax_embedding=True,
# decoding maximum length: source length + decode_length
decode_length=50,
# beam size
beam_size=4,
# length penalty during beam search
decode_alpha=0.6,
decode_beta=1./6.,
# noise beam search with gumbel
enable_noise_beam_search=False,
# beam search temperature, sharp or flat prediction
beam_search_temperature=1.0,
# return top elements, not used
top_beams=1,
# which version of beam search to use
# cache or dev
search_mode="cache",
# distance considered for PDP
pdp_r=512,
# speech feature number
# not that meaningful, we extracted mel features of dimension 40
# after applying deltas, the feature grows to 120
audio_sample_rate=16000,
audio_preemphasis=0.97,
# note, disable it after training
audio_dither=1.0 / np.iinfo(np.int16).max,
audio_frame_length=25.0,
audio_frame_step=10.0,
audio_lower_edge_hertz=20.0,
audio_upper_edge_hertz=8000.0,
audio_num_mel_bins=80,
audio_add_delta_deltas=True,
# ASR pretrained model path
asr_pretrain="",
# whether filter variables from ASR initialization, such as not initlaize global steps
filter_variables=False,
# lrate decay
# number of shards
nstable=4,
# warmup steps: start point for learning rate stop increaing
warmup_steps=4000,
# select strategy: noam, gnmt+, epoch, score and vanilla
lrate_strategy="noam",
# learning decay rate
lrate_decay=0.5,
# cosine learning rate schedule period
cosine_period=5000,
# cosine factor
cosine_factor=1,
# early stopping
estop_patience=100,
# initialization
# type of initializer
initializer="uniform",
# initializer range control
initializer_gain=0.08,
# parameters for rnnsearch
# encoder and decoder hidden size
hidden_size=1000,
# source and target embedding size
embed_size=620,
# dropout value
dropout=0.1,
relu_dropout=0.1,
residual_dropout=0.1,
# label smoothing value
label_smooth=0.1,
# model name
model_name="transformer",
# scope name
scope_name="transformer",
# filter size for transformer
filter_size=2048,
# attention dropout
attention_dropout=0.1,
# the number of encoder layers, valid for deep nmt
num_encoder_layer=6,
# the number of decoder layers, valid for deep nmt
num_decoder_layer=6,
# the number of attention heads
num_heads=8,
# sample rate * N / 100
max_frame_len=100,
max_text_len=100,
# constant batch size at 'batch' mode for batch-based batching
batch_size=80,
# constant token size at 'token' mode for token-based batching
token_size=3000,
# token or batch-based data iterator
batch_or_token='token',
# batch size for decoding, i.e. number of source sentences decoded at the same time
eval_batch_size=32,
# whether shuffle batches during training
shuffle_batch=True,
# data leak buffer threshold
data_leak_ratio=0.5,
# whether use multiprocessing deal with data reading, default true
process_num=1,
# buffer size controls the number of sentences readed in one time,
buffer_size=100,
# a unique queue in multi-thread reading process
input_queue_size=100,
output_queue_size=100,
# source vocabulary
src_vocab_file="",
# target vocabulary
tgt_vocab_file="",
# source train file
src_train_path="",
src_train_file="",
# target train file
tgt_train_file="",
# ctc train file
ctc_train_file="",
# source development file
src_dev_path="",
src_dev_file="",
# target development file
tgt_dev_file="",
# source test file
src_test_path="",
src_test_file="",
# target test file
tgt_test_file="",
# output directory
output_dir="",
# output during testing
test_output="",
# adam optimizer hyperparameters
beta1=0.9,
beta2=0.999,
epsilon=1e-9,
# gradient clipping value
clip_grad_norm=5.0,
# the gradient norm upper bound, to avoid wired large gradient norm, only works for safe nan mode
gnorm_upper_bound=1e20,
# initial learning rate
lrate=1e-5,
# minimum learning rate
min_lrate=0.0,
# maximum learning rate
max_lrate=1.0,
# maximum epochs
epoches=10,
# the effective batch size is: batch/token size * update_cycle * num_gpus
# sequential update cycle
update_cycle=1,
# the number of gpus
gpus=[0],
# enable safely handle nan
safe_nan=False,
# exponential moving average for stability, disabled by default
ema_decay=-1.,
# enable training deep transformer
deep_transformer_init=False,
# print information every disp_freq training steps
disp_freq=100,
# evaluate on the development file every eval_freq steps
eval_freq=10000,
# save the model parameters every save_freq steps
save_freq=5000,
# print sample translations every sample_freq steps
sample_freq=1000,
# saved checkpoint number
checkpoints=5,
best_checkpoints=1,
# the maximum training steps, program with stop if epochs or max_training_steps is meet
max_training_steps=1000,
# random control, not so well for tensorflow.
random_seed=1234,
# whether or not train from checkpoint
train_continue=True,
# provide interface to modify the default datatype
default_dtype="float32",
dtype_epsilon=1e-8,
dtype_inf=1e8,
loss_scale=1.0,
# speech-specific settings
sinusoid_posenc=True,
max_poslen=2048,
ctc_repeated=False,
ctc_enable=False,
ctc_alpha=0.3, # ctc loss factor
enc_localize="log",
dec_localize="none",
encdec_localize="none",
# cola ctc settings
# -1: disable cola ctc, in our paper we set 256.
cola_ctc_L=-1,
# neural acoustic feature modeling
use_nafm=False,
nafm_alpha=0.05,
)
flags = tf.flags
flags.DEFINE_string("config", "", "Additional Mergable Parameters")
flags.DEFINE_string("parameters", "", "Command Line Refinable Parameters")
flags.DEFINE_string("name", "model", "Description of the training process for distinguishing")
flags.DEFINE_string("mode", "train", "train or test or ensemble")
# saving model configuration
def save_parameters(params, output_dir):
if not tf.gfile.Exists(output_dir):
tf.gfile.MkDir(output_dir)
param_name = os.path.join(output_dir, "param.json")
with tf.gfile.Open(param_name, "w") as writer:
tf.logging.info("Saving parameters into {}"
.format(param_name))
writer.write(params.to_json())
# load model configuration
def load_parameters(params, output_dir):
param_name = os.path.join(output_dir, "param.json")
param_name = os.path.abspath(param_name)
if tf.gfile.Exists(param_name):
tf.logging.info("Loading parameters from {}"
.format(param_name))
with tf.gfile.Open(param_name, 'r') as reader:
json_str = reader.readline()
params.parse_json(json_str)
return params
# build training process recorder
def setup_recorder(params):
recorder = Recorder()
# This is for early stopping, currently I did not use it
recorder.bad_counter = 0 # start from 0
recorder.estop = False
recorder.lidx = -1 # local data index
recorder.step = 0 # global step, start from 0
recorder.epoch = 1 # epoch number, start from 1
recorder.lrate = params.lrate # running learning rate
recorder.history_scores = []
recorder.valid_script_scores = []
# trying to load saved recorder
record_path = os.path.join(params.output_dir, "record.json")
record_path = os.path.abspath(record_path)
if tf.gfile.Exists(record_path):
recorder.load_from_json(record_path)
params.add_hparam('recorder', recorder)
return params
# print model configuration
def print_parameters(params):
tf.logging.info("The Used Configuration:")
for k, v in params.values().items():
tf.logging.info("%s\t%s", k.ljust(20), str(v).ljust(20))
tf.logging.info("")
def main(_):
# set up logger
tf.logging.set_verbosity(tf.logging.INFO)
tf.logging.info("Welcome Using Zero :)")
pid = os.getpid()
tf.logging.info("Your pid is {0} and use the following command to force kill your running:\n"
"'pkill -9 -P {0}; kill -9 {0}'".format(pid))
# On clusters, this could tell which machine you are running
tf.logging.info("Your running machine name is {}".format(socket.gethostname()))
# load registered models
util.dynamic_load_module(models, prefix="models")
params = global_params
# try loading parameters
# priority: command line > saver > default
params.parse(flags.FLAGS.parameters)
if os.path.exists(flags.FLAGS.config):
params.override_from_dict(eval(open(flags.FLAGS.config).read()))
params = load_parameters(params, params.output_dir)
# override
if os.path.exists(flags.FLAGS.config):
params.override_from_dict(eval(open(flags.FLAGS.config).read()))
params.parse(flags.FLAGS.parameters)
# set up random seed
random.seed(params.random_seed)
np.random.seed(params.random_seed)
tf.set_random_seed(params.random_seed)
# loading vocabulary
tf.logging.info("Begin Loading Vocabulary")
start_time = time.time()
params.src_vocab = Vocab(params.src_vocab_file)
params.tgt_vocab = Vocab(params.tgt_vocab_file)
tf.logging.info("End Loading Vocabulary, Source Vocab Size {}, "
"Target Vocab Size {}, within {} seconds"
.format(params.src_vocab.size(), params.tgt_vocab.size(),
time.time() - start_time))
# print parameters
print_parameters(params)
# set up the default datatype
dtype.set_floatx(params.default_dtype)
dtype.set_epsilon(params.dtype_epsilon)
dtype.set_inf(params.dtype_inf)
mode = flags.FLAGS.mode
if mode == "train":
# save parameters
save_parameters(params, params.output_dir)
# load the recorder
params = setup_recorder(params)
graph.train(params)
elif mode == "test":
graph.evaluate(params)
elif mode == "score":
graph.scorer(params)
else:
tf.logging.error("Invalid mode: {}".format(mode))
if __name__ == '__main__':
tf.app.run()