Skip to content

Commit ac4508a

Browse files
committed
🦟 Update global batch-size for train-dataloder.
1 parent 1ca581f commit ac4508a

File tree

9 files changed

+122
-90
lines changed

9 files changed

+122
-90
lines changed

examples/fastspeech/train_fastspeech.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
from tensorflow_tts.models import TFFastSpeech
3737
from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp
3838
from tensorflow_tts.trainers import Seq2SeqBasedTrainer
39-
from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss,
40-
return_strategy)
39+
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
4140

4241

4342
class FastSpeechTrainer(Seq2SeqBasedTrainer):
@@ -218,7 +217,7 @@ def main():
218217
default="",
219218
type=str,
220219
nargs="?",
221-
help='pretrained checkpoint file to load weights from. Auto-skips non-matching layers',
220+
help="pretrained checkpoint file to load weights from. Auto-skips non-matching layers",
222221
)
223222
args = parser.parse_args()
224223

@@ -302,7 +301,9 @@ def main():
302301
).create(
303302
is_shuffle=config["is_shuffle"],
304303
allow_cache=config["allow_cache"],
305-
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
304+
batch_size=config["batch_size"]
305+
* STRATEGY.num_replicas_in_sync
306+
* config["gradient_accumulation_steps"],
306307
)
307308

308309
valid_dataset = CharactorDurationMelDataset(
@@ -335,11 +336,12 @@ def main():
335336
)
336337
fastspeech._build()
337338
fastspeech.summary()
338-
339+
339340
if len(args.pretrained) > 1:
340341
fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
341-
logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.")
342-
342+
logging.info(
343+
f"Successfully loaded pretrained weight from {args.pretrained}."
344+
)
343345

344346
# AdamW for fastspeech
345347
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(

examples/fastspeech2/train_fastspeech2.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,13 @@
3333
from tqdm import tqdm
3434

3535
import tensorflow_tts
36-
from examples.fastspeech2.fastspeech2_dataset import \
37-
CharactorDurationF0EnergyMelDataset
36+
from examples.fastspeech2.fastspeech2_dataset import CharactorDurationF0EnergyMelDataset
3837
from examples.fastspeech.train_fastspeech import FastSpeechTrainer
3938
from tensorflow_tts.configs import FastSpeech2Config
4039
from tensorflow_tts.models import TFFastSpeech2
4140
from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp
4241
from tensorflow_tts.trainers import Seq2SeqBasedTrainer
43-
from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss,
44-
return_strategy)
42+
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
4543

4644

4745
class FastSpeech2Trainer(Seq2SeqBasedTrainer):
@@ -244,9 +242,8 @@ def main():
244242
default="",
245243
type=str,
246244
nargs="?",
247-
help='pretrained weights .h5 file to load weights from. Auto-skips non-matching layers',
245+
help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers",
248246
)
249-
250247

251248
args = parser.parse_args()
252249

@@ -330,7 +327,9 @@ def main():
330327
).create(
331328
is_shuffle=config["is_shuffle"],
332329
allow_cache=config["allow_cache"],
333-
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
330+
batch_size=config["batch_size"]
331+
* STRATEGY.num_replicas_in_sync
332+
* config["gradient_accumulation_steps"],
334333
)
335334

336335
valid_dataset = CharactorDurationF0EnergyMelDataset(
@@ -367,7 +366,9 @@ def main():
367366
fastspeech.summary()
368367
if len(args.pretrained) > 1:
369368
fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
370-
logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.")
369+
logging.info(
370+
f"Successfully loaded pretrained weight from {args.pretrained}."
371+
)
371372

372373
# AdamW for fastspeech
373374
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(

examples/fastspeech2_libritts/train_fastspeech2.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,33 @@
3333
import json
3434

3535
import tensorflow_tts
36-
from examples.fastspeech2_libritts.fastspeech2_dataset import \
37-
CharactorDurationF0EnergyMelDataset
36+
from examples.fastspeech2_libritts.fastspeech2_dataset import (
37+
CharactorDurationF0EnergyMelDataset,
38+
)
3839
from tensorflow_tts.configs import FastSpeech2Config
3940
from tensorflow_tts.models import TFFastSpeech2
4041
from tensorflow_tts.optimizers import AdamWeightDecay, WarmUp
4142
from tensorflow_tts.trainers import Seq2SeqBasedTrainer
42-
from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss,
43-
return_strategy, TFGriffinLim)
43+
from tensorflow_tts.utils import (
44+
calculate_2d_loss,
45+
calculate_3d_loss,
46+
return_strategy,
47+
TFGriffinLim,
48+
)
4449

4550

4651
class FastSpeech2Trainer(Seq2SeqBasedTrainer):
4752
"""FastSpeech2 Trainer class based on FastSpeechTrainer."""
4853

4954
def __init__(
50-
self, config, strategy, steps=0, epochs=0, is_mixed_precision=False, stats_path: str = "",
51-
dataset_config: str = ""
55+
self,
56+
config,
57+
strategy,
58+
steps=0,
59+
epochs=0,
60+
is_mixed_precision=False,
61+
stats_path: str = "",
62+
dataset_config: str = "",
5263
):
5364
"""Initialize trainer.
5465
Args:
@@ -78,7 +89,9 @@ def __init__(
7889
self.use_griffin = config.get("use_griffin", False)
7990
self.griffin_lim_tf = None
8091
if self.use_griffin:
81-
logging.info(f"Load griff stats from {stats_path} and config from {dataset_config}")
92+
logging.info(
93+
f"Load griff stats from {stats_path} and config from {dataset_config}"
94+
)
8295
self.griff_conf = yaml.load(open(dataset_config), Loader=yaml.Loader)
8396
self.prepare_grim(stats_path, self.griff_conf)
8497

@@ -160,7 +173,9 @@ def generate_and_save_intermediate_result(self, batch):
160173

161174
# check directory
162175
if self.use_griffin:
163-
griff_dir_name = os.path.join(self.config["outdir"], f"predictions/{self.steps}_wav")
176+
griff_dir_name = os.path.join(
177+
self.config["outdir"], f"predictions/{self.steps}_wav"
178+
)
164179
if not os.path.exists(griff_dir_name):
165180
os.makedirs(griff_dir_name)
166181

@@ -171,23 +186,31 @@ def generate_and_save_intermediate_result(self, batch):
171186
for idx, (mel_gt, mel_before, mel_after) in enumerate(
172187
zip(mel_gts, mels_before, mels_after), 0
173188
):
174-
175-
189+
176190
if self.use_griffin:
177191
utt_id = utt_ids[idx]
178-
grif_before = self.griffin_lim_tf(tf.reshape(mel_before, [-1, 80])[tf.newaxis, :], n_iter=32)
179-
grif_after = self.griffin_lim_tf(tf.reshape(mel_after, [-1, 80])[tf.newaxis, :], n_iter=32)
180-
grif_gt = self.griffin_lim_tf(tf.reshape(mel_gt, [-1, 80])[tf.newaxis, :], n_iter=32)
181-
self.griffin_lim_tf.save_wav(grif_before, griff_dir_name, f"{utt_id}_before")
182-
self.griffin_lim_tf.save_wav(grif_after, griff_dir_name, f"{utt_id}_after")
192+
grif_before = self.griffin_lim_tf(
193+
tf.reshape(mel_before, [-1, 80])[tf.newaxis, :], n_iter=32
194+
)
195+
grif_after = self.griffin_lim_tf(
196+
tf.reshape(mel_after, [-1, 80])[tf.newaxis, :], n_iter=32
197+
)
198+
grif_gt = self.griffin_lim_tf(
199+
tf.reshape(mel_gt, [-1, 80])[tf.newaxis, :], n_iter=32
200+
)
201+
self.griffin_lim_tf.save_wav(
202+
grif_before, griff_dir_name, f"{utt_id}_before"
203+
)
204+
self.griffin_lim_tf.save_wav(
205+
grif_after, griff_dir_name, f"{utt_id}_after"
206+
)
183207
self.griffin_lim_tf.save_wav(grif_gt, griff_dir_name, f"{utt_id}_gt")
184-
208+
185209
utt_id = utt_ids[idx]
186210
mel_gt = tf.reshape(mel_gt, (-1, 80)).numpy() # [length, 80]
187211
mel_before = tf.reshape(mel_before, (-1, 80)).numpy() # [length, 80]
188212
mel_after = tf.reshape(mel_after, (-1, 80)).numpy() # [length, 80]
189213

190-
191214
# plit figure and save it
192215
figname = os.path.join(dirname, f"{utt_id}.png")
193216
fig = plt.figure(figsize=(10, 8))
@@ -229,10 +252,7 @@ def main():
229252
"--use-norm", default=1, type=int, help="usr norm-mels for train or raw."
230253
)
231254
parser.add_argument(
232-
"--f0-stat",
233-
default="./dump/stats_f0.npy",
234-
type=str,
235-
help="f0-stat path.",
255+
"--f0-stat", default="./dump/stats_f0.npy", type=str, help="f0-stat path.",
236256
)
237257
parser.add_argument(
238258
"--energy-stat",
@@ -266,26 +286,20 @@ def main():
266286
help="using mixed precision for generator or not.",
267287
)
268288
parser.add_argument(
269-
"--dataset_config",
270-
default="preprocess/libritts_preprocess.yaml",
271-
type=str,
289+
"--dataset_config", default="preprocess/libritts_preprocess.yaml", type=str,
272290
)
273291
parser.add_argument(
274-
"--dataset_stats",
275-
default="dump/stats.npy",
276-
type=str,
292+
"--dataset_stats", default="dump/stats.npy", type=str,
277293
)
278294
parser.add_argument(
279-
"--dataset_mapping",
280-
default="dump/libritts_mapper.npy",
281-
type=str,
295+
"--dataset_mapping", default="dump/libritts_mapper.npy", type=str,
282296
)
283297
parser.add_argument(
284298
"--pretrained",
285299
default="",
286300
type=str,
287301
nargs="?",
288-
help='pretrained weights .h5 file to load weights from. Auto-skips non-matching layers',
302+
help="pretrained weights .h5 file to load weights from. Auto-skips non-matching layers",
289303
)
290304
args = parser.parse_args()
291305

@@ -362,7 +376,9 @@ def main():
362376

363377
# Check n_speakers matches number of speakers in speakers_map
364378
n_speakers = config["fastspeech2_params"]["n_speakers"]
365-
assert n_speakers == len(speakers_map), f"Number of speakers in dataset does not match n_speakers in config"
379+
assert n_speakers == len(
380+
speakers_map
381+
), f"Number of speakers in dataset does not match n_speakers in config"
366382

367383
# define train/valid dataset
368384
train_dataset = CharactorDurationF0EnergyMelDataset(
@@ -375,11 +391,13 @@ def main():
375391
f0_stat=args.f0_stat,
376392
energy_stat=args.energy_stat,
377393
mel_length_threshold=mel_length_threshold,
378-
speakers_map=speakers_map
394+
speakers_map=speakers_map,
379395
).create(
380396
is_shuffle=config["is_shuffle"],
381397
allow_cache=config["allow_cache"],
382-
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
398+
batch_size=config["batch_size"]
399+
* STRATEGY.num_replicas_in_sync
400+
* config["gradient_accumulation_steps"],
383401
)
384402

385403
valid_dataset = CharactorDurationF0EnergyMelDataset(
@@ -392,7 +410,7 @@ def main():
392410
f0_stat=args.f0_stat,
393411
energy_stat=args.energy_stat,
394412
mel_length_threshold=mel_length_threshold,
395-
speakers_map=speakers_map
413+
speakers_map=speakers_map,
396414
).create(
397415
is_shuffle=config["is_shuffle"],
398416
allow_cache=config["allow_cache"],
@@ -407,7 +425,7 @@ def main():
407425
epochs=0,
408426
is_mixed_precision=args.mixed_precision,
409427
stats_path=args.dataset_stats,
410-
dataset_config=args.dataset_config
428+
dataset_config=args.dataset_config,
411429
)
412430

413431
with STRATEGY.scope():
@@ -417,11 +435,12 @@ def main():
417435
)
418436
fastspeech._build()
419437
fastspeech.summary()
420-
438+
421439
if len(args.pretrained) > 1:
422440
fastspeech.load_weights(args.pretrained, by_name=True, skip_mismatch=True)
423-
logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.")
424-
441+
logging.info(
442+
f"Successfully loaded pretrained weight from {args.pretrained}."
443+
)
425444

426445
# AdamW for fastspeech
427446
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(

examples/melgan.stft/train_melgan_stft.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@
3636
from examples.melgan.audio_mel_dataset import AudioMelDataset
3737
from examples.melgan.train_melgan import MelganTrainer, collater
3838
from tensorflow_tts.losses import TFMultiResolutionSTFT
39-
from tensorflow_tts.models import (TFMelGANGenerator,
40-
TFMelGANMultiScaleDiscriminator)
41-
from tensorflow_tts.utils import (calculate_2d_loss, calculate_3d_loss,
42-
return_strategy)
39+
from tensorflow_tts.models import TFMelGANGenerator, TFMelGANMultiScaleDiscriminator
40+
from tensorflow_tts.utils import calculate_2d_loss, calculate_3d_loss, return_strategy
4341

4442

4543
class MultiSTFTMelganTrainer(MelganTrainer):
@@ -206,7 +204,7 @@ def main():
206204
default="",
207205
type=str,
208206
nargs="?",
209-
help='path of .h5 melgan generator to load weights from',
207+
help="path of .h5 melgan generator to load weights from",
210208
)
211209
args = parser.parse_args()
212210

@@ -295,7 +293,9 @@ def main():
295293
hop_size=tf.constant(config["hop_size"], dtype=tf.int32),
296294
),
297295
allow_cache=config["allow_cache"],
298-
batch_size=config["batch_size"] * STRATEGY.num_replicas_in_sync,
296+
batch_size=config["batch_size"]
297+
* STRATEGY.num_replicas_in_sync
298+
* config["gradient_accumulation_steps"],
299299
)
300300

301301
valid_dataset = AudioMelDataset(
@@ -336,19 +336,22 @@ def main():
336336
)
337337

338338
discriminator = TFMelGANMultiScaleDiscriminator(
339-
MELGAN_CONFIG.MelGANDiscriminatorConfig(**config["melgan_discriminator_params"]),
339+
MELGAN_CONFIG.MelGANDiscriminatorConfig(
340+
**config["melgan_discriminator_params"]
341+
),
340342
name="melgan_discriminator",
341343
)
342344

343345
# dummy input to build model.
344346
fake_mels = tf.random.uniform(shape=[1, 100, 80], dtype=tf.float32)
345347
y_hat = generator(fake_mels)
346348
discriminator(y_hat)
347-
349+
348350
if len(args.pretrained) > 1:
349351
generator.load_weights(args.pretrained)
350-
logging.info(f"Successfully loaded pretrained weight from {args.pretrained}.")
351-
352+
logging.info(
353+
f"Successfully loaded pretrained weight from {args.pretrained}."
354+
)
352355

353356
generator.summary()
354357
discriminator.summary()

0 commit comments

Comments
 (0)