33
33
import json
34
34
35
35
import tensorflow_tts
36
- from examples .fastspeech2_libritts .fastspeech2_dataset import \
37
- CharactorDurationF0EnergyMelDataset
36
+ from examples .fastspeech2_libritts .fastspeech2_dataset import (
37
+ CharactorDurationF0EnergyMelDataset ,
38
+ )
38
39
from tensorflow_tts .configs import FastSpeech2Config
39
40
from tensorflow_tts .models import TFFastSpeech2
40
41
from tensorflow_tts .optimizers import AdamWeightDecay , WarmUp
41
42
from tensorflow_tts .trainers import Seq2SeqBasedTrainer
42
- from tensorflow_tts .utils import (calculate_2d_loss , calculate_3d_loss ,
43
- return_strategy , TFGriffinLim )
43
+ from tensorflow_tts .utils import (
44
+ calculate_2d_loss ,
45
+ calculate_3d_loss ,
46
+ return_strategy ,
47
+ TFGriffinLim ,
48
+ )
44
49
45
50
46
51
class FastSpeech2Trainer (Seq2SeqBasedTrainer ):
47
52
"""FastSpeech2 Trainer class based on FastSpeechTrainer."""
48
53
49
54
def __init__ (
50
- self , config , strategy , steps = 0 , epochs = 0 , is_mixed_precision = False , stats_path : str = "" ,
51
- dataset_config : str = ""
55
+ self ,
56
+ config ,
57
+ strategy ,
58
+ steps = 0 ,
59
+ epochs = 0 ,
60
+ is_mixed_precision = False ,
61
+ stats_path : str = "" ,
62
+ dataset_config : str = "" ,
52
63
):
53
64
"""Initialize trainer.
54
65
Args:
@@ -78,7 +89,9 @@ def __init__(
78
89
self .use_griffin = config .get ("use_griffin" , False )
79
90
self .griffin_lim_tf = None
80
91
if self .use_griffin :
81
- logging .info (f"Load griff stats from { stats_path } and config from { dataset_config } " )
92
+ logging .info (
93
+ f"Load griff stats from { stats_path } and config from { dataset_config } "
94
+ )
82
95
self .griff_conf = yaml .load (open (dataset_config ), Loader = yaml .Loader )
83
96
self .prepare_grim (stats_path , self .griff_conf )
84
97
@@ -160,7 +173,9 @@ def generate_and_save_intermediate_result(self, batch):
160
173
161
174
# check directory
162
175
if self .use_griffin :
163
- griff_dir_name = os .path .join (self .config ["outdir" ], f"predictions/{ self .steps } _wav" )
176
+ griff_dir_name = os .path .join (
177
+ self .config ["outdir" ], f"predictions/{ self .steps } _wav"
178
+ )
164
179
if not os .path .exists (griff_dir_name ):
165
180
os .makedirs (griff_dir_name )
166
181
@@ -171,23 +186,31 @@ def generate_and_save_intermediate_result(self, batch):
171
186
for idx , (mel_gt , mel_before , mel_after ) in enumerate (
172
187
zip (mel_gts , mels_before , mels_after ), 0
173
188
):
174
-
175
-
189
+
176
190
if self .use_griffin :
177
191
utt_id = utt_ids [idx ]
178
- grif_before = self .griffin_lim_tf (tf .reshape (mel_before , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32 )
179
- grif_after = self .griffin_lim_tf (tf .reshape (mel_after , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32 )
180
- grif_gt = self .griffin_lim_tf (tf .reshape (mel_gt , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32 )
181
- self .griffin_lim_tf .save_wav (grif_before , griff_dir_name , f"{ utt_id } _before" )
182
- self .griffin_lim_tf .save_wav (grif_after , griff_dir_name , f"{ utt_id } _after" )
192
+ grif_before = self .griffin_lim_tf (
193
+ tf .reshape (mel_before , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32
194
+ )
195
+ grif_after = self .griffin_lim_tf (
196
+ tf .reshape (mel_after , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32
197
+ )
198
+ grif_gt = self .griffin_lim_tf (
199
+ tf .reshape (mel_gt , [- 1 , 80 ])[tf .newaxis , :], n_iter = 32
200
+ )
201
+ self .griffin_lim_tf .save_wav (
202
+ grif_before , griff_dir_name , f"{ utt_id } _before"
203
+ )
204
+ self .griffin_lim_tf .save_wav (
205
+ grif_after , griff_dir_name , f"{ utt_id } _after"
206
+ )
183
207
self .griffin_lim_tf .save_wav (grif_gt , griff_dir_name , f"{ utt_id } _gt" )
184
-
208
+
185
209
utt_id = utt_ids [idx ]
186
210
mel_gt = tf .reshape (mel_gt , (- 1 , 80 )).numpy () # [length, 80]
187
211
mel_before = tf .reshape (mel_before , (- 1 , 80 )).numpy () # [length, 80]
188
212
mel_after = tf .reshape (mel_after , (- 1 , 80 )).numpy () # [length, 80]
189
213
190
-
191
214
# plit figure and save it
192
215
figname = os .path .join (dirname , f"{ utt_id } .png" )
193
216
fig = plt .figure (figsize = (10 , 8 ))
@@ -229,10 +252,7 @@ def main():
229
252
"--use-norm" , default = 1 , type = int , help = "usr norm-mels for train or raw."
230
253
)
231
254
parser .add_argument (
232
- "--f0-stat" ,
233
- default = "./dump/stats_f0.npy" ,
234
- type = str ,
235
- help = "f0-stat path." ,
255
+ "--f0-stat" , default = "./dump/stats_f0.npy" , type = str , help = "f0-stat path." ,
236
256
)
237
257
parser .add_argument (
238
258
"--energy-stat" ,
@@ -266,26 +286,20 @@ def main():
266
286
help = "using mixed precision for generator or not." ,
267
287
)
268
288
parser .add_argument (
269
- "--dataset_config" ,
270
- default = "preprocess/libritts_preprocess.yaml" ,
271
- type = str ,
289
+ "--dataset_config" , default = "preprocess/libritts_preprocess.yaml" , type = str ,
272
290
)
273
291
parser .add_argument (
274
- "--dataset_stats" ,
275
- default = "dump/stats.npy" ,
276
- type = str ,
292
+ "--dataset_stats" , default = "dump/stats.npy" , type = str ,
277
293
)
278
294
parser .add_argument (
279
- "--dataset_mapping" ,
280
- default = "dump/libritts_mapper.npy" ,
281
- type = str ,
295
+ "--dataset_mapping" , default = "dump/libritts_mapper.npy" , type = str ,
282
296
)
283
297
parser .add_argument (
284
298
"--pretrained" ,
285
299
default = "" ,
286
300
type = str ,
287
301
nargs = "?" ,
288
- help = ' pretrained weights .h5 file to load weights from. Auto-skips non-matching layers' ,
302
+ help = " pretrained weights .h5 file to load weights from. Auto-skips non-matching layers" ,
289
303
)
290
304
args = parser .parse_args ()
291
305
@@ -362,7 +376,9 @@ def main():
362
376
363
377
# Check n_speakers matches number of speakers in speakers_map
364
378
n_speakers = config ["fastspeech2_params" ]["n_speakers" ]
365
- assert n_speakers == len (speakers_map ), f"Number of speakers in dataset does not match n_speakers in config"
379
+ assert n_speakers == len (
380
+ speakers_map
381
+ ), f"Number of speakers in dataset does not match n_speakers in config"
366
382
367
383
# define train/valid dataset
368
384
train_dataset = CharactorDurationF0EnergyMelDataset (
@@ -375,11 +391,13 @@ def main():
375
391
f0_stat = args .f0_stat ,
376
392
energy_stat = args .energy_stat ,
377
393
mel_length_threshold = mel_length_threshold ,
378
- speakers_map = speakers_map
394
+ speakers_map = speakers_map ,
379
395
).create (
380
396
is_shuffle = config ["is_shuffle" ],
381
397
allow_cache = config ["allow_cache" ],
382
- batch_size = config ["batch_size" ] * STRATEGY .num_replicas_in_sync ,
398
+ batch_size = config ["batch_size" ]
399
+ * STRATEGY .num_replicas_in_sync
400
+ * config ["gradient_accumulation_steps" ],
383
401
)
384
402
385
403
valid_dataset = CharactorDurationF0EnergyMelDataset (
@@ -392,7 +410,7 @@ def main():
392
410
f0_stat = args .f0_stat ,
393
411
energy_stat = args .energy_stat ,
394
412
mel_length_threshold = mel_length_threshold ,
395
- speakers_map = speakers_map
413
+ speakers_map = speakers_map ,
396
414
).create (
397
415
is_shuffle = config ["is_shuffle" ],
398
416
allow_cache = config ["allow_cache" ],
@@ -407,7 +425,7 @@ def main():
407
425
epochs = 0 ,
408
426
is_mixed_precision = args .mixed_precision ,
409
427
stats_path = args .dataset_stats ,
410
- dataset_config = args .dataset_config
428
+ dataset_config = args .dataset_config ,
411
429
)
412
430
413
431
with STRATEGY .scope ():
@@ -417,11 +435,12 @@ def main():
417
435
)
418
436
fastspeech ._build ()
419
437
fastspeech .summary ()
420
-
438
+
421
439
if len (args .pretrained ) > 1 :
422
440
fastspeech .load_weights (args .pretrained , by_name = True , skip_mismatch = True )
423
- logging .info (f"Successfully loaded pretrained weight from { args .pretrained } ." )
424
-
441
+ logging .info (
442
+ f"Successfully loaded pretrained weight from { args .pretrained } ."
443
+ )
425
444
426
445
# AdamW for fastspeech
427
446
learning_rate_fn = tf .keras .optimizers .schedules .PolynomialDecay (
0 commit comments