From f519012dea95ab2bbd2c8ff5194cfa5ebff512ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Apr 2021 11:47:39 +0200 Subject: [PATCH] reformatting and styling --- Makefile | 28 + TTS/bin/compute_attention_masks.py | 135 +++-- TTS/bin/compute_embeddings.py | 87 ++-- TTS/bin/compute_statistics.py | 57 +-- TTS/bin/convert_melgan_tflite.py | 13 +- TTS/bin/convert_melgan_torch_to_tf.py | 45 +- TTS/bin/convert_tacotron2_tflite.py | 17 +- TTS/bin/convert_tacotron2_torch_to_tf.py | 132 ++--- TTS/bin/distribute.py | 48 +- TTS/bin/find_unique_chars.py | 28 +- TTS/bin/resample.py | 71 +-- TTS/bin/synthesize.py | 109 ++-- TTS/bin/train_align_tts.py | 367 ++++++------- TTS/bin/train_encoder.py | 154 +++--- TTS/bin/train_glow_tts.py | 269 +++++----- TTS/bin/train_speedy_speech.py | 276 +++++----- TTS/bin/train_tacotron.py | 414 +++++++++------ TTS/bin/train_vocoder_gan.py | 296 +++++------ TTS/bin/train_vocoder_wavegrad.py | 173 +++---- TTS/bin/train_vocoder_wavernn.py | 180 +++---- TTS/bin/tune_wavegrad.py | 59 ++- TTS/server/server.py | 87 ++-- TTS/speaker_encoder/dataset.py | 39 +- TTS/speaker_encoder/losses.py | 25 +- TTS/speaker_encoder/model.py | 12 +- TTS/speaker_encoder/utils/generic_utils.py | 139 ++--- TTS/speaker_encoder/utils/prepare_voxceleb.py | 93 ++-- TTS/speaker_encoder/utils/visual.py | 4 +- TTS/tts/datasets/TTSDataset.py | 222 ++++---- TTS/tts/datasets/preprocess.py | 165 +++--- .../layers/align_tts/duration_predictor.py | 3 +- TTS/tts/layers/align_tts/mdn.py | 5 +- TTS/tts/layers/feed_forward/decoder.py | 82 +-- .../layers/feed_forward/duration_predictor.py | 15 +- TTS/tts/layers/feed_forward/encoder.py | 92 ++-- TTS/tts/layers/generic/gated_conv.py | 12 +- TTS/tts/layers/generic/normalization.py | 29 +- TTS/tts/layers/generic/pos_encoding.py | 21 +- TTS/tts/layers/generic/res_conv_bn.py | 36 +- TTS/tts/layers/generic/time_depth_sep_conv.py | 30 +- TTS/tts/layers/generic/transformer.py | 37 +- TTS/tts/layers/generic/wavenet.py | 94 ++-- TTS/tts/layers/glow_tts/decoder.py | 72 ++- TTS/tts/layers/glow_tts/duration_predictor.py | 11 +- TTS/tts/layers/glow_tts/encoder.py | 101 ++-- TTS/tts/layers/glow_tts/glow.py | 92 ++-- .../glow_tts/monotonic_align/__init__.py | 7 +- .../layers/glow_tts/monotonic_align/core.pyx | 4 +- TTS/tts/layers/glow_tts/transformer.py | 191 ++++--- TTS/tts/layers/losses.py | 188 ++++--- TTS/tts/layers/tacotron/attentions.py | 219 ++++---- TTS/tts/layers/tacotron/common_layers.py | 50 +- TTS/tts/layers/tacotron/gst_layers.py | 75 +-- TTS/tts/layers/tacotron/tacotron.py | 227 ++++---- TTS/tts/layers/tacotron/tacotron2.py | 192 ++++--- TTS/tts/models/align_tts.py | 147 ++---- TTS/tts/models/glow_tts.py | 168 +++--- TTS/tts/models/speedy_speech.py | 78 ++- TTS/tts/models/tacotron.py | 201 +++++--- TTS/tts/models/tacotron2.py | 212 +++++--- TTS/tts/models/tacotron_abstract.py | 104 ++-- TTS/tts/tf/layers/tacotron/common_layers.py | 92 ++-- TTS/tts/tf/layers/tacotron/tacotron2.py | 198 +++---- TTS/tts/tf/models/tacotron2.py | 99 ++-- TTS/tts/tf/utils/convert_torch_to_tf_utils.py | 45 +- TTS/tts/tf/utils/generic_utils.py | 80 +-- TTS/tts/tf/utils/io.py | 25 +- TTS/tts/tf/utils/tflite.py | 15 +- TTS/tts/utils/chinese_mandarin/numbers.py | 56 +- TTS/tts/utils/chinese_mandarin/phonemizer.py | 8 +- .../chinese_mandarin/pinyinToPhonemes.py | 1 - TTS/tts/utils/data.py | 27 +- TTS/tts/utils/generic_utils.py | 484 ++++++++++-------- TTS/tts/utils/io.py | 66 ++- TTS/tts/utils/speakers.py | 44 +- TTS/tts/utils/ssim.py | 25 +- TTS/tts/utils/synthesis.py | 202 ++++---- TTS/tts/utils/text/__init__.py | 105 ++-- TTS/tts/utils/text/abbreviations.py | 129 ++--- TTS/tts/utils/text/cleaners.py | 72 +-- TTS/tts/utils/text/cmudict.py | 119 ++++- TTS/tts/utils/text/number_norm.py | 36 +- TTS/tts/utils/text/symbols.py | 53 +- TTS/tts/utils/text/time.py | 7 +- TTS/tts/utils/visual.py | 95 ++-- TTS/utils/arguments.py | 81 ++- TTS/utils/audio.py | 208 ++++---- TTS/utils/console_logger.py | 63 +-- TTS/utils/distribute.py | 16 +- TTS/utils/generic_utils.py | 66 +-- TTS/utils/io.py | 27 +- TTS/utils/manage.py | 26 +- TTS/utils/radam.py | 66 ++- TTS/utils/synthesizer.py | 37 +- TTS/utils/tensorboard_logger.py | 33 +- TTS/utils/training.py | 37 +- TTS/vocoder/datasets/gan_dataset.py | 62 ++- TTS/vocoder/datasets/preprocess.py | 8 +- TTS/vocoder/datasets/wavegrad_dataset.py | 51 +- TTS/vocoder/datasets/wavernn_dataset.py | 51 +- TTS/vocoder/layers/hifigan.py | 12 +- TTS/vocoder/layers/losses.py | 204 ++++---- TTS/vocoder/layers/melgan.py | 33 +- TTS/vocoder/layers/parallel_wavegan.py | 54 +- TTS/vocoder/layers/pqmf.py | 15 +- TTS/vocoder/layers/upsample.py | 51 +- TTS/vocoder/layers/wavegrad.py | 53 +- .../models/fullband_melgan_generator.py | 39 +- TTS/vocoder/models/hifigan_discriminator.py | 84 +-- TTS/vocoder/models/hifigan_generator.py | 197 +++---- TTS/vocoder/models/melgan_discriminator.py | 69 +-- TTS/vocoder/models/melgan_generator.py | 77 ++- .../models/melgan_multiscale_discriminator.py | 57 ++- .../models/multiband_melgan_generator.py | 42 +- .../models/parallel_wavegan_discriminator.py | 115 ++--- .../models/parallel_wavegan_generator.py | 89 ++-- .../models/random_window_discriminator.py | 95 ++-- TTS/vocoder/models/wavegrad.py | 71 +-- TTS/vocoder/models/wavernn.py | 137 +++-- TTS/vocoder/tf/layers/melgan.py | 28 +- TTS/vocoder/tf/layers/pqmf.py | 24 +- TTS/vocoder/tf/models/melgan_generator.py | 75 +-- .../tf/models/multiband_melgan_generator.py | 45 +- .../tf/utils/convert_torch_to_tf_utils.py | 26 +- TTS/vocoder/tf/utils/generic_utils.py | 27 +- TTS/vocoder/tf/utils/io.py | 15 +- TTS/vocoder/tf/utils/tflite.py | 15 +- TTS/vocoder/utils/distribution.py | 42 +- TTS/vocoder/utils/generic_utils.py | 167 +++--- TTS/vocoder/utils/io.py | 135 +++-- notebooks/dataset_analysis/analyze.py | 87 ++-- pyproject.toml | 31 ++ tests/test_audio.py | 81 +-- tests/test_feed_forward_layers.py | 137 ++--- tests/test_glow_tts.py | 77 ++- tests/test_layers.py | 49 +- tests/test_loader.py | 38 +- tests/test_preprocessors.py | 9 +- tests/test_speaker_encoder.py | 14 +- tests/test_speedy_speech_layers.py | 39 +- tests/test_symbols.py | 3 +- tests/test_synthesizer.py | 76 ++- tests/test_tacotron2_model.py | 190 ++++--- tests/test_tacotron2_tf_model.py | 125 +++-- tests/test_tacotron_model.py | 274 +++++----- tests/test_text_cleaners.py | 4 +- tests/test_text_processing.py | 21 +- tests/test_vocoder_gan_datasets.py | 68 +-- tests/test_vocoder_losses.py | 16 +- tests/test_vocoder_melgan_generator.py | 1 + ..._vocoder_parallel_wavegan_discriminator.py | 11 +- ...test_vocoder_parallel_wavegan_generator.py | 3 +- tests/test_vocoder_pqmf.py | 8 +- tests/test_vocoder_rwd.py | 16 +- tests/test_vocoder_tf_pqmf.py | 8 +- tests/test_vocoder_wavernn.py | 4 +- tests/test_vocoder_wavernn_datasets.py | 58 +-- tests/test_wavegrad_layers.py | 14 +- tests/test_wavegrad_train.py | 41 +- 159 files changed, 6605 insertions(+), 6445 deletions(-) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..c00cd1ceea --- /dev/null +++ b/Makefile @@ -0,0 +1,28 @@ +.DEFAULT_GOAL := help +.PHONY: test deps style lint install help + +help: + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +target_dirs := tests TTS notebooks + +system-deps: ## install linux system deps + sudo apt-get install -y espeak-ng + sudo apt-get install -y libsndfile1-dev + +deps: ## install 🐸 requirements. + pip install -r requirements.txt + +test: ## run tests. + nosetests --with-cov -cov --cover-erase --cover-package TTS tests + ./run_bash_tests.sh + +style: ## update code style. + black ${target_dirs} + isort ${target_dirs} + +lint: ## run pylint linter. + pylint ${target_dirs} + +install: ## install 🐸 TTS for development. + pip install -e . diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 53246e07fe..16011dda2c 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -1,12 +1,13 @@ import argparse import importlib import os +from argparse import RawTextHelpFormatter import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm -from argparse import RawTextHelpFormatter + from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import load_checkpoint @@ -14,17 +15,14 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config - -if __name__ == '__main__': +if __name__ == "__main__": # pylint: disable=bad-continuation parser = argparse.ArgumentParser( - description='''Extract attention masks from trained Tacotron/Tacotron2 models. -These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n''' - -'''Each attention mask is written to the same path as the input wav file with ".npy" file extension. -(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n''' - -''' + description="""Extract attention masks from trained Tacotron/Tacotron2 models. +These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" + """Each attention mask is written to the same path as the input wav file with ".npy" file extension. +(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" + """ Example run: CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar @@ -34,53 +32,44 @@ --batch_size 32 --dataset ljspeech --use_cuda True -''', - formatter_class=RawTextHelpFormatter - ) - parser.add_argument('--model_path', - type=str, - required=True, - help='Path to Tacotron/Tacotron2 model file ') +""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") parser.add_argument( - '--config_path', + "--config_path", type=str, required=True, - help='Path to Tacotron/Tacotron2 config file.', + help="Path to Tacotron/Tacotron2 config file.", ) - parser.add_argument('--dataset', - type=str, - default='', - required=True, - help='Target dataset processor name from TTS.tts.dataset.preprocess.') - parser.add_argument( - '--dataset_metafile', + "--dataset", type=str, - default='', + default="", required=True, - help='Dataset metafile inclusing file paths with transcripts.') + help="Target dataset processor name from TTS.tts.dataset.preprocess.", + ) + parser.add_argument( - '--data_path', + "--dataset_metafile", type=str, - default='', - help='Defines the data path. It overwrites config.json.') - parser.add_argument('--use_cuda', - type=bool, - default=False, - help="enable/disable cuda.") + default="", + required=True, + help="Dataset metafile inclusing file paths with transcripts.", + ) + parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") + parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.") parser.add_argument( - '--batch_size', - default=16, - type=int, - help='Batch size for the model. Use batch_size=1 if you have no CUDA.') + "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." + ) args = parser.parse_args() C = load_config(args.config_path) ap = AudioProcessor(**C.audio) # if the vocabulary was passed, replace the default - if 'characters' in C.keys(): + if "characters" in C.keys(): symbols, phonemes = make_symbols(**C.characters) # load the model @@ -91,28 +80,32 @@ model.eval() # data loader - preprocessor = importlib.import_module('TTS.tts.datasets.preprocess') + preprocessor = importlib.import_module("TTS.tts.datasets.preprocess") preprocessor = getattr(preprocessor, args.dataset) meta_data = preprocessor(args.data_path, args.dataset_metafile) - dataset = MyDataset(model.decoder.r, - C.text_cleaner, - compute_linear_spec=False, - ap=ap, - meta_data=meta_data, - tp=C.characters if 'characters' in C.keys() else None, - add_blank=C['add_blank'] if 'add_blank' in C.keys() else False, - use_phonemes=C.use_phonemes, - phoneme_cache_path=C.phoneme_cache_path, - phoneme_language=C.phoneme_language, - enable_eos_bos=C.enable_eos_bos_chars) + dataset = MyDataset( + model.decoder.r, + C.text_cleaner, + compute_linear_spec=False, + ap=ap, + meta_data=meta_data, + tp=C.characters if "characters" in C.keys() else None, + add_blank=C["add_blank"] if "add_blank" in C.keys() else False, + use_phonemes=C.use_phonemes, + phoneme_cache_path=C.phoneme_cache_path, + phoneme_language=C.phoneme_language, + enable_eos_bos=C.enable_eos_bos_chars, + ) dataset.sort_items() - loader = DataLoader(dataset, - batch_size=args.batch_size, - num_workers=4, - collate_fn=dataset.collate_fn, - shuffle=False, - drop_last=False) + loader = DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=4, + collate_fn=dataset.collate_fn, + shuffle=False, + drop_last=False, + ) # compute attentions file_paths = [] @@ -134,25 +127,29 @@ mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() - mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward( - text_input, text_lengths, mel_input) + mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input) alignments = alignments.detach() for idx, alignment in enumerate(alignments): item_idx = item_idxs[idx] # interpolate if r > 1 - alignment = torch.nn.functional.interpolate( - alignment.transpose(0, 1).unsqueeze(0), - size=None, - scale_factor=model.decoder.r, - mode='nearest', - align_corners=None, - recompute_scale_factor=None).squeeze(0).transpose(0, 1) + alignment = ( + torch.nn.functional.interpolate( + alignment.transpose(0, 1).unsqueeze(0), + size=None, + scale_factor=model.decoder.r, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + ) + .squeeze(0) + .transpose(0, 1) + ) # remove paddings - alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy() + alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() # set file paths wav_file_name = os.path.basename(item_idx) - align_file_name = os.path.splitext(wav_file_name)[0] + '.npy' + align_file_name = os.path.splitext(wav_file_name)[0] + ".npy" file_path = item_idx.replace(wav_file_name, align_file_name) # save output file_paths.append([item_idx, file_path]) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 64edd140c8..c38e0e7e27 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -3,101 +3,82 @@ import os import numpy as np +import torch from tqdm import tqdm -import torch from TTS.speaker_encoder.model import SpeakerEncoder +from TTS.tts.datasets.preprocess import load_meta_data +from TTS.tts.utils.speakers import save_speaker_mapping from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config -from TTS.tts.utils.speakers import save_speaker_mapping -from TTS.tts.datasets.preprocess import load_meta_data parser = argparse.ArgumentParser( - description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.') -parser.add_argument( - 'model_path', - type=str, - help='Path to model outputs (checkpoint, tensorboard etc.).') -parser.add_argument( - 'config_path', - type=str, - help='Path to config file for training.', + description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.' ) +parser.add_argument("model_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).") parser.add_argument( - 'data_path', - type=str, - help='Data path for wav files - directory or CSV file') -parser.add_argument( - 'output_path', - type=str, - help='path for training outputs.') -parser.add_argument( - '--target_dataset', + "config_path", type=str, - default='', - help='Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.' -) -parser.add_argument( - '--use_cuda', type=bool, help='flag to set cuda.', default=False + help="Path to config file for training.", ) +parser.add_argument("data_path", type=str, help="Data path for wav files - directory or CSV file") +parser.add_argument("output_path", type=str, help="path for training outputs.") parser.add_argument( - '--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|' + "--target_dataset", + type=str, + default="", + help="Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.", ) +parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=False) +parser.add_argument("--separator", type=str, help="Separator used in file if CSV is passed for data_path", default="|") args = parser.parse_args() c = load_config(args.config_path) -ap = AudioProcessor(**c['audio']) +ap = AudioProcessor(**c["audio"]) data_path = args.data_path split_ext = os.path.splitext(data_path) sep = args.separator -if args.target_dataset != '': +if args.target_dataset != "": # if target dataset is defined dataset_config = [ - { - "name": args.target_dataset, - "path": args.data_path, - "meta_file_train": None, - "meta_file_val": None - }, + {"name": args.target_dataset, "path": args.data_path, "meta_file_train": None, "meta_file_val": None}, ] wav_files, _ = load_meta_data(dataset_config, eval_split=False) - output_files = [wav_file[1].replace(data_path, args.output_path).replace( - '.wav', '.npy') for wav_file in wav_files] + output_files = [wav_file[1].replace(data_path, args.output_path).replace(".wav", ".npy") for wav_file in wav_files] else: # if target dataset is not defined - if len(split_ext) > 0 and split_ext[1].lower() == '.csv': + if len(split_ext) > 0 and split_ext[1].lower() == ".csv": # Parse CSV - print(f'CSV file: {data_path}') + print(f"CSV file: {data_path}") with open(data_path) as f: - wav_path = os.path.join(os.path.dirname(data_path), 'wavs') + wav_path = os.path.join(os.path.dirname(data_path), "wavs") wav_files = [] - print(f'Separator is: {sep}') + print(f"Separator is: {sep}") for line in f: components = line.split(sep) if len(components) != 2: print("Invalid line") continue - wav_file = os.path.join(wav_path, components[0] + '.wav') - #print(f'wav_file: {wav_file}') + wav_file = os.path.join(wav_path, components[0] + ".wav") + # print(f'wav_file: {wav_file}') if os.path.exists(wav_file): wav_files.append(wav_file) - print(f'Count of wavs imported: {len(wav_files)}') + print(f"Count of wavs imported: {len(wav_files)}") else: # Parse all wav files in data_path - wav_files = glob.glob(data_path + '/**/*.wav', recursive=True) + wav_files = glob.glob(data_path + "/**/*.wav", recursive=True) - output_files = [wav_file.replace(data_path, args.output_path).replace( - '.wav', '.npy') for wav_file in wav_files] + output_files = [wav_file.replace(data_path, args.output_path).replace(".wav", ".npy") for wav_file in wav_files] for output_file in output_files: os.makedirs(os.path.dirname(output_file), exist_ok=True) # define Encoder model model = SpeakerEncoder(**c.model) -model.load_state_dict(torch.load(args.model_path)['model']) +model.load_state_dict(torch.load(args.model_path)["model"]) model.eval() if args.use_cuda: model.cuda() @@ -117,14 +98,14 @@ embedd = embedd.detach().cpu().numpy() np.save(output_files[idx], embedd) - if args.target_dataset != '': + if args.target_dataset != "": # create speaker_mapping if target dataset is defined wav_file_name = os.path.basename(wav_file) speaker_mapping[wav_file_name] = {} - speaker_mapping[wav_file_name]['name'] = speaker_name - speaker_mapping[wav_file_name]['embedding'] = embedd.flatten().tolist() + speaker_mapping[wav_file_name]["name"] = speaker_name + speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist() -if args.target_dataset != '': +if args.target_dataset != "": # save speaker_mapping if target dataset is defined - mapping_file_path = os.path.join(args.output_path, 'speakers.json') + mapping_file_path = os.path.join(args.output_path, "speakers.json") save_speaker_mapping(args.output_path, speaker_mapping) diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index a74fe90aef..9e2b7415d8 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -1,39 +1,38 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os -import glob import argparse +import glob +import os import numpy as np from tqdm import tqdm from TTS.tts.datasets.preprocess import load_meta_data -from TTS.utils.io import load_config from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_config def main(): """Run preprocessing process.""" - parser = argparse.ArgumentParser( - description="Compute mean and variance of spectrogtram features.") - parser.add_argument("--config_path", type=str, required=True, - help="TTS config file path to define audio processin parameters.") - parser.add_argument("--out_path", type=str, required=True, - help="save path (directory and filename).") + parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") + parser.add_argument( + "--config_path", type=str, required=True, help="TTS config file path to define audio processin parameters." + ) + parser.add_argument("--out_path", type=str, required=True, help="save path (directory and filename).") args = parser.parse_args() # load config CONFIG = load_config(args.config_path) - CONFIG.audio['signal_norm'] = False # do not apply earlier normalization - CONFIG.audio['stats_path'] = None # discard pre-defined stats + CONFIG.audio["signal_norm"] = False # do not apply earlier normalization + CONFIG.audio["stats_path"] = None # discard pre-defined stats # load audio processor ap = AudioProcessor(**CONFIG.audio) # load the meta data of target dataset - if 'data_path' in CONFIG.keys(): - dataset_items = glob.glob(os.path.join(CONFIG.data_path, '**', '*.wav'), recursive=True) + if "data_path" in CONFIG.keys(): + dataset_items = glob.glob(os.path.join(CONFIG.data_path, "**", "*.wav"), recursive=True) else: dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data print(f" > There are {len(dataset_items)} files.") @@ -63,27 +62,27 @@ def main(): output_file_path = args.out_path stats = {} - stats['mel_mean'] = mel_mean - stats['mel_std'] = mel_scale - stats['linear_mean'] = linear_mean - stats['linear_std'] = linear_scale + stats["mel_mean"] = mel_mean + stats["mel_std"] = mel_scale + stats["linear_mean"] = linear_mean + stats["linear_std"] = linear_scale - print(f' > Avg mel spec mean: {mel_mean.mean()}') - print(f' > Avg mel spec scale: {mel_scale.mean()}') - print(f' > Avg linear spec mean: {linear_mean.mean()}') - print(f' > Avg lienar spec scale: {linear_scale.mean()}') + print(f" > Avg mel spec mean: {mel_mean.mean()}") + print(f" > Avg mel spec scale: {mel_scale.mean()}") + print(f" > Avg linear spec mean: {linear_mean.mean()}") + print(f" > Avg lienar spec scale: {linear_scale.mean()}") # set default config values for mean-var scaling - CONFIG.audio['stats_path'] = output_file_path - CONFIG.audio['signal_norm'] = True + CONFIG.audio["stats_path"] = output_file_path + CONFIG.audio["signal_norm"] = True # remove redundant values - del CONFIG.audio['max_norm'] - del CONFIG.audio['min_level_db'] - del CONFIG.audio['symmetric_norm'] - del CONFIG.audio['clip_norm'] - stats['audio_config'] = CONFIG.audio + del CONFIG.audio["max_norm"] + del CONFIG.audio["min_level_db"] + del CONFIG.audio["symmetric_norm"] + del CONFIG.audio["clip_norm"] + stats["audio_config"] = CONFIG.audio np.save(output_file_path, stats, allow_pickle=True) - print(f' > stats saved to {output_file_path}') + print(f" > stats saved to {output_file_path}") if __name__ == "__main__": diff --git a/TTS/bin/convert_melgan_tflite.py b/TTS/bin/convert_melgan_tflite.py index 8df582da60..a3a3fb66fa 100644 --- a/TTS/bin/convert_melgan_tflite.py +++ b/TTS/bin/convert_melgan_tflite.py @@ -7,17 +7,10 @@ from TTS.vocoder.tf.utils.io import load_checkpoint from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite - parser = argparse.ArgumentParser() -parser.add_argument('--tf_model', - type=str, - help='Path to target torch model to be converted to TF.') -parser.add_argument('--config_path', - type=str, - help='Path to config file of torch model.') -parser.add_argument('--output_path', - type=str, - help='path to tflite output binary.') +parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.") +parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") +parser.add_argument("--output_path", type=str, help="path to tflite output binary.") args = parser.parse_args() # Set constants diff --git a/TTS/bin/convert_melgan_torch_to_tf.py b/TTS/bin/convert_melgan_torch_to_tf.py index 2eec6157cf..435813483c 100644 --- a/TTS/bin/convert_melgan_torch_to_tf.py +++ b/TTS/bin/convert_melgan_torch_to_tf.py @@ -1,6 +1,6 @@ import argparse -from difflib import SequenceMatcher import os +from difflib import SequenceMatcher import numpy as np import tensorflow as tf @@ -8,27 +8,22 @@ from TTS.utils.io import load_config from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( - compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf) -from TTS.vocoder.tf.utils.generic_utils import \ - setup_generator as setup_tf_generator + compare_torch_tf, + convert_tf_name, + transfer_weights_torch_to_tf, +) +from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator from TTS.vocoder.tf.utils.io import save_checkpoint from TTS.vocoder.utils.generic_utils import setup_generator # prevent GPU use -os.environ['CUDA_VISIBLE_DEVICES'] = '' +os.environ["CUDA_VISIBLE_DEVICES"] = "" # define args parser = argparse.ArgumentParser() -parser.add_argument('--torch_model_path', - type=str, - help='Path to target torch model to be converted to TF.') -parser.add_argument('--config_path', - type=str, - help='Path to config file of torch model.') -parser.add_argument( - '--output_path', - type=str, - help='path to output file including file name to save TF model.') +parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.") +parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") +parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.") args = parser.parse_args() # load model config @@ -38,9 +33,8 @@ # init torch model model = setup_generator(c) -checkpoint = torch.load(args.torch_model_path, - map_location=torch.device('cpu')) -state_dict = checkpoint['model'] +checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) +state_dict = checkpoint["model"] model.load_state_dict(state_dict) model.remove_weight_norm() state_dict = model.state_dict() @@ -48,7 +42,7 @@ # init tf model model_tf = setup_tf_generator(c) -common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE' +common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE" # get tf_model graph by passing an input # B x D x T dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32) @@ -66,10 +60,7 @@ if tf_name in [name[0] for name in var_map]: continue tf_name_edited = convert_tf_name(tf_name) - ratios = [ - SequenceMatcher(None, torch_name, tf_name_edited).ratio() - for torch_name in torch_var_names - ] + ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names] max_idx = np.argmax(ratios) matching_name = torch_var_names[max_idx] del torch_var_names[max_idx] @@ -107,10 +98,8 @@ model_tf.inference_padding = 0 output_torch = model.inference(dummy_input_torch) output_tf = model_tf(dummy_input_tf, training=False) -assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf( - output_torch, output_tf) +assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(output_torch, output_tf) # save tf model -save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'], - args.output_path) -print(' > Model conversion is successfully completed :).') +save_checkpoint(model_tf, checkpoint["step"], checkpoint["epoch"], args.output_path) +print(" > Model conversion is successfully completed :).") diff --git a/TTS/bin/convert_tacotron2_tflite.py b/TTS/bin/convert_tacotron2_tflite.py index 2fddf4b01c..327d0ae811 100644 --- a/TTS/bin/convert_tacotron2_tflite.py +++ b/TTS/bin/convert_tacotron2_tflite.py @@ -2,23 +2,16 @@ import argparse -from TTS.utils.io import load_config -from TTS.tts.utils.text.symbols import symbols, phonemes from TTS.tts.tf.utils.generic_utils import setup_model from TTS.tts.tf.utils.io import load_checkpoint from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite - +from TTS.tts.utils.text.symbols import phonemes, symbols +from TTS.utils.io import load_config parser = argparse.ArgumentParser() -parser.add_argument('--tf_model', - type=str, - help='Path to target torch model to be converted to TF.') -parser.add_argument('--config_path', - type=str, - help='Path to config file of torch model.') -parser.add_argument('--output_path', - type=str, - help='path to tflite output binary.') +parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.") +parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") +parser.add_argument("--output_path", type=str, help="path to tflite output binary.") args = parser.parse_args() # Set constants diff --git a/TTS/bin/convert_tacotron2_torch_to_tf.py b/TTS/bin/convert_tacotron2_torch_to_tf.py index 71fb8d5e25..d523d01e78 100644 --- a/TTS/bin/convert_tacotron2_torch_to_tf.py +++ b/TTS/bin/convert_tacotron2_torch_to_tf.py @@ -1,34 +1,28 @@ import argparse -from difflib import SequenceMatcher import os import sys +from difflib import SequenceMatcher from pprint import pprint import numpy as np import tensorflow as tf import torch + from TTS.tts.tf.models.tacotron2 import Tacotron2 -from TTS.tts.tf.utils.convert_torch_to_tf_utils import ( - compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf) +from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf from TTS.tts.tf.utils.generic_utils import save_checkpoint from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.text.symbols import phonemes, symbols from TTS.utils.io import load_config -sys.path.append('/home/erogol/Projects') -os.environ['CUDA_VISIBLE_DEVICES'] = '' +sys.path.append("/home/erogol/Projects") +os.environ["CUDA_VISIBLE_DEVICES"] = "" parser = argparse.ArgumentParser() -parser.add_argument('--torch_model_path', - type=str, - help='Path to target torch model to be converted to TF.') -parser.add_argument('--config_path', - type=str, - help='Path to config file of torch model.') -parser.add_argument('--output_path', - type=str, - help='path to output file including file name to save TF model.') +parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.") +parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") +parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.") args = parser.parse_args() # load model config @@ -39,51 +33,48 @@ # init torch model num_chars = len(phonemes) if c.use_phonemes else len(symbols) model = setup_model(num_chars, num_speakers, c) -checkpoint = torch.load(args.torch_model_path, - map_location=torch.device('cpu')) -state_dict = checkpoint['model'] +checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) +state_dict = checkpoint["model"] model.load_state_dict(state_dict) # init tf model -model_tf = Tacotron2(num_chars=num_chars, - num_speakers=num_speakers, - r=model.decoder.r, - postnet_output_dim=c.audio['num_mels'], - decoder_output_dim=c.audio['num_mels'], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder) +model_tf = Tacotron2( + num_chars=num_chars, + num_speakers=num_speakers, + r=model.decoder.r, + postnet_output_dim=c.audio["num_mels"], + decoder_output_dim=c.audio["num_mels"], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, +) # set initial layer mapping - these are not captured by the below heuristic approach # TODO: set layer names so that we can remove these manual matching -common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE' +common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE" var_map = [ - ('embedding/embeddings:0', 'embedding.weight'), - ('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0', - 'encoder.lstm.weight_ih_l0'), - ('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0', - 'encoder.lstm.weight_hh_l0'), - ('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0', - 'encoder.lstm.weight_ih_l0_reverse'), - ('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0', - 'encoder.lstm.weight_hh_l0_reverse'), - ('encoder/lstm/forward_lstm/lstm_cell_1/bias:0', - ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')), - ('encoder/lstm/backward_lstm/lstm_cell_2/bias:0', - ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')), - ('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'), - ('decoder/linear_projection/kernel:0', - 'decoder.linear_projection.linear_layer.weight'), - ('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight') + ("embedding/embeddings:0", "embedding.weight"), + ("encoder/lstm/forward_lstm/lstm_cell_1/kernel:0", "encoder.lstm.weight_ih_l0"), + ("encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0", "encoder.lstm.weight_hh_l0"), + ("encoder/lstm/backward_lstm/lstm_cell_2/kernel:0", "encoder.lstm.weight_ih_l0_reverse"), + ("encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0", "encoder.lstm.weight_hh_l0_reverse"), + ("encoder/lstm/forward_lstm/lstm_cell_1/bias:0", ("encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0")), + ( + "encoder/lstm/backward_lstm/lstm_cell_2/bias:0", + ("encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse"), + ), + ("attention/v/kernel:0", "decoder.attention.v.linear_layer.weight"), + ("decoder/linear_projection/kernel:0", "decoder.linear_projection.linear_layer.weight"), + ("decoder/stopnet/kernel:0", "decoder.stopnet.1.linear_layer.weight"), ] # %% @@ -101,10 +92,7 @@ if tf_name in [name[0] for name in var_map]: continue tf_name_edited = convert_tf_name(tf_name) - ratios = [ - SequenceMatcher(None, torch_name, tf_name_edited).ratio() - for torch_name in torch_var_names - ] + ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names] max_idx = np.argmax(ratios) matching_name = torch_var_names[max_idx] del torch_var_names[max_idx] @@ -124,25 +112,21 @@ o_t = model.embedding(input_ids) o_tf = model_tf.embedding(input_ids.detach().numpy()) -assert abs(o_t.detach().numpy() - - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - - o_tf.numpy()).sum() +assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum() # compare encoder outputs oo_en = model.encoder.inference(o_t.transpose(1, 2)) ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False) assert compare_torch_tf(oo_en, ooo_en) < 1e-5 -#pylint: disable=redefined-builtin +# pylint: disable=redefined-builtin # compare decoder.attention_rnn inp = torch.rand([1, 768]) inp_tf = inp.numpy() -model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access +model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access output, cell_state = model.decoder.attention_rnn(inp) states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, - states[2], - training=False) +output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False) assert compare_torch_tf(output, output_tf).mean() < 1e-5 query = output @@ -153,8 +137,7 @@ # compare decoder.attention model.decoder.attention.init_states(inputs) processes_inputs = model.decoder.attention.preprocess_inputs(inputs) -loc_attn, proc_query = model.decoder.attention.get_location_attention( - query, processes_inputs) +loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs) context = model.decoder.attention(query, inputs, processes_inputs, None) attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1] @@ -169,13 +152,10 @@ # compare decoder.decoder_rnn input = torch.rand([1, 1536]) input_tf = input.numpy() -model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access -output, cell_state = model.decoder.decoder_rnn( - input, [model.decoder.decoder_hidden, model.decoder.decoder_cell]) +model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access +output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell]) states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, - states[3], - training=False) +output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False) assert abs(input - input_tf).mean() < 1e-5 assert compare_torch_tf(output, output_tf).mean() < 1e-5 @@ -198,12 +178,10 @@ outputs_torch = model.inference(input_ids) outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy())) print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean()) -assert compare_torch_tf(outputs_torch[2][:, 50, :], - outputs_tf[2][:, 50, :]) < 1e-5 +assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5 assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4 # %% # save tf model -save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'], - checkpoint['r'], args.output_path) -print(' > Model conversion is successfully completed :).') +save_checkpoint(model_tf, None, checkpoint["step"], checkpoint["epoch"], checkpoint["r"], args.output_path) +print(" > Model conversion is successfully completed :).") diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index 390bd738de..0bd2727544 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import argparse import os -import sys import pathlib -import time import subprocess -import argparse +import sys +import time + import torch @@ -15,26 +16,19 @@ def main(): Call train.py as a new process and pass command arguments """ parser = argparse.ArgumentParser() + parser.add_argument("--script", type=str, help="Target training script to distibute.") parser.add_argument( - '--script', - type=str, - help='Target training script to distibute.') - parser.add_argument( - '--continue_path', + "--continue_path", type=str, help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', - default='', - required='--config_path' not in sys.argv) + default="", + required="--config_path" not in sys.argv, + ) parser.add_argument( - '--restore_path', - type=str, - help='Model file to be restored. Use to finetune a model.', - default='') + "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default="" + ) parser.add_argument( - '--config_path', - type=str, - help='Path to config file for training.', - required='--continue_path' not in sys.argv + "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv ) args = parser.parse_args() @@ -44,20 +38,20 @@ def main(): # set arguments for train.py folder_path = pathlib.Path(__file__).parent.absolute() command = [os.path.join(folder_path, args.script)] - command.append('--continue_path={}'.format(args.continue_path)) - command.append('--restore_path={}'.format(args.restore_path)) - command.append('--config_path={}'.format(args.config_path)) - command.append('--group_id=group_{}'.format(group_id)) - command.append('') + command.append("--continue_path={}".format(args.continue_path)) + command.append("--restore_path={}".format(args.restore_path)) + command.append("--config_path={}".format(args.config_path)) + command.append("--group_id=group_{}".format(group_id)) + command.append("") # run processes processes = [] for i in range(num_gpus): my_env = os.environ.copy() my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i) - command[-1] = '--rank={}'.format(i) - stdout = None if i == 0 else open(os.devnull, 'w') - p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env) + command[-1] = "--rank={}".format(i) + stdout = None if i == 0 else open(os.devnull, "w") + p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) processes.append(p) print(command) @@ -65,5 +59,5 @@ def main(): p.wait() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index f9b6827b31..d2436c6d0e 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -1,6 +1,6 @@ """Find all the unique characters in a dataset""" -import os import argparse +import os from argparse import RawTextHelpFormatter from TTS.tts.datasets.preprocess import get_preprocessor_by_name @@ -8,29 +8,23 @@ def main(): # pylint: disable=bad-continuation - parser = argparse.ArgumentParser(description='''Find all the unique characters or phonemes in a dataset.\n\n''' - - '''Target dataset must be defined in TTS.tts.datasets.preprocess\n\n'''\ - ''' + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """Target dataset must be defined in TTS.tts.datasets.preprocess\n\n""" + """ Example runs: python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv - ''', formatter_class=RawTextHelpFormatter) - - parser.add_argument( - '--dataset', - type=str, - default='', - help='One of the target dataset names in TTS.tts.datasets.preprocess.' - ) + """, + formatter_class=RawTextHelpFormatter, + ) parser.add_argument( - '--meta_file', - type=str, - default=None, - help='Path to the transcriptions file of the dataset.' + "--dataset", type=str, default="", help="One of the target dataset names in TTS.tts.datasets.preprocess." ) + parser.add_argument("--meta_file", type=str, default=None, help="Path to the transcriptions file of the dataset.") + args = parser.parse_args() preprocessor = get_preprocessor_by_name(args.dataset) diff --git a/TTS/bin/resample.py b/TTS/bin/resample.py index 080e2bad19..f1e416b4a4 100644 --- a/TTS/bin/resample.py +++ b/TTS/bin/resample.py @@ -1,21 +1,24 @@ import argparse import glob import os -import librosa -from distutils.dir_util import copy_tree from argparse import RawTextHelpFormatter +from distutils.dir_util import copy_tree from multiprocessing import Pool + +import librosa from tqdm import tqdm + def resample_file(func_args): filename, output_sr = func_args y, sr = librosa.load(filename, sr=output_sr) librosa.output.write_wav(filename, y, sr) -if __name__ == '__main__': + +if __name__ == "__main__": parser = argparse.ArgumentParser( - description='''Resample a folder recusively with librosa + description="""Resample a folder recusively with librosa Can be used in place or create a copy of the folder as an output.\n\n Example run: python TTS/bin/resample.py @@ -23,46 +26,52 @@ def resample_file(func_args): --output_sr 22050 --output_dir /root/resampled_LJSpeech-1.1/ --n_jobs 24 - ''', - formatter_class=RawTextHelpFormatter) + """, + formatter_class=RawTextHelpFormatter, + ) - parser.add_argument('--input_dir', - type=str, - default=None, - required=True, - help='Path of the folder containing the audio files to resample') + parser.add_argument( + "--input_dir", + type=str, + default=None, + required=True, + help="Path of the folder containing the audio files to resample", + ) - parser.add_argument('--output_sr', - type=int, - default=22050, - required=False, - help='Samlple rate to which the audio files should be resampled') + parser.add_argument( + "--output_sr", + type=int, + default=22050, + required=False, + help="Samlple rate to which the audio files should be resampled", + ) - parser.add_argument('--output_dir', - type=str, - default=None, - required=False, - help='Path of the destination folder. If not defined, the operation is done in place') + parser.add_argument( + "--output_dir", + type=str, + default=None, + required=False, + help="Path of the destination folder. If not defined, the operation is done in place", + ) - parser.add_argument('--n_jobs', - type=int, - default=None, - help='Number of threads to use, by default it uses all cores') + parser.add_argument( + "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores" + ) args = parser.parse_args() if args.output_dir: - print('Recursively copying the input folder...') + print("Recursively copying the input folder...") copy_tree(args.input_dir, args.output_dir) args.input_dir = args.output_dir - print('Resampling the audio files...') - audio_files = glob.glob(os.path.join(args.input_dir, '**/*.wav'), recursive=True) - print(f'Found {len(audio_files)} files...') - audio_files = list(zip(audio_files, len(audio_files)*[args.output_sr])) + print("Resampling the audio files...") + audio_files = glob.glob(os.path.join(args.input_dir, "**/*.wav"), recursive=True) + print(f"Found {len(audio_files)} files...") + audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr])) with Pool(processes=args.n_jobs) as p: with tqdm(total=len(audio_files)) as pbar: for i, _ in enumerate(p.imap_unordered(resample_file, audio_files)): pbar.update() - print('Done !') + print("Done !") diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 8b96d945aa..356196b56f 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -4,6 +4,7 @@ import argparse import sys from argparse import RawTextHelpFormatter + # pylint: disable=redefined-outer-name, unused-argument from pathlib import Path @@ -14,22 +15,20 @@ def str2bool(v): if isinstance(v, bool): return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): + if v.lower() in ("yes", "true", "t", "y", "1"): return True - if v.lower() in ('no', 'false', 'f', 'n', '0'): + if v.lower() in ("no", "false", "f", "n", "0"): return False - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError("Boolean value expected.") def main(): # pylint: disable=bad-continuation - parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n''' - - '''You can either use your trained model or choose a model from the provided list.\n\n'''\ - - '''If you don't specify any models, then it uses LJSpeech based English models\n\n'''\ - - ''' + parser = argparse.ArgumentParser( + description="""Synthesize speech on command line.\n\n""" + """You can either use your trained model or choose a model from the provided list.\n\n""" + """If you don't specify any models, then it uses LJSpeech based English models\n\n""" + """ Example runs: # list provided models @@ -51,106 +50,80 @@ def main(): ./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav --vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json - ''', - formatter_class=RawTextHelpFormatter) + """, + formatter_class=RawTextHelpFormatter, + ) parser.add_argument( - '--list_models', + "--list_models", type=str2bool, - nargs='?', + nargs="?", const=True, default=False, - help='list available pre-trained tts and vocoder models.' - ) - parser.add_argument( - '--text', - type=str, - default=None, - help='Text to generate speech.' - ) + help="list available pre-trained tts and vocoder models.", + ) + parser.add_argument("--text", type=str, default=None, help="Text to generate speech.") # Args for running pre-trained TTS models. parser.add_argument( - '--model_name', + "--model_name", type=str, default="tts_models/en/ljspeech/speedy-speech-wn", - help= - 'Name of one of the pre-trained tts models in format //' + help="Name of one of the pre-trained tts models in format //", ) parser.add_argument( - '--vocoder_name', + "--vocoder_name", type=str, default=None, - help= - 'Name of one of the pre-trained vocoder models in format //' + help="Name of one of the pre-trained vocoder models in format //", ) # Args for running custom models + parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") parser.add_argument( - '--config_path', - default=None, - type=str, - help='Path to model config file.' - ) - parser.add_argument( - '--model_path', + "--model_path", type=str, default=None, - help='Path to model file.', + help="Path to model file.", ) parser.add_argument( - '--out_path', + "--out_path", type=str, - default='tts_output.wav', - help='Output wav file path.', + default="tts_output.wav", + help="Output wav file path.", ) + parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False) parser.add_argument( - '--use_cuda', - type=bool, - help='Run model on CUDA.', - default=False - ) - parser.add_argument( - '--vocoder_path', + "--vocoder_path", type=str, - help= - 'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).', + help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", default=None, ) - parser.add_argument( - '--vocoder_config_path', - type=str, - help='Path to vocoder model config file.', - default=None) + parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) # args for multi-speaker synthesis + parser.add_argument("--speakers_json", type=str, help="JSON file for multi-speaker model.", default=None) parser.add_argument( - '--speakers_json', - type=str, - help="JSON file for multi-speaker model.", - default=None) - parser.add_argument( - '--speaker_idx', + "--speaker_idx", type=str, help="if the tts model is trained with x-vectors, then speaker_idx is a file present in speakers.json else speaker_idx is the speaker id corresponding to a speaker in the speaker embedding layer.", - default=None) - parser.add_argument( - '--gst_style', - help="Wav path file for GST stylereference.", - default=None) + default=None, + ) + parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None) # aux args parser.add_argument( - '--save_spectogram', + "--save_spectogram", type=bool, help="If true save raw spectogram for further (vocoder) processing in out_path.", - default=False) + default=False, + ) args = parser.parse_args() # print the description if either text or list_models is not set if args.text is None and not args.list_models: - parser.parse_args(['-h']) + parser.parse_args(["-h"]) # load model manager path = Path(__file__).parent / "../.models.json" @@ -169,7 +142,7 @@ def main(): # CASE2: load pre-trained models if args.model_name is not None: model_path, config_path, model_item = manager.download_model(args.model_name) - args.vocoder_name = model_item['default_vocoder'] if args.vocoder_name is None else args.vocoder_name + args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name if args.vocoder_name is not None: vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) diff --git a/TTS/bin/train_align_tts.py b/TTS/bin/train_align_tts.py index 1b3e7d5242..6f268ed352 100644 --- a/TTS/bin/train_align_tts.py +++ b/TTS/bin/train_align_tts.py @@ -12,6 +12,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler + from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import AlignTTSLoss @@ -25,12 +26,11 @@ from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam from TTS.utils.training import NoamLR, setup_torch_training_env -if __name__ == '__main__': +if __name__ == "__main__": use_cuda, num_gpus = setup_torch_training_env(True, False) # torch.autograd.set_detect_anomaly(True) @@ -44,10 +44,9 @@ def setup_loader(ap, r, is_val=False, verbose=False): compute_linear_spec=False, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, - tp=c.characters if 'characters' in c.keys() else None, - add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, - batch_group_size=0 if is_val else c.batch_group_size * - c.batch_size, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, max_seq_len=c.max_seq_len, phoneme_cache_path=c.phoneme_cache_path, @@ -56,8 +55,10 @@ def setup_loader(ap, r, is_val=False, verbose=False): enable_eos_bos=c.enable_eos_bos_chars, use_noise_augment=not is_val, verbose=verbose, - speaker_mapping=speaker_mapping if c.use_speaker_embedding - and c.use_external_speaker_embedding_file else None) + speaker_mapping=speaker_mapping + if c.use_speaker_embedding and c.use_external_speaker_embedding_file + else None, + ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. @@ -72,9 +73,9 @@ def setup_loader(ap, r, is_val=False, verbose=False): collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader def format_data(data): @@ -94,10 +95,7 @@ def format_data(data): speaker_c = data[8] else: # return speaker_id to be used by an embedding layer - speaker_c = [ - speaker_mapping[speaker_name] - for speaker_name in speaker_names - ] + speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_c = torch.LongTensor(speaker_c) else: speaker_c = None @@ -109,18 +107,15 @@ def format_data(data): mel_lengths = mel_lengths.cuda(non_blocking=True) if speaker_c is not None: speaker_c = speaker_c.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, item_idx + return text_input, text_lengths, mel_input, mel_lengths, speaker_c, avg_text_length, avg_spec_length, item_idx - def train(data_loader, model, criterion, optimizer, scheduler, ap, - global_step, epoch, training_phase): + def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, training_phase): model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -130,8 +125,16 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, start_time = time.time() # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, _ = format_data(data) + ( + text_input, + text_lengths, + mel_targets, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + _, + ) = format_data(data) loader_time = time.time() - end_time @@ -141,36 +144,32 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward( - text_input, - text_lengths, - mel_targets, - mel_lengths, - g=speaker_c, - phase=training_phase) + text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase + ) # compute loss - loss_dict = criterion(logp, - decoder_output, - mel_targets, - mel_lengths, - dur_output, - dur_mas_output, - text_lengths, - global_step, - phase=training_phase) + loss_dict = criterion( + logp, + decoder_output, + mel_targets, + mel_lengths, + dur_output, + dur_mas_output, + text_lengths, + global_step, + phase=training_phase, + ) # backward pass with loss scaling if c.mixed_precision: - scaler.scale(loss_dict['loss']).backward() + scaler.scale(loss_dict["loss"]).backward() scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), c.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: - loss_dict['loss'].backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), c.grad_clip) + loss_dict["loss"].backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) optimizer.step() # setup lr @@ -178,25 +177,21 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, scheduler.step() # current_lr - current_lr = optimizer.param_groups[0]['lr'] + current_lr = optimizer.param_groups[0]["lr"] # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: - loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, - num_gpus) - loss_dict['loss_ssim'] = reduce_tensor( - loss_dict['loss_ssim'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor( - loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, - num_gpus) + loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) + loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -210,48 +205,43 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training progress if global_step % c.print_step == 0: log_dict = { - "avg_spec_length": [avg_spec_length, - 1], # value, precision + "avg_spec_length": [avg_spec_length, 1], # value, precision "avg_text_length": [avg_text_length, 1], "step_time": [step_time, 4], "loader_time": [loader_time, 2], "current_lr": current_lr, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, - keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % c.tb_plot_step == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm, - "step_time": step_time - } + iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, - optimizer, - global_step, - epoch, - 1, - OUT_PATH, - model_characters, - model_loss=loss_dict['loss']) + save_checkpoint( + model, + optimizer, + global_step, + epoch, + 1, + OUT_PATH, + model_characters, + model_loss=loss_dict["loss"], + ) # wait all kernels to be completed torch.cuda.synchronize() @@ -259,8 +249,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, # Diagnostic visualizations if decoder_output is not None: idx = np.random.randint(mel_targets.shape[0]) - pred_spec = decoder_output[idx].detach().data.cpu( - ).numpy().T + pred_spec = decoder_output[idx].detach().data.cpu().numpy().T gt_spec = mel_targets[idx].data.cpu().numpy().T align_img = alignments[idx].data.cpu() @@ -274,14 +263,11 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_train_audios(global_step, - {'TrainAudio': train_audio}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats - c_logger.print_train_epoch_end(global_step, epoch, epoch_time, - keep_avg) + c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Epoch Stats if args.rank == 0: @@ -293,8 +279,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, ap, return keep_avg.avg_values, global_step @torch.no_grad() - def evaluate(data_loader, model, criterion, ap, global_step, epoch, - training_phase): + def evaluate(data_loader, model, criterion, ap, global_step, epoch, training_phase): model.eval() epoch_time = 0 keep_avg = KeepAverage() @@ -304,50 +289,41 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch, start_time = time.time() # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ - _, _, _ = format_data(data) + text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _ = format_data(data) # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward( - text_input, - text_lengths, - mel_targets, - mel_lengths, - g=speaker_c, - phase=training_phase) + text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase + ) # compute loss - loss_dict = criterion(logp, - decoder_output, - mel_targets, - mel_lengths, - dur_output, - dur_mas_output, - text_lengths, - global_step, - phase=training_phase) - + loss_dict = criterion( + logp, + decoder_output, + mel_targets, + mel_lengths, + dur_output, + dur_mas_output, + text_lengths, + global_step, + phase=training_phase, + ) # step time step_time = time.time() - start_time epoch_time += step_time # compute alignment score - align_error = 1 - alignment_diagonal_score(alignments, - binary=True) - loss_dict['align_error'] = align_error + align_error = 1 - alignment_diagonal_score(alignments, binary=True) + loss_dict["align_error"] = align_error # aggregate losses from processes if num_gpus > 1: - loss_dict['loss_l1'] = reduce_tensor( - loss_dict['loss_l1'].data, num_gpus) - loss_dict['loss_ssim'] = reduce_tensor( - loss_dict['loss_ssim'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor( - loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, - num_gpus) + loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) + loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -361,12 +337,11 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch, # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value + update_train_values["avg_" + key] = value keep_avg.update_values(update_train_values) if c.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, - keep_avg.avg_values) + c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) if args.rank == 0: # Diagnostic visualizations @@ -376,19 +351,14 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch, align_img = alignments[idx].data.cpu() eval_figures = { - "prediction": plot_spectrogram(pred_spec, - ap, - output_fig=False), - "ground_truth": plot_spectrogram(gt_spec, - ap, - output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False) + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), } # Sample audio eval_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # Plot Validation Stats tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) @@ -401,7 +371,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch, "Be a voice, not an echo.", "I'm sorry Dave. I'm afraid I can't do that.", "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963." + "Prior to November 22, 1963.", ] else: with open(c.test_sentences_file, "r") as f: @@ -413,9 +383,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch, print(" | > Synthesizing test sentences") if c.use_speaker_embedding: if c.use_external_speaker_embedding_file: - speaker_embedding = speaker_mapping[list( - speaker_mapping.keys())[randrange( - len(speaker_mapping) - 1)]]['embedding'] + speaker_embedding = speaker_mapping[ + list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)] + ]["embedding"] speaker_id = None else: speaker_id = 0 @@ -437,25 +407,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch, speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, - enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument use_griffin_lim=True, - do_trim_silence=False) + do_trim_silence=False, + ) file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, - "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) - test_audios['{}-audio'.format(idx)] = wav - test_figures['{}-prediction'.format( - idx)] = plot_spectrogram(postnet_output, ap) - test_figures['{}-alignment'.format(idx)] = plot_alignment( - alignment) - except: #pylint: disable=bare-except + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment) + except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, - c.audio['sample_rate']) + tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"]) tb_logger.tb_test_figures(global_step, test_figures) return keep_avg.avg_values @@ -464,69 +431,55 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) - if 'characters' in c.keys(): + if "characters" in c.keys(): symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # set model characters model_characters = phonemes if c.use_phonemes else symbols num_chars = len(model_characters) # load data instances - meta_data_train, meta_data_eval = load_meta_data(c.datasets, - eval_split=True) + meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True) # set the portion of the data used for training if set in config.json - if 'train_portion' in c.keys(): - meta_data_train = meta_data_train[:int( - len(meta_data_train) * c.train_portion)] - if 'eval_portion' in c.keys(): - meta_data_eval = meta_data_eval[:int( - len(meta_data_eval) * c.eval_portion)] + if "train_portion" in c.keys(): + meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)] + if "eval_portion" in c.keys(): + meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)] # parse speakers - num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers( - c, args, meta_data_train, OUT_PATH) + num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) # setup model - model = setup_model(num_chars, - num_speakers, - c, - speaker_embedding_dim=speaker_embedding_dim) - optimizer = RAdam(model.parameters(), - lr=c.lr, - weight_decay=0, - betas=(0.9, 0.98), - eps=1e-9) + model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) + optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) criterion = AlignTTSLoss(c) if args.restore_path: - print( - f" > Restoring from {os.path.basename(args.restore_path)} ...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + print(f" > Restoring from {os.path.basename(args.restore_path)} ...") + checkpoint = torch.load(args.restore_path, map_location="cpu") try: # TODO: fix optimizer init, model.cuda() needs to be called before # optimizer restore - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint["optimizer"]) if c.reinit_layers: raise RuntimeError - model.load_state_dict(checkpoint['model']) - except: #pylint: disable=bare-except + model.load_state_dict(checkpoint["model"]) + except: # pylint: disable=bare-except print(" > Partial model initialization.") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['initial_lr'] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + group["initial_lr"] = c.lr + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -539,9 +492,7 @@ def main(args): # pylint: disable=redefined-outer-name model = DDP_th(model, device_ids=[args.rank]) if c.noam_schedule: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -549,16 +500,14 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False # define dataloaders train_loader = setup_loader(ap, 1, is_val=False, verbose=True) @@ -573,9 +522,9 @@ def set_phase(): if not True in vals: phase = 0 else: - phase = len(c.phase_start_steps) - [ - i < global_step for i in c.phase_start_steps - ][::-1].index(True) - 1 + phase = ( + len(c.phase_start_steps) - [i < global_step for i in c.phase_start_steps][::-1].index(True) - 1 + ) else: phase = None return phase @@ -584,32 +533,30 @@ def set_phase(): cur_phase = set_phase() print(f"\n > Current AlignTTS phase: {cur_phase}") c_logger.print_epoch_start(epoch, c.epochs) - train_avg_loss_dict, global_step = train(train_loader, model, - criterion, optimizer, - scheduler, ap, - global_step, epoch, - cur_phase) - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, - global_step, epoch, cur_phase) + train_avg_loss_dict, global_step = train( + train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, cur_phase + ) + eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch, cur_phase) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict['avg_loss'] + target_loss = train_avg_loss_dict["avg_loss"] if c.run_eval: - target_loss = eval_avg_loss_dict['avg_loss'] - best_loss = save_best_model(target_loss, - best_loss, - model, - optimizer, - global_step, - epoch, - 1, - OUT_PATH, - model_characters, - keep_all_best=keep_all_best, - keep_after=keep_after) + target_loss = eval_avg_loss_dict["avg_loss"] + best_loss = save_best_model( + target_loss, + best_loss, + model, + optimizer, + global_step, + epoch, + 1, + OUT_PATH, + model_characters, + keep_all_best=keep_all_best, + keep_after=keep_after, + ) args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='tts') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") try: main(args) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 12fba6e11c..3a3f876e13 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -9,17 +9,21 @@ import torch from torch.utils.data import DataLoader + from TTS.speaker_encoder.dataset import MyDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss from TTS.speaker_encoder.model import SpeakerEncoder -from TTS.speaker_encoder.utils.generic_utils import \ - check_config_speaker_encoder, save_best_model +from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.tts.datasets.preprocess import load_meta_data from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import (count_parameters, - create_experiment_folder, get_git_branch, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import ( + count_parameters, + create_experiment_folder, + get_git_branch, + remove_experiment_folder, + set_init_dict, +) from TTS.utils.io import copy_model_files, load_config from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger @@ -34,28 +38,30 @@ print(" > Number of GPUs: ", num_gpus) -def setup_loader(ap: AudioProcessor, - is_val: bool = False, - verbose: bool = False): +def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): if is_val: loader = None else: - dataset = MyDataset(ap, - meta_data_eval if is_val else meta_data_train, - voice_len=1.6, - num_utter_per_speaker=c.num_utters_per_speaker, - num_speakers_in_batch=c.num_speakers_in_batch, - skip_speakers=False, - storage_size=c.storage["storage_size"], - sample_from_storage_p=c.storage["sample_from_storage_p"], - additive_noise=c.storage["additive_noise"], - verbose=verbose) + dataset = MyDataset( + ap, + meta_data_eval if is_val else meta_data_train, + voice_len=1.6, + num_utter_per_speaker=c.num_utters_per_speaker, + num_speakers_in_batch=c.num_speakers_in_batch, + skip_speakers=False, + storage_size=c.storage["storage_size"], + sample_from_storage_p=c.storage["sample_from_storage_p"], + additive_noise=c.storage["additive_noise"], + verbose=verbose, + ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader(dataset, - batch_size=c.num_speakers_in_batch, - shuffle=False, - num_workers=c.num_loader_workers, - collate_fn=dataset.collate_fn) + loader = DataLoader( + dataset, + batch_size=c.num_speakers_in_batch, + shuffle=False, + num_workers=c.num_loader_workers, + collate_fn=dataset.collate_fn, + ) return loader @@ -63,7 +69,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): data_loader = setup_loader(ap, is_val=False, verbose=True) model.train() epoch_time = 0 - best_loss = float('inf') + best_loss = float("inf") avg_loss = 0 avg_loader_time = 0 end_time = time.time() @@ -89,9 +95,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): outputs = model(inputs) # loss computation - loss = criterion( - outputs.view(c.num_speakers_in_batch, - outputs.shape[0] // c.num_speakers_in_batch, -1)) + loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1)) loss.backward() grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() @@ -100,11 +104,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): epoch_time += step_time # Averaged Loss and Averaged Loader Time - avg_loss = 0.01 * loss.item() \ - + 0.99 * avg_loss if avg_loss != 0 else loss.item() - avg_loader_time = 1/c.num_loader_workers * loader_time + \ - (c.num_loader_workers-1) / c.num_loader_workers * avg_loader_time if avg_loader_time != 0 else loader_time - current_lr = optimizer.param_groups[0]['lr'] + avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item() + avg_loader_time = ( + 1 / c.num_loader_workers * loader_time + (c.num_loader_workers - 1) / c.num_loader_workers * avg_loader_time + if avg_loader_time != 0 + else loader_time + ) + current_lr = optimizer.param_groups[0]["lr"] if global_step % c.steps_plot_stats == 0: # Plot Training Epoch Stats @@ -113,13 +119,12 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time, - "avg_loader_time": avg_loader_time + "avg_loader_time": avg_loader_time, } tb_logger.tb_train_epoch_stats(global_step, train_stats) figures = { # FIXME: not constant - "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), - 10), + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10), } tb_logger.tb_train_figures(global_step, figures) @@ -127,13 +132,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): print( " | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} " "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( - global_step, loss.item(), avg_loss, grad_norm, step_time, - loader_time, avg_loader_time, current_lr), - flush=True) + global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr + ), + flush=True, + ) # save best model - best_loss = save_best_model(model, optimizer, avg_loss, best_loss, - OUT_PATH, global_step) + best_loss = save_best_model(model, optimizer, avg_loss, best_loss, OUT_PATH, global_step) end_time = time.time() return avg_loss, global_step @@ -145,14 +150,16 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_eval ap = AudioProcessor(**c.audio) - model = SpeakerEncoder(input_dim=c.model['input_dim'], - proj_dim=c.model['proj_dim'], - lstm_dim=c.model['lstm_dim'], - num_lstm_layers=c.model['num_lstm_layers']) + model = SpeakerEncoder( + input_dim=c.model["input_dim"], + proj_dim=c.model["proj_dim"], + lstm_dim=c.model["lstm_dim"], + num_lstm_layers=c.model["num_lstm_layers"], + ) optimizer = RAdam(model.parameters(), lr=c.lr) if c.loss == "ge2e": - criterion = GE2ELoss(loss_method='softmax') + criterion = GE2ELoss(loss_method="softmax") elif c.loss == "angleproto": criterion = AngleProtoLoss() else: @@ -166,7 +173,7 @@ def main(args): # pylint: disable=redefined-outer-name # optimizer.load_state_dict(checkpoint['optimizer']) if c.reinit_layers: raise RuntimeError - model.load_state_dict(checkpoint['model']) + model.load_state_dict(checkpoint["model"]) except KeyError: print(" > Partial model initialization.") model_dict = model.state_dict() @@ -174,10 +181,9 @@ def main(args): # pylint: disable=redefined-outer-name model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['lr'] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + group["lr"] = c.lr + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -186,9 +192,7 @@ def main(args): # pylint: disable=redefined-outer-name criterion.cuda() if c.lr_decay: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -199,55 +203,39 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets) global_step = args.restore_step - _, global_step = train(model, criterion, optimizer, scheduler, ap, - global_step) + _, global_step = train(model, criterion, optimizer, scheduler, ap, global_step) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - '--restore_path', - type=str, - help='Path to model outputs (checkpoint, tensorboard etc.).', - default=0) + "--restore_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).", default=0 + ) parser.add_argument( - '--config_path', + "--config_path", type=str, required=True, - help='Path to config file for training.', + help="Path to config file for training.", ) - parser.add_argument('--debug', - type=bool, - default=True, - help='Do not verify commit integrity to run training.') - parser.add_argument( - '--data_path', - type=str, - default='', - help='Defines the data path. It overwrites config.json.') - parser.add_argument('--output_path', - type=str, - help='path for training outputs.', - default='') - parser.add_argument('--output_folder', - type=str, - default='', - help='folder name for training outputs.') + parser.add_argument("--debug", type=bool, default=True, help="Do not verify commit integrity to run training.") + parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") + parser.add_argument("--output_path", type=str, help="path for training outputs.", default="") + parser.add_argument("--output_folder", type=str, default="", help="folder name for training outputs.") args = parser.parse_args() # setup output paths and read configs c = load_config(args.config_path) check_config_speaker_encoder(c) _ = os.path.dirname(os.path.realpath(__file__)) - if args.data_path != '': + if args.data_path != "": c.data_path = args.data_path - if args.output_path == '': + if args.output_path == "": OUT_PATH = os.path.join(_, c.output_path) else: OUT_PATH = args.output_path - if args.output_folder == '': + if args.output_folder == "": OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug) else: OUT_PATH = os.path.join(OUT_PATH, args.output_folder) @@ -259,7 +247,7 @@ def main(args): # pylint: disable=redefined-outer-name copy_model_files(c, args.config_path, OUT_PATH, new_fields) LOG_DIR = OUT_PATH - tb_logger = TensorboardLogger(LOG_DIR, model_name='Speaker_Encoder') + tb_logger = TensorboardLogger(LOG_DIR, model_name="Speaker_Encoder") try: main(args) diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 117de53112..d3b3d0e2fb 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -8,12 +8,12 @@ from random import randrange import torch + # DISTRIBUTED from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from TTS.utils.arguments import parse_arguments, process_args from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import GlowTTSLoss @@ -24,10 +24,10 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam from TTS.utils.training import NoamLR, setup_torch_training_env @@ -44,19 +44,21 @@ def setup_loader(ap, r, is_val=False, verbose=False): compute_linear_spec=False, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, - tp=c.characters if 'characters' in c.keys() else None, - add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, - batch_group_size=0 if is_val else c.batch_group_size * - c.batch_size, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, max_seq_len=c.max_seq_len, phoneme_cache_path=c.phoneme_cache_path, use_phonemes=c.use_phonemes, phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, - use_noise_augment=c['use_noise_augment'] and not is_val, + use_noise_augment=c["use_noise_augment"] and not is_val, verbose=verbose, - speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) + speaker_mapping=speaker_mapping + if c.use_speaker_embedding and c.use_external_speaker_embedding_file + else None, + ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. @@ -71,9 +73,9 @@ def setup_loader(ap, r, is_val=False, verbose=False): collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader @@ -95,9 +97,7 @@ def format_data(data): speaker_c = data[8] else: # return speaker_id to be used by an embedding layer - speaker_c = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] + speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_c = torch.LongTensor(speaker_c) else: speaker_c = None @@ -112,13 +112,22 @@ def format_data(data): speaker_c = speaker_c.cuda(non_blocking=True) if attn_mask is not None: attn_mask = attn_mask.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, attn_mask, item_idx + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + attn_mask, + item_idx, + ) def data_depended_init(data_loader, model): """Data depended initialization for activation normalization.""" - if hasattr(model, 'module'): + if hasattr(model, "module"): for f in model.module.decoder.flows: if getattr(f, "set_ddi", False): f.set_ddi(True) @@ -134,17 +143,15 @@ def data_depended_init(data_loader, model): for _, data in enumerate(data_loader): # format data - text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\ - _, _, attn_mask, _ = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, spekaer_embed, _, _, attn_mask, _ = format_data(data) # forward pass model - _ = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed) + _ = model.forward(text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed) if num_iter == c.data_dep_init_iter: break num_iter += 1 - if hasattr(model, 'module'): + if hasattr(model, "module"): for f in model.module.decoder.flows: if getattr(f, "set_ddi", False): f.set_ddi(False) @@ -155,15 +162,13 @@ def data_depended_init(data_loader, model): return model -def train(data_loader, model, criterion, optimizer, scheduler, - ap, global_step, epoch): +def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch): model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -173,8 +178,17 @@ def train(data_loader, model, criterion, optimizer, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, attn_mask, _ = format_data(data) + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + attn_mask, + _, + ) = format_data(data) loader_time = time.time() - end_time @@ -184,24 +198,22 @@ def train(data_loader, model, criterion, optimizer, scheduler, # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c + ) # compute loss - loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, - o_dur_log, o_total_dur, text_lengths) + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths) # backward pass with loss scaling if c.mixed_precision: - scaler.scale(loss_dict['loss']).backward() + scaler.scale(loss_dict["loss"]).backward() scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: - loss_dict['loss'].backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.grad_clip) + loss_dict["loss"].backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) optimizer.step() # setup lr @@ -209,20 +221,20 @@ def train(data_loader, model, criterion, optimizer, scheduler, scheduler.step() # current_lr - current_lr = optimizer.param_groups[0]['lr'] + current_lr = optimizer.param_groups[0]["lr"] # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: - loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -236,9 +248,9 @@ def train(data_loader, model, criterion, optimizer, scheduler, # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training progress @@ -250,26 +262,29 @@ def train(data_loader, model, criterion, optimizer, scheduler, "loader_time": [loader_time, 2], "current_lr": current_lr, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % c.tb_plot_step == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm, - "step_time": step_time - } + iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters, - model_loss=loss_dict['loss']) + save_checkpoint( + model, + optimizer, + global_step, + epoch, + 1, + OUT_PATH, + model_characters, + model_loss=loss_dict["loss"], + ) # wait all kernels to be completed torch.cuda.synchronize() @@ -278,7 +293,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, # direct pass on model for spec predictions target_speaker = None if speaker_c is None else speaker_c[:1] - if hasattr(model, 'module'): + if hasattr(model, "module"): spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker) else: spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) @@ -299,9 +314,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, # Sample audio train_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_train_audios(global_step, - {'TrainAudio': train_audio}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -328,16 +341,15 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - _, _, attn_mask, _ = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, speaker_c, _, _, attn_mask, _ = format_data(data) # forward pass model z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c + ) # compute loss - loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, - o_dur_log, o_total_dur, text_lengths) + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths) # step time step_time = time.time() - start_time @@ -345,13 +357,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # compute alignment score align_error = 1 - alignment_diagonal_score(alignments) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error # aggregate losses from processes if num_gpus > 1: - loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -365,7 +377,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value + update_train_values["avg_" + key] = value keep_avg.update_values(update_train_values) if c.print_eval: @@ -375,7 +387,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # Diagnostic visualizations # direct pass on model for spec predictions target_speaker = None if speaker_c is None else speaker_c[:1] - if hasattr(model, 'module'): + if hasattr(model, "module"): spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker) else: spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) @@ -389,13 +401,12 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): eval_figures = { "prediction": plot_spectrogram(const_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img) + "alignment": plot_alignment(align_img), } # Sample audio eval_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # Plot Validation Stats tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) @@ -408,7 +419,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): "Be a voice, not an echo.", "I'm sorry Dave. I'm afraid I can't do that.", "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963." + "Prior to November 22, 1963.", ] else: with open(c.test_sentences_file, "r") as f: @@ -420,7 +431,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): print(" | > Synthesizing test sentences") if c.use_speaker_embedding: if c.use_external_speaker_embedding_file: - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] + speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][ + "embedding" + ] speaker_id = None else: speaker_id = 0 @@ -442,25 +455,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, - enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument use_griffin_lim=True, - do_trim_silence=False) + do_trim_silence=False, + ) file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, - "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) - test_audios['{}-audio'.format(idx)] = wav - test_figures['{}-prediction'.format(idx)] = plot_spectrogram( - postnet_output, ap) - test_figures['{}-alignment'.format(idx)] = plot_alignment( - alignment) - except: #pylint: disable=bare-except + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment) + except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, - c.audio['sample_rate']) + tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"]) tb_logger.tb_test_figures(global_step, test_figures) return keep_avg.avg_values @@ -470,13 +480,12 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) - if 'characters' in c.keys(): + if "characters" in c.keys(): symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # set model characters model_characters = phonemes if c.use_phonemes else symbols @@ -486,10 +495,10 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets) # set the portion of the data used for training - if 'train_portion' in c.keys(): - meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)] - if 'eval_portion' in c.keys(): - meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] + if "train_portion" in c.keys(): + meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)] + if "eval_portion" in c.keys(): + meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)] # parse speakers num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) @@ -501,26 +510,25 @@ def main(args): # pylint: disable=redefined-outer-name if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)} ...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: # TODO: fix optimizer init, model.cuda() needs to be called before # optimizer restore - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint["optimizer"]) if c.reinit_layers: raise RuntimeError - model.load_state_dict(checkpoint['model']) - except: #pylint: disable=bare-except + model.load_state_dict(checkpoint["model"]) + except: # pylint: disable=bare-except print(" > Partial model initialization.") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['initial_lr'] = c.lr - print(f" > Model restored from step {checkpoint['step']:d}", - flush=True) - args.restore_step = checkpoint['step'] + group["initial_lr"] = c.lr + print(f" > Model restored from step {checkpoint['step']:d}", flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -533,9 +541,7 @@ def main(args): # pylint: disable=redefined-outer-name model = DDP_th(model, device_ids=[args.rank]) if c.noam_schedule: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -543,16 +549,14 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False # define dataloaders train_loader = setup_loader(ap, 1, is_val=False, verbose=True) @@ -562,25 +566,32 @@ def main(args): # pylint: disable=redefined-outer-name model = data_depended_init(train_loader, model) for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - train_avg_loss_dict, global_step = train(train_loader, model, - criterion, optimizer, - scheduler, ap, global_step, - epoch) - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, - global_step, epoch) + train_avg_loss_dict, global_step = train( + train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch + ) + eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict['avg_loss'] + target_loss = train_avg_loss_dict["avg_loss"] if c.run_eval: - target_loss = eval_avg_loss_dict['avg_loss'] - best_loss = save_best_model(target_loss, best_loss, model, optimizer, - global_step, epoch, c.r, OUT_PATH, model_characters, - keep_all_best=keep_all_best, keep_after=keep_after) - - -if __name__ == '__main__': + target_loss = eval_avg_loss_dict["avg_loss"] + best_loss = save_best_model( + target_loss, + best_loss, + model, + optimizer, + global_step, + epoch, + c.r, + OUT_PATH, + model_characters, + keep_all_best=keep_all_best, + keep_after=keep_after, + ) + + +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='tts') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") try: main(args) diff --git a/TTS/bin/train_speedy_speech.py b/TTS/bin/train_speedy_speech.py index 026413bb0f..3adbe51377 100644 --- a/TTS/bin/train_speedy_speech.py +++ b/TTS/bin/train_speedy_speech.py @@ -5,15 +5,16 @@ import sys import time import traceback -import numpy as np from random import randrange +import numpy as np import torch -from TTS.utils.arguments import parse_arguments, process_args + # DISTRIBUTED from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler + from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import SpeedySpeechLoss @@ -24,10 +25,10 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam from TTS.utils.training import NoamLR, setup_torch_training_env @@ -44,10 +45,9 @@ def setup_loader(ap, r, is_val=False, verbose=False): compute_linear_spec=False, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, - tp=c.characters if 'characters' in c.keys() else None, - add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, - batch_group_size=0 if is_val else c.batch_group_size * - c.batch_size, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, max_seq_len=c.max_seq_len, phoneme_cache_path=c.phoneme_cache_path, @@ -56,7 +56,10 @@ def setup_loader(ap, r, is_val=False, verbose=False): enable_eos_bos=c.enable_eos_bos_chars, use_noise_augment=not is_val, verbose=verbose, - speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) + speaker_mapping=speaker_mapping + if c.use_speaker_embedding and c.use_external_speaker_embedding_file + else None, + ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. @@ -71,9 +74,9 @@ def setup_loader(ap, r, is_val=False, verbose=False): collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader @@ -95,9 +98,7 @@ def format_data(data): speaker_c = data[8] else: # return speaker_id to be used by an embedding layer - speaker_c = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] + speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_c = torch.LongTensor(speaker_c) else: speaker_c = None @@ -105,7 +106,7 @@ def format_data(data): durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) for idx, am in enumerate(attn_mask): # compute raw durations - c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1] + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) c_idxs, counts = torch.unique(c_idxs, return_counts=True) dur = torch.ones([text_lengths[idx]]).to(counts.dtype) @@ -115,8 +116,10 @@ def format_data(data): extra_frames = dur.sum() - mel_lengths[idx] largest_idxs = torch.argsort(-dur)[:extra_frames] dur[largest_idxs] -= 1 - assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" - durations[idx, :text_lengths[idx]] = dur + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur # dispatch data to GPU if use_cuda: text_input = text_input.cuda(non_blocking=True) @@ -127,19 +130,27 @@ def format_data(data): speaker_c = speaker_c.cuda(non_blocking=True) attn_mask = attn_mask.cuda(non_blocking=True) durations = durations.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, attn_mask, durations, item_idx - - -def train(data_loader, model, criterion, optimizer, scheduler, - ap, global_step, epoch): + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + attn_mask, + durations, + item_idx, + ) + + +def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch): model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -149,8 +160,18 @@ def train(data_loader, model, criterion, optimizer, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, _, dur_target, _ = format_data(data) + ( + text_input, + text_lengths, + mel_targets, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + _, + dur_target, + _, + ) = format_data(data) loader_time = time.time() - end_time @@ -160,23 +181,24 @@ def train(data_loader, model, criterion, optimizer, scheduler, # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): decoder_output, dur_output, alignments = model.forward( - text_input, text_lengths, mel_lengths, dur_target, g=speaker_c) + text_input, text_lengths, mel_lengths, dur_target, g=speaker_c + ) # compute loss - loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths) + loss_dict = criterion( + decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths + ) # backward pass with loss scaling if c.mixed_precision: - scaler.scale(loss_dict['loss']).backward() + scaler.scale(loss_dict["loss"]).backward() scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: - loss_dict['loss'].backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.grad_clip) + loss_dict["loss"].backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) optimizer.step() # setup lr @@ -184,21 +206,21 @@ def train(data_loader, model, criterion, optimizer, scheduler, scheduler.step() # current_lr - current_lr = optimizer.param_groups[0]['lr'] + current_lr = optimizer.param_groups[0]["lr"] # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: - loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus) - loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) + loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -212,41 +234,43 @@ def train(data_loader, model, criterion, optimizer, scheduler, # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training progress if global_step % c.print_step == 0: log_dict = { - "avg_spec_length": [avg_spec_length, 1], # value, precision "avg_text_length": [avg_text_length, 1], "step_time": [step_time, 4], "loader_time": [loader_time, 2], "current_lr": current_lr, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % c.tb_plot_step == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm, - "step_time": step_time - } + iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters, - model_loss=loss_dict['loss']) + save_checkpoint( + model, + optimizer, + global_step, + epoch, + 1, + OUT_PATH, + model_characters, + model_loss=loss_dict["loss"], + ) # wait all kernels to be completed torch.cuda.synchronize() @@ -267,9 +291,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_train_audios(global_step, - {'TrainAudio': train_audio}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -296,16 +318,18 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ - _, _, _, dur_target, _ = format_data(data) + text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _, dur_target, _ = format_data(data) # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): decoder_output, dur_output, alignments = model.forward( - text_input, text_lengths, mel_lengths, dur_target, g=speaker_c) + text_input, text_lengths, mel_lengths, dur_target, g=speaker_c + ) # compute loss - loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths) + loss_dict = criterion( + decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths + ) # step time step_time = time.time() - start_time @@ -313,14 +337,14 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # compute alignment score align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error # aggregate losses from processes if num_gpus > 1: - loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus) - loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) + loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -334,7 +358,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value + update_train_values["avg_" + key] = value keep_avg.update_values(update_train_values) if c.print_eval: @@ -350,13 +374,12 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): eval_figures = { "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False) + "alignment": plot_alignment(align_img, output_fig=False), } # Sample audio eval_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # Plot Validation Stats tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) @@ -369,7 +392,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): "Be a voice, not an echo.", "I'm sorry Dave. I'm afraid I can't do that.", "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963." + "Prior to November 22, 1963.", ] else: with open(c.test_sentences_file, "r") as f: @@ -381,7 +404,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): print(" | > Synthesizing test sentences") if c.use_speaker_embedding: if c.use_external_speaker_embedding_file: - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] + speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][ + "embedding" + ] speaker_id = None else: speaker_id = 0 @@ -403,25 +428,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, - enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument use_griffin_lim=True, - do_trim_silence=False) + do_trim_silence=False, + ) file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, - "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) - test_audios['{}-audio'.format(idx)] = wav - test_figures['{}-prediction'.format(idx)] = plot_spectrogram( - postnet_output, ap) - test_figures['{}-alignment'.format(idx)] = plot_alignment( - alignment) - except: #pylint: disable=bare-except + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment) + except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, - c.audio['sample_rate']) + tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"]) tb_logger.tb_test_figures(global_step, test_figures) return keep_avg.avg_values @@ -432,13 +454,12 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) - if 'characters' in c.keys(): + if "characters" in c.keys(): symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # set model characters model_characters = phonemes if c.use_phonemes else symbols @@ -448,10 +469,10 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True) # set the portion of the data used for training if set in config.json - if 'train_portion' in c.keys(): - meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)] - if 'eval_portion' in c.keys(): - meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] + if "train_portion" in c.keys(): + meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)] + if "eval_portion" in c.keys(): + meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)] # parse speakers num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) @@ -463,26 +484,25 @@ def main(args): # pylint: disable=redefined-outer-name if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)} ...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: # TODO: fix optimizer init, model.cuda() needs to be called before # optimizer restore - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint["optimizer"]) if c.reinit_layers: raise RuntimeError - model.load_state_dict(checkpoint['model']) - except: #pylint: disable=bare-except + model.load_state_dict(checkpoint["model"]) + except: # pylint: disable=bare-except print(" > Partial model initialization.") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['initial_lr'] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + group["initial_lr"] = c.lr + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -495,9 +515,7 @@ def main(args): # pylint: disable=redefined-outer-name model = DDP_th(model, device_ids=[args.rank]) if c.noam_schedule: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -505,16 +523,14 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False # define dataloaders train_loader = setup_loader(ap, 1, is_val=False, verbose=True) @@ -523,24 +539,32 @@ def main(args): # pylint: disable=redefined-outer-name global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, - scheduler, ap, global_step, - epoch) - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, - global_step, epoch) + train_avg_loss_dict, global_step = train( + train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch + ) + eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict['avg_loss'] + target_loss = train_avg_loss_dict["avg_loss"] if c.run_eval: - target_loss = eval_avg_loss_dict['avg_loss'] - best_loss = save_best_model(target_loss, best_loss, model, optimizer, - global_step, epoch, c.r, OUT_PATH, model_characters, - keep_all_best=keep_all_best, keep_after=keep_after) - - -if __name__ == '__main__': + target_loss = eval_avg_loss_dict["avg_loss"] + best_loss = save_best_model( + target_loss, + best_loss, + model, + optimizer, + global_step, + epoch, + c.r, + OUT_PATH, + model_characters, + keep_all_best=keep_all_best, + keep_after=keep_after, + ) + + +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='tts') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") try: main(args) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index ce41980d45..c8346c3abc 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -10,7 +10,7 @@ import numpy as np import torch from torch.utils.data import DataLoader -from TTS.utils.arguments import parse_arguments, process_args + from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import TacotronLoss @@ -21,15 +21,19 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce, - init_distributed, reduce_tensor) -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam -from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, - gradual_training_scheduler, set_weight_decay, - setup_torch_training_env) +from TTS.utils.training import ( + NoamLR, + adam_weight_decay, + check_update, + gradual_training_scheduler, + set_weight_decay, + setup_torch_training_env, +) use_cuda, num_gpus = setup_torch_training_env(True, False) @@ -42,13 +46,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): dataset = MyDataset( r, c.text_cleaner, - compute_linear_spec=c.model.lower() == 'tacotron', + compute_linear_spec=c.model.lower() == "tacotron", meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, - tp=c.characters if 'characters' in c.keys() else None, - add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, - batch_group_size=0 if is_val else c.batch_group_size * - c.batch_size, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, max_seq_len=c.max_seq_len, phoneme_cache_path=c.phoneme_cache_path, @@ -56,11 +59,10 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, verbose=verbose, - speaker_mapping=(speaker_mapping if ( - c.use_speaker_embedding - and c.use_external_speaker_embedding_file - ) else None) - ) + speaker_mapping=( + speaker_mapping if (c.use_speaker_embedding and c.use_external_speaker_embedding_file) else None + ), + ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. @@ -75,11 +77,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader + def format_data(data): # setup input data text_input = data[0] @@ -97,21 +100,16 @@ def format_data(data): speaker_embeddings = data[8] speaker_ids = None else: - speaker_ids = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] + speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_ids = torch.LongTensor(speaker_ids) speaker_embeddings = None else: speaker_embeddings = None speaker_ids = None - # set stop targets view, we predict a single stop token per iteration. - stop_targets = stop_targets.view(text_input.shape[0], - stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > - 0.0).unsqueeze(2).float().squeeze(2) + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) # dispatch data to GPU if use_cuda: @@ -126,17 +124,26 @@ def format_data(data): if speaker_embeddings is not None: speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length - - -def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, - ap, global_step, epoch, scaler, scaler_st): + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + linear_input, + stop_targets, + speaker_ids, + speaker_embeddings, + max_text_length, + max_spec_length, + ) + + +def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, scaler, scaler_st): model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -145,7 +152,18 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data) + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + linear_input, + stop_targets, + speaker_ids, + speaker_embeddings, + max_text_length, + max_spec_length, + ) = format_data(data) loader_time = time.time() - end_time global_step += 1 @@ -161,35 +179,65 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, with torch.cuda.amp.autocast(enabled=c.mixed_precision): # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: - decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + ( + decoder_output, + postnet_output, + alignments, + stop_tokens, + decoder_backward_output, + alignments_backward, + ) = model( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids=speaker_ids, + speaker_embeddings=speaker_embeddings, + ) else: decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids=speaker_ids, + speaker_embeddings=speaker_embeddings, + ) decoder_backward_output = None alignments_backward = None # set the [alignment] lengths wrt reduction factor for guided attention if mel_lengths.max() % model.decoder.r != 0: - alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r + alignment_lengths = ( + mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r)) + ) // model.decoder.r else: - alignment_lengths = mel_lengths // model.decoder.r + alignment_lengths = mel_lengths // model.decoder.r # compute loss - loss_dict = criterion(postnet_output, decoder_output, mel_input, - linear_input, stop_tokens, stop_targets, - mel_lengths, decoder_backward_output, - alignments, alignment_lengths, - alignments_backward, text_lengths) + loss_dict = criterion( + postnet_output, + decoder_output, + mel_input, + linear_input, + stop_tokens, + stop_targets, + mel_lengths, + decoder_backward_output, + alignments, + alignment_lengths, + alignments_backward, + text_lengths, + ) # check nan loss - if torch.isnan(loss_dict['loss']).any(): - raise RuntimeError(f'Detected NaN loss at step {global_step}.') + if torch.isnan(loss_dict["loss"]).any(): + raise RuntimeError(f"Detected NaN loss at step {global_step}.") # optimizer step if c.mixed_precision: # model optimizer step in mixed precision mode - scaler.scale(loss_dict['loss']).backward() + scaler.scale(loss_dict["loss"]).backward() scaler.unscale_(optimizer) optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) @@ -198,7 +246,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, # stopnet optimizer step if c.separate_stopnet: - scaler_st.scale(loss_dict['stopnet_loss']).backward() + scaler_st.scale(loss_dict["stopnet_loss"]).backward() scaler.unscale_(optimizer_st) optimizer_st, _ = adam_weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) @@ -208,14 +256,14 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, grad_norm_st = 0 else: # main model optimizer step - loss_dict['loss'].backward() + loss_dict["loss"].backward() optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) optimizer.step() # stopnet optimizer step if c.separate_stopnet: - loss_dict['stopnet_loss'].backward() + loss_dict["stopnet_loss"].backward() optimizer_st, _ = adam_weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() @@ -224,17 +272,19 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: - loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) - loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) - loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss'] + loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus) + loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) + loss_dict["stopnet_loss"] = ( + reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) if c.stopnet else loss_dict["stopnet_loss"] + ) # detach loss values loss_dict_new = dict() @@ -248,9 +298,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training progress @@ -262,8 +312,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, "loader_time": [loader_time, 2], "current_lr": current_lr, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats @@ -273,7 +322,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, - "step_time": step_time + "step_time": step_time, } iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) @@ -281,17 +330,26 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, - optimizer_st=optimizer_st, - model_loss=loss_dict['postnet_loss'], - characters=model_characters, - scaler=scaler.state_dict() if c.mixed_precision else None) + save_checkpoint( + model, + optimizer, + global_step, + epoch, + model.decoder.r, + OUT_PATH, + optimizer_st=optimizer_st, + model_loss=loss_dict["postnet_loss"], + characters=model_characters, + scaler=scaler.state_dict() if c.mixed_precision else None, + ) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() - gt_spec = linear_input[0].data.cpu().numpy() if c.model in [ - "Tacotron", "TacotronGST" - ] else mel_input[0].data.cpu().numpy() + gt_spec = ( + linear_input[0].data.cpu().numpy() + if c.model in ["Tacotron", "TacotronGST"] + else mel_input[0].data.cpu().numpy() + ) align_img = alignments[0].data.cpu().numpy() figures = { @@ -301,7 +359,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, } if c.bidirectional_decoder or c.double_decoder_consistency: - figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + figures["alignment_backward"] = plot_alignment( + alignments_backward[0].data.cpu().numpy(), output_fig=False + ) tb_logger.tb_train_figures(global_step, figures) @@ -310,9 +370,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, train_audio = ap.inv_spectrogram(const_spec.T) else: train_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_train_audios(global_step, - {'TrainAudio': train_audio}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -339,31 +397,62 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data) + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + linear_input, + stop_targets, + speaker_ids, + speaker_embeddings, + _, + _, + ) = format_data(data) assert mel_input.shape[1] % model.decoder.r == 0 # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: - decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + ( + decoder_output, + postnet_output, + alignments, + stop_tokens, + decoder_backward_output, + alignments_backward, + ) = model( + text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings + ) else: decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings + ) decoder_backward_output = None alignments_backward = None # set the alignment lengths wrt reduction factor for guided attention if mel_lengths.max() % model.decoder.r != 0: - alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r + alignment_lengths = ( + mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r)) + ) // model.decoder.r else: - alignment_lengths = mel_lengths // model.decoder.r + alignment_lengths = mel_lengths // model.decoder.r # compute loss - loss_dict = criterion(postnet_output, decoder_output, mel_input, - linear_input, stop_tokens, stop_targets, - mel_lengths, decoder_backward_output, - alignments, alignment_lengths, alignments_backward, - text_lengths) + loss_dict = criterion( + postnet_output, + decoder_output, + mel_input, + linear_input, + stop_tokens, + stop_targets, + mel_lengths, + decoder_backward_output, + alignments, + alignment_lengths, + alignments_backward, + text_lengths, + ) # step time step_time = time.time() - start_time @@ -371,14 +460,14 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # compute alignment score align_error = 1 - alignment_diagonal_score(alignments) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error # aggregate losses from processes if num_gpus > 1: - loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) - loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) + loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus) + loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus) if c.stopnet: - loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) + loss_dict["stopnet_loss"] = reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -392,7 +481,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value + update_train_values["avg_" + key] = value keep_avg.update_values(update_train_values) if c.print_eval: @@ -402,15 +491,17 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # Diagnostic visualizations idx = np.random.randint(mel_input.shape[0]) const_spec = postnet_output[idx].data.cpu().numpy() - gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [ - "Tacotron", "TacotronGST" - ] else mel_input[idx].data.cpu().numpy() + gt_spec = ( + linear_input[idx].data.cpu().numpy() + if c.model in ["Tacotron", "TacotronGST"] + else mel_input[idx].data.cpu().numpy() + ) align_img = alignments[idx].data.cpu().numpy() eval_figures = { "prediction": plot_spectrogram(const_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False) + "alignment": plot_alignment(align_img, output_fig=False), } # Sample audio @@ -418,14 +509,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): eval_audio = ap.inv_spectrogram(const_spec.T) else: eval_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # Plot Validation Stats if c.bidirectional_decoder or c.double_decoder_consistency: align_b_img = alignments_backward[idx].data.cpu().numpy() - eval_figures['alignment2'] = plot_alignment(align_b_img, output_fig=False) + eval_figures["alignment2"] = plot_alignment(align_b_img, output_fig=False) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_figures(global_step, eval_figures) @@ -436,7 +526,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): "Be a voice, not an echo.", "I'm sorry Dave. I'm afraid I can't do that.", "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963." + "Prior to November 22, 1963.", ] else: with open(c.test_sentences_file, "r") as f: @@ -447,13 +537,17 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): test_figures = {} print(" | > Synthesizing test sentences") speaker_id = 0 if c.use_speaker_embedding else None - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] if c.use_external_speaker_embedding_file and c.use_speaker_embedding else None + speaker_embedding = ( + speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]]["embedding"] + if c.use_external_speaker_embedding_file and c.use_speaker_embedding + else None + ) style_wav = c.get("gst_style_input") if style_wav is None and c.use_gst: # inicialize GST with zero dict. style_wav = {} print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") - for i in range(c.gst['gst_style_tokens']): + for i in range(c.gst["gst_style_tokens"]): style_wav[str(i)] = 0 style_wav = c.get("gst_style_input") for idx, test_sentence in enumerate(test_sentences): @@ -468,25 +562,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, - enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument use_griffin_lim=True, - do_trim_silence=False) + do_trim_silence=False, + ) file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, - "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) - test_audios['{}-audio'.format(idx)] = wav - test_figures['{}-prediction'.format(idx)] = plot_spectrogram( - postnet_output, ap, output_fig=False) - test_figures['{}-alignment'.format(idx)] = plot_alignment( - alignment, output_fig=False) - except: #pylint: disable=bare-except + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap, output_fig=False) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) + except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, - c.audio['sample_rate']) + tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"]) tb_logger.tb_test_figures(global_step, test_figures) return keep_avg.avg_values @@ -498,13 +589,12 @@ def main(args): # pylint: disable=redefined-outer-name ap = AudioProcessor(**c.audio) # setup custom characters if set in config file. - if 'characters' in c.keys(): + if "characters" in c.keys(): symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) num_chars = len(phonemes) if c.use_phonemes else len(symbols) model_characters = phonemes if c.use_phonemes else symbols @@ -512,10 +602,10 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets) # set the portion of the data used for training - if 'train_portion' in c.keys(): - meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)] - if 'eval_portion' in c.keys(): - meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] + if "train_portion" in c.keys(): + meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)] + if "eval_portion" in c.keys(): + meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)] # parse speakers num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) @@ -529,9 +619,7 @@ def main(args): # pylint: disable=redefined-outer-name params = set_weight_decay(model, c.wd) optimizer = RAdam(params, lr=c.lr, weight_decay=0) if c.stopnet and c.separate_stopnet: - optimizer_st = RAdam(model.decoder.stopnet.parameters(), - lr=c.lr, - weight_decay=0) + optimizer_st = RAdam(model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) else: optimizer_st = None @@ -539,13 +627,13 @@ def main(args): # pylint: disable=redefined-outer-name criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4) if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Model...") - model.load_state_dict(checkpoint['model']) + model.load_state_dict(checkpoint["model"]) # optimizer restore print(" > Restoring Optimizer...") - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint["optimizer"]) if "scaler" in checkpoint and c.mixed_precision: print(" > Restoring AMP Scaler...") scaler.load_state_dict(checkpoint["scaler"]) @@ -554,17 +642,16 @@ def main(args): # pylint: disable=redefined-outer-name except (KeyError, RuntimeError): print(" > Partial model initialization...") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) # torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt')) # print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt')) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['lr'] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + group["lr"] = c.lr + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -577,9 +664,7 @@ def main(args): # pylint: disable=redefined-outer-name model = apply_gradient_allreduce(model) if c.noam_schedule: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -587,22 +672,17 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False # define data loaders - train_loader = setup_loader(ap, - model.decoder.r, - is_val=False, - verbose=True) + train_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=True) eval_loader = setup_loader(ap, model.decoder.r, is_val=True) global_step = args.restore_step @@ -617,28 +697,29 @@ def main(args): # pylint: disable=redefined-outer-name model.decoder_backward.set_r(r) train_loader.dataset.outputs_per_step = r eval_loader.dataset.outputs_per_step = r - train_loader = setup_loader(ap, - model.decoder.r, - is_val=False, - dataset=train_loader.dataset) - eval_loader = setup_loader(ap, - model.decoder.r, - is_val=True, - dataset=eval_loader.dataset) + train_loader = setup_loader(ap, model.decoder.r, is_val=False, dataset=train_loader.dataset) + eval_loader = setup_loader(ap, model.decoder.r, is_val=True, dataset=eval_loader.dataset) print("\n > Number of output frames:", model.decoder.r) # train one epoch - train_avg_loss_dict, global_step = train(train_loader, model, - criterion, optimizer, - optimizer_st, scheduler, ap, - global_step, epoch, scaler, - scaler_st) + train_avg_loss_dict, global_step = train( + train_loader, + model, + criterion, + optimizer, + optimizer_st, + scheduler, + ap, + global_step, + epoch, + scaler, + scaler_st, + ) # eval one epoch - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, - global_step, epoch) + eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict['avg_postnet_loss'] + target_loss = train_avg_loss_dict["avg_postnet_loss"] if c.run_eval: - target_loss = eval_avg_loss_dict['avg_postnet_loss'] + target_loss = eval_avg_loss_dict["avg_postnet_loss"] best_loss = save_best_model( target_loss, best_loss, @@ -651,14 +732,13 @@ def main(args): # pylint: disable=redefined-outer-name model_characters, keep_all_best=keep_all_best, keep_after=keep_after, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='tts') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") try: main(args) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 6357bc0183..43b1ff3589 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -9,21 +9,21 @@ from inspect import signature import torch + # DISTRIBUTED from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler + from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss -from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator, - setup_generator) +from TTS.vocoder.utils.generic_utils import plot_results, setup_discriminator, setup_generator from TTS.vocoder.utils.io import save_best_model, save_checkpoint use_cuda, num_gpus = setup_torch_training_env(True, True) @@ -32,28 +32,31 @@ def setup_loader(ap, is_val=False, verbose=False): loader = None if not is_val or c.run_eval: - dataset = GANDataset(ap=ap, - items=eval_data if is_val else train_data, - seq_len=c.seq_len, - hop_len=ap.hop_length, - pad_short=c.pad_short, - conv_pad=c.conv_pad, - return_pairs=c.diff_samples_for_G_and_D if 'diff_samples_for_G_and_D' in c else False, - is_training=not is_val, - return_segments=not is_val, - use_noise_augment=c.use_noise_augment, - use_cache=c.use_cache, - verbose=verbose) + dataset = GANDataset( + ap=ap, + items=eval_data if is_val else train_data, + seq_len=c.seq_len, + hop_len=ap.hop_length, + pad_short=c.pad_short, + conv_pad=c.conv_pad, + return_pairs=c.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in c else False, + is_training=not is_val, + return_segments=not is_val, + use_noise_augment=c.use_noise_augment, + use_cache=c.use_cache, + verbose=verbose, + ) dataset.shuffle_mapping() sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None - loader = DataLoader(dataset, - batch_size=1 if is_val else c.batch_size, - shuffle=num_gpus == 0, - drop_last=False, - sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + loader = DataLoader( + dataset, + batch_size=1 if is_val else c.batch_size, + shuffle=num_gpus == 0, + drop_last=False, + sampler=sampler, + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader @@ -74,16 +77,26 @@ def format_data(data): return x, y, None, None -def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, - scheduler_G, scheduler_D, ap, global_step, epoch): +def train( + model_G, + criterion_G, + optimizer_G, + model_D, + criterion_D, + optimizer_D, + scheduler_G, + scheduler_D, + ap, + global_step, + epoch, +): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model_G.train() model_D.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -140,22 +153,23 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, scores_fake = D_out_fake # compute losses - loss_G_dict = criterion_G(y_hat=y_hat, - y=y_G, - scores_fake=scores_fake, - feats_fake=feats_fake, - feats_real=feats_real, - y_hat_sub=y_hat_sub, - y_sub=y_G_sub) - - loss_G = loss_G_dict['G_loss'] + loss_G_dict = criterion_G( + y_hat=y_hat, + y=y_G, + scores_fake=scores_fake, + feats_fake=feats_fake, + feats_real=feats_real, + y_hat_sub=y_hat_sub, + y_sub=y_G_sub, + ) + + loss_G = loss_G_dict["G_loss"] # optimizer generator optimizer_G.zero_grad() loss_G.backward() if c.gen_clip_grad > 0: - torch.nn.utils.clip_grad_norm_(model_G.parameters(), - c.gen_clip_grad) + torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad) optimizer_G.step() loss_dict = dict() @@ -206,14 +220,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, # compute losses loss_D_dict = criterion_D(scores_fake, scores_real) - loss_D = loss_D_dict['D_loss'] + loss_D = loss_D_dict["D_loss"] # optimizer discriminator optimizer_D.zero_grad() loss_D.backward() if c.disc_clip_grad > 0: - torch.nn.utils.clip_grad_norm_(model_D.parameters(), - c.disc_clip_grad) + torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad) optimizer_D.step() for key, value in loss_D_dict.items(): @@ -226,36 +239,31 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, epoch_time += step_time # get current learning rates - current_lr_G = list(optimizer_G.param_groups)[0]['lr'] - current_lr_D = list(optimizer_D.param_groups)[0]['lr'] + current_lr_G = list(optimizer_G.param_groups)[0]["lr"] + current_lr_D = list(optimizer_D.param_groups)[0]["lr"] # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training stats if global_step % c.print_step == 0: log_dict = { - 'step_time': [step_time, 2], - 'loader_time': [loader_time, 4], + "step_time": [step_time, 2], + "loader_time": [loader_time, 4], "current_lr_G": current_lr_G, - "current_lr_D": current_lr_D + "current_lr_D": current_lr_D, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # plot step stats if global_step % 10 == 0: - iter_stats = { - "lr_G": current_lr_G, - "lr_D": current_lr_D, - "step_time": step_time - } + iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) @@ -263,27 +271,26 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model_G, - optimizer_G, - scheduler_G, - model_D, - optimizer_D, - scheduler_D, - global_step, - epoch, - OUT_PATH, - model_losses=loss_dict) + save_checkpoint( + model_G, + optimizer_G, + scheduler_G, + model_D, + optimizer_D, + scheduler_D, + global_step, + epoch, + OUT_PATH, + model_losses=loss_dict, + ) # compute spectrograms - figures = plot_results(y_hat_vis, y_G, ap, global_step, - 'train') + figures = plot_results(y_hat_vis, y_G, ap, global_step, "train") tb_logger.tb_train_figures(global_step, figures) # Sample audio sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy() - tb_logger.tb_train_audios(global_step, - {'train/audio': sample_voice}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"train/audio": sample_voice}, c.audio["sample_rate"]) end_time = time.time() if scheduler_G is not None: @@ -330,7 +337,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) ############################## # generator pass - y_hat = model_G(c_G)[:, :, :y_G.size(2)] + y_hat = model_G(c_G)[:, :, : y_G.size(2)] y_hat_sub = None y_G_sub = None @@ -365,8 +372,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) feats_fake, feats_real = None, None # compute losses - loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, - feats_real, y_hat_sub, y_G_sub) + loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub) loss_dict = dict() for key, value in loss_G_dict.items(): @@ -382,7 +388,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) if global_step >= c.steps_to_start_discriminator: # discriminator pass with torch.no_grad(): - y_hat = model_G(c_G)[:, :, :y_G.size(2)] + y_hat = model_G(c_G)[:, :, : y_G.size(2)] # PQMF formatting if y_hat.shape[1] > 1: @@ -422,9 +428,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): - update_eval_values['avg_' + key] = value - update_eval_values['avg_loader_time'] = loader_time - update_eval_values['avg_step_time'] = step_time + update_eval_values["avg_" + key] = value + update_eval_values["avg_loader_time"] = loader_time + update_eval_values["avg_step_time"] = step_time keep_avg.update_values(update_eval_values) # print eval stats @@ -433,13 +439,12 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) if args.rank == 0: # compute spectrograms - figures = plot_results(y_hat, y_G, ap, global_step, 'eval') + figures = plot_results(y_hat, y_G, ap, global_step, "eval") tb_logger.tb_eval_figures(global_step, figures) # Sample audio sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() - tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"]) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) @@ -455,8 +460,7 @@ def main(args): # pylint: disable=redefined-outer-name print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data( - c.data_path, c.feature_path, c.eval_split_size) + eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) @@ -465,8 +469,7 @@ def main(args): # pylint: disable=redefined-outer-name # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # setup models model_gen = setup_generator(c) @@ -494,67 +497,63 @@ def main(args): # pylint: disable=redefined-outer-name # schedulers scheduler_gen = None scheduler_disc = None - if 'lr_scheduler_gen' in c: + if "lr_scheduler_gen" in c: scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) - scheduler_gen = scheduler_gen( - optimizer_gen, **c.lr_scheduler_gen_params) - if 'lr_scheduler_disc' in c: + scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) + if "lr_scheduler_disc" in c: scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) - scheduler_disc = scheduler_disc( - optimizer_disc, **c.lr_scheduler_disc_params) + scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Generator Model...") - model_gen.load_state_dict(checkpoint['model']) + model_gen.load_state_dict(checkpoint["model"]) print(" > Restoring Generator Optimizer...") - optimizer_gen.load_state_dict(checkpoint['optimizer']) + optimizer_gen.load_state_dict(checkpoint["optimizer"]) print(" > Restoring Discriminator Model...") - model_disc.load_state_dict(checkpoint['model_disc']) + model_disc.load_state_dict(checkpoint["model_disc"]) print(" > Restoring Discriminator Optimizer...") - optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) + optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) # restore schedulers if it is a continuing training. - if args.continue_path != '': - if 'scheduler' in checkpoint and scheduler_gen is not None: + if args.continue_path != "": + if "scheduler" in checkpoint and scheduler_gen is not None: print(" > Restoring Generator LR Scheduler...") - scheduler_gen.load_state_dict(checkpoint['scheduler']) + scheduler_gen.load_state_dict(checkpoint["scheduler"]) # NOTE: Not sure if necessary scheduler_gen.optimizer = optimizer_gen - if 'scheduler_disc' in checkpoint and scheduler_disc is not None: + if "scheduler_disc" in checkpoint and scheduler_disc is not None: print(" > Restoring Discriminator LR Scheduler...") - scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) + scheduler_disc.load_state_dict(checkpoint["scheduler_disc"]) scheduler_disc.optimizer = optimizer_disc if c.lr_scheduler_disc == "ExponentialLR": - scheduler_disc.last_epoch = checkpoint['epoch'] + scheduler_disc.last_epoch = checkpoint["epoch"] except RuntimeError: # restore only matching layers. print(" > Partial model initialization...") model_dict = model_gen.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_gen.load_state_dict(model_dict) model_dict = model_disc.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c) + model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c) model_disc.load_state_dict(model_dict) del model_dict # reset lr if not countinuining training. - if args.continue_path == '': + if args.continue_path == "": for group in optimizer_gen.param_groups: - group['lr'] = c.lr_gen + group["lr"] = c.lr_gen for group in optimizer_disc.param_groups: - group['lr'] = c.lr_disc + group["lr"] = c.lr_disc - print(f" > Model restored from step {checkpoint['step']:d}", - flush=True) - args.restore_step = checkpoint['step'] + print(f" > Model restored from step {checkpoint['step']:d}", flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 - # DISTRUBUTED if num_gpus > 1: model_gen = DDP_th(model_gen, device_ids=[args.rank]) @@ -566,50 +565,55 @@ def main(args): # pylint: disable=redefined-outer-name print(" > Discriminator has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with best loss of {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - _, global_step = train(model_gen, criterion_gen, optimizer_gen, - model_disc, criterion_disc, optimizer_disc, - scheduler_gen, scheduler_disc, ap, global_step, - epoch) - eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, - criterion_disc, ap, - global_step, epoch) + _, global_step = train( + model_gen, + criterion_gen, + optimizer_gen, + model_disc, + criterion_disc, + optimizer_disc, + scheduler_gen, + scheduler_disc, + ap, + global_step, + epoch, + ) + eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] - best_loss = save_best_model(target_loss, - best_loss, - model_gen, - optimizer_gen, - scheduler_gen, - model_disc, - optimizer_disc, - scheduler_disc, - global_step, - epoch, - OUT_PATH, - keep_all_best=keep_all_best, - keep_after=keep_after, - model_losses=eval_avg_loss_dict, - ) - - -if __name__ == '__main__': + best_loss = save_best_model( + target_loss, + best_loss, + model_gen, + optimizer_gen, + scheduler_gen, + model_disc, + optimizer_disc, + scheduler_disc, + global_step, + epoch, + OUT_PATH, + keep_all_best=keep_all_best, + keep_after=keep_after, + model_losses=eval_avg_loss_dict, + ) + + +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='vocoder') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder") try: main(args) diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index 68d76598a4..1f039a6779 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -5,19 +5,20 @@ import sys import time import traceback -import numpy as np +import numpy as np import torch + # DISTRIBUTED from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler + from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset @@ -31,27 +32,29 @@ def setup_loader(ap, is_val=False, verbose=False): if is_val and not c.run_eval: loader = None else: - dataset = WaveGradDataset(ap=ap, - items=eval_data if is_val else train_data, - seq_len=c.seq_len, - hop_len=ap.hop_length, - pad_short=c.pad_short, - conv_pad=c.conv_pad, - is_training=not is_val, - return_segments=True, - use_noise_augment=False, - use_cache=c.use_cache, - verbose=verbose) + dataset = WaveGradDataset( + ap=ap, + items=eval_data if is_val else train_data, + seq_len=c.seq_len, + hop_len=ap.hop_length, + pad_short=c.pad_short, + conv_pad=c.conv_pad, + is_training=not is_val, + return_segments=True, + use_noise_augment=False, + use_cache=c.use_cache, + verbose=verbose, + ) sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader(dataset, - batch_size=c.batch_size, - shuffle=num_gpus <= 1, - drop_last=False, - sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) - + loader = DataLoader( + dataset, + batch_size=c.batch_size, + shuffle=num_gpus <= 1, + drop_last=False, + sampler=sampler, + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader @@ -77,24 +80,21 @@ def format_test_data(data): return m, x -def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, - epoch): +def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() # setup noise schedule - noise_schedule = c['train_noise_schedule'] - betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], - noise_schedule['num_steps']) - if hasattr(model, 'module'): + noise_schedule = c["train_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + if hasattr(model, "module"): model.module.compute_noise_level(betas) else: model.compute_noise_level(betas) @@ -109,7 +109,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, with torch.cuda.amp.autocast(enabled=c.mixed_precision): # compute noisy input - if hasattr(model, 'module'): + if hasattr(model, "module"): noise, x_noisy, noise_scale = model.module.compute_y_n(x) else: noise, x_noisy, noise_scale = model.compute_y_n(x) @@ -119,11 +119,11 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, # compute losses loss = criterion(noise, noise_hat) - loss_wavegrad_dict = {'wavegrad_loss': loss} + loss_wavegrad_dict = {"wavegrad_loss": loss} # check nan loss if torch.isnan(loss).any(): - raise RuntimeError(f'Detected NaN loss at step {global_step}.') + raise RuntimeError(f"Detected NaN loss at step {global_step}.") optimizer.zero_grad() @@ -131,14 +131,12 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, if c.mixed_precision: scaler.scale(loss).backward() scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.clip_grad) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.clip_grad) scaler.step(optimizer) scaler.update() else: loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.clip_grad) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.clip_grad) optimizer.step() # schedule update @@ -158,35 +156,30 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch_time += step_time # get current learning rates - current_lr = list(optimizer.param_groups)[0]['lr'] + current_lr = list(optimizer.param_groups)[0]["lr"] # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training stats if global_step % c.print_step == 0: log_dict = { - 'step_time': [step_time, 2], - 'loader_time': [loader_time, 4], + "step_time": [step_time, 2], + "loader_time": [loader_time, 4], "current_lr": current_lr, - "grad_norm": grad_norm.item() + "grad_norm": grad_norm.item(), } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # plot step stats if global_step % 10 == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm.item(), - "step_time": step_time - } + iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) @@ -205,7 +198,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch, OUT_PATH, model_losses=loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) end_time = time.time() @@ -242,19 +235,17 @@ def evaluate(model, criterion, ap, global_step, epoch): global_step += 1 # compute noisy input - if hasattr(model, 'module'): + if hasattr(model, "module"): noise, x_noisy, noise_scale = model.module.compute_y_n(x) else: noise, x_noisy, noise_scale = model.compute_y_n(x) - # forward pass noise_hat = model(x_noisy, m, noise_scale) # compute losses loss = criterion(noise, noise_hat) - loss_wavegrad_dict = {'wavegrad_loss': loss} - + loss_wavegrad_dict = {"wavegrad_loss": loss} loss_dict = dict() for key, value in loss_wavegrad_dict.items(): @@ -269,9 +260,9 @@ def evaluate(model, criterion, ap, global_step, epoch): # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): - update_eval_values['avg_' + key] = value - update_eval_values['avg_loader_time'] = loader_time - update_eval_values['avg_step_time'] = step_time + update_eval_values["avg_" + key] = value + update_eval_values["avg_loader_time"] = loader_time + update_eval_values["avg_step_time"] = step_time keep_avg.update_values(update_eval_values) # print eval stats @@ -284,11 +275,9 @@ def evaluate(model, criterion, ap, global_step, epoch): m, x = format_test_data(samples[0]) # setup noise schedule and inference - noise_schedule = c['test_noise_schedule'] - betas = np.linspace(noise_schedule['min_val'], - noise_schedule['max_val'], - noise_schedule['num_steps']) - if hasattr(model, 'module'): + noise_schedule = c["test_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + if hasattr(model, "module"): model.module.compute_noise_level(betas) # compute voice x_pred = model.module.inference(m) @@ -298,13 +287,12 @@ def evaluate(model, criterion, ap, global_step, epoch): x_pred = model.inference(m) # compute spectrograms - figures = plot_results(x_pred, x, ap, global_step, 'eval') + figures = plot_results(x_pred, x, ap, global_step, "eval") tb_logger.tb_eval_figures(global_step, figures) # Sample audio sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy() - tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"]) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) data_loader.dataset.return_segments = True @@ -318,8 +306,7 @@ def main(args): # pylint: disable=redefined-outer-name print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, - c.eval_split_size) + eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) @@ -328,8 +315,7 @@ def main(args): # pylint: disable=redefined-outer-name # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # setup models model = setup_generator(c) @@ -342,7 +328,7 @@ def main(args): # pylint: disable=redefined-outer-name # schedulers scheduler = None - if 'lr_scheduler' in c: + if "lr_scheduler" in c: scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) scheduler = scheduler(optimizer, **c.lr_scheduler_params) @@ -355,15 +341,15 @@ def main(args): # pylint: disable=redefined-outer-name if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Model...") - model.load_state_dict(checkpoint['model']) + model.load_state_dict(checkpoint["model"]) print(" > Restoring Optimizer...") - optimizer.load_state_dict(checkpoint['optimizer']) - if 'scheduler' in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer"]) + if "scheduler" in checkpoint: print(" > Restoring LR Scheduler...") - scheduler.load_state_dict(checkpoint['scheduler']) + scheduler.load_state_dict(checkpoint["scheduler"]) # NOTE: Not sure if necessary scheduler.optimizer = optimizer if "scaler" in checkpoint and c.mixed_precision: @@ -373,17 +359,16 @@ def main(args): # pylint: disable=redefined-outer-name # retore only matching layers. print(" > Partial model initialization...") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict # reset lr if not countinuining training. for group in optimizer.param_groups: - group['lr'] = c.lr + group["lr"] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -395,22 +380,19 @@ def main(args): # pylint: disable=redefined-outer-name print(" > WaveGrad has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - _, global_step = train(model, criterion, optimizer, scheduler, scaler, - ap, global_step, epoch) + _, global_step = train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] @@ -429,14 +411,13 @@ def main(args): # pylint: disable=redefined-outer-name keep_all_best=keep_all_best, keep_after=keep_after, model_losses=eval_avg_loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='vocoder') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder") try: main(args) diff --git a/TTS/bin/train_vocoder_wavernn.py b/TTS/bin/train_vocoder_wavernn.py index 6b75405a52..3f6f5836df 100644 --- a/TTS/bin/train_vocoder_wavernn.py +++ b/TTS/bin/train_vocoder_wavernn.py @@ -2,36 +2,29 @@ """Train WaveRNN vocoder model.""" import os +import random import sys -import traceback import time -import random +import traceback import torch from torch.utils.data import DataLoader -# from torch.utils.data.distributed import DistributedSampler - -from TTS.utils.arguments import parse_arguments, process_args from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam from TTS.utils.training import setup_torch_training_env -from TTS.utils.generic_utils import ( - KeepAverage, - count_parameters, - remove_experiment_folder, - set_init_dict, -) +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset -from TTS.vocoder.datasets.preprocess import ( - load_wav_data, - load_wav_feat_data -) from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss from TTS.vocoder.utils.generic_utils import setup_generator from TTS.vocoder.utils.io import save_best_model, save_checkpoint +# from torch.utils.data.distributed import DistributedSampler + + use_cuda, num_gpus = setup_torch_training_env(True, True) @@ -40,26 +33,26 @@ def setup_loader(ap, is_val=False, verbose=False): if is_val and not c.run_eval: loader = None else: - dataset = WaveRNNDataset(ap=ap, - items=eval_data if is_val else train_data, - seq_len=c.seq_len, - hop_len=ap.hop_length, - pad=c.padding, - mode=c.mode, - mulaw=c.mulaw, - is_training=not is_val, - verbose=verbose, - ) + dataset = WaveRNNDataset( + ap=ap, + items=eval_data if is_val else train_data, + seq_len=c.seq_len, + hop_len=ap.hop_length, + pad=c.padding, + mode=c.mode, + mulaw=c.mulaw, + is_training=not is_val, + verbose=verbose, + ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader(dataset, - shuffle=True, - collate_fn=dataset.collate, - batch_size=c.batch_size, - num_workers=c.num_val_loader_workers - if is_val - else c.num_loader_workers, - pin_memory=True, - ) + loader = DataLoader( + dataset, + shuffle=True, + collate_fn=dataset.collate, + batch_size=c.batch_size, + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=True, + ) return loader @@ -85,8 +78,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / - (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -114,8 +106,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch scaler.scale(loss).backward() scaler.unscale_(optimizer) if c.grad_clip > 0: - torch.nn.utils.clip_grad_norm_( - model.parameters(), c.grad_clip) + torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: @@ -132,8 +123,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch raise RuntimeError(" [!] None loss. Exiting ...") loss.backward() if c.grad_clip > 0: - torch.nn.utils.clip_grad_norm_( - model.parameters(), c.grad_clip) + torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) optimizer.step() if scheduler is not None: @@ -156,17 +146,19 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch # print training stats if global_step % c.print_step == 0: - log_dict = {"step_time": [step_time, 2], - "loader_time": [loader_time, 4], - "current_lr": cur_lr, - } - c_logger.print_train_step(batch_n_iter, - num_iter, - global_step, - log_dict, - loss_dict, - keep_avg.avg_values, - ) + log_dict = { + "step_time": [step_time, 2], + "loader_time": [loader_time, 4], + "current_lr": cur_lr, + } + c_logger.print_train_step( + batch_n_iter, + num_iter, + global_step, + log_dict, + loss_dict, + keep_avg.avg_values, + ) # plot step stats if global_step % 10 == 0: @@ -189,36 +181,36 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch epoch, OUT_PATH, model_losses=loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) # synthesize a full voice rand_idx = random.randrange(0, len(train_data)) - wav_path = train_data[rand_idx] if not isinstance( - train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] + wav_path = ( + train_data[rand_idx] if not isinstance(train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] + ) wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) ground_mel = torch.FloatTensor(ground_mel) if use_cuda: ground_mel = ground_mel.cuda(non_blocking=True) - sample_wav = model.inference(ground_mel, - c.batched, - c.target_samples, - c.overlap_samples, - ) + sample_wav = model.inference( + ground_mel, + c.batched, + c.target_samples, + c.overlap_samples, + ) predict_mel = ap.melspectrogram(sample_wav) # compute spectrograms - figures = {"train/ground_truth": plot_spectrogram(ground_mel.T), - "train/prediction": plot_spectrogram(predict_mel.T) - } + figures = { + "train/ground_truth": plot_spectrogram(ground_mel.T), + "train/prediction": plot_spectrogram(predict_mel.T), + } tb_logger.tb_train_figures(global_step, figures) # Sample audio - tb_logger.tb_train_audios( - global_step, { - "train/audio": sample_wav}, c.audio["sample_rate"] - ) + tb_logger.tb_train_audios(global_step, {"train/audio": sample_wav}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -277,36 +269,32 @@ def evaluate(model, criterion, ap, global_step, epoch): # print eval stats if c.print_eval: - c_logger.print_eval_step( - num_iter, loss_dict, keep_avg.avg_values) + c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) if epoch % c.test_every_epochs == 0 and epoch != 0: # synthesize a full voice rand_idx = random.randrange(0, len(eval_data)) - wav_path = eval_data[rand_idx] if not isinstance( - eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] + wav_path = eval_data[rand_idx] if not isinstance(eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) ground_mel = torch.FloatTensor(ground_mel) if use_cuda: ground_mel = ground_mel.cuda(non_blocking=True) - sample_wav = model.inference(ground_mel, - c.batched, - c.target_samples, - c.overlap_samples, - ) + sample_wav = model.inference( + ground_mel, + c.batched, + c.target_samples, + c.overlap_samples, + ) predict_mel = ap.melspectrogram(sample_wav) # Sample audio - tb_logger.tb_eval_audios( - global_step, { - "eval/audio": sample_wav}, c.audio["sample_rate"] - ) + tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_wav}, c.audio["sample_rate"]) # compute spectrograms figures = { "eval/ground_truth": plot_spectrogram(ground_mel.T), - "eval/prediction": plot_spectrogram(predict_mel.T) + "eval/prediction": plot_spectrogram(predict_mel.T), } tb_logger.tb_eval_figures(global_step, figures) @@ -347,11 +335,9 @@ def main(args): # pylint: disable=redefined-outer-name print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data( - c.data_path, c.feature_path, c.eval_split_size) + eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: - eval_data, train_data = load_wav_data( - c.data_path, c.eval_split_size) + eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) # setup model model_wavernn = setup_generator(c) @@ -404,8 +390,7 @@ def main(args): # pylint: disable=redefined-outer-name model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_wavernn.load_state_dict(model_dict) - print(" > Model restored from step %d" % - checkpoint["step"], flush=True) + print(" > Model restored from step %d" % checkpoint["step"], flush=True) args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -418,24 +403,20 @@ def main(args): # pylint: disable=redefined-outer-name print(" > Model has {} parameters".format(num_parameters), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - _, global_step = train(model_wavernn, optimizer, - criterion, scheduler, scaler, ap, global_step, epoch) - eval_avg_loss_dict = evaluate( - model_wavernn, criterion, ap, global_step, epoch) + _, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch) + eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict["avg_model_loss"] best_loss = save_best_model( @@ -453,14 +434,13 @@ def main(args): # pylint: disable=redefined-outer-name keep_all_best=keep_all_best, keep_after=keep_after, model_losses=eval_avg_loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='vocoder') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder") try: main(args) diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index 436a276467..a31d6c4548 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -6,6 +6,7 @@ import torch from torch.utils.data import DataLoader from tqdm import tqdm + from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config from TTS.vocoder.datasets.preprocess import load_wav_data @@ -13,14 +14,21 @@ from TTS.vocoder.utils.generic_utils import setup_generator parser = argparse.ArgumentParser() -parser.add_argument('--model_path', type=str, help='Path to model checkpoint.') -parser.add_argument('--config_path', type=str, help='Path to model config file.') -parser.add_argument('--data_path', type=str, help='Path to data directory.') -parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.') -parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.') -parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.') -parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.') -parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.') +parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") +parser.add_argument("--config_path", type=str, help="Path to model config file.") +parser.add_argument("--data_path", type=str, help="Path to data directory.") +parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") +parser.add_argument( + "--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for." +) +parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.") +parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") +parser.add_argument( + "--search_depth", + type=int, + default=3, + help="Search granularity. Increasing this increases the run-time exponentially.", +) # load config args = parser.parse_args() @@ -31,18 +39,20 @@ # load dataset _, train_data = load_wav_data(args.data_path, 0) -train_data = train_data[:args.num_samples] -dataset = WaveGradDataset(ap=ap, - items=train_data, - seq_len=-1, - hop_len=ap.hop_length, - pad_short=config.pad_short, - conv_pad=config.conv_pad, - is_training=True, - return_segments=False, - use_noise_augment=False, - use_cache=False, - verbose=True) +train_data = train_data[: args.num_samples] +dataset = WaveGradDataset( + ap=ap, + items=train_data, + seq_len=-1, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=True, + return_segments=False, + use_noise_augment=False, + use_cache=False, + verbose=True, +) loader = DataLoader( dataset, batch_size=1, @@ -50,7 +60,8 @@ collate_fn=dataset.collate_full_clips, drop_last=False, num_workers=config.num_loader_workers, - pin_memory=False) + pin_memory=False, +) # setup the model model = setup_generator(config) @@ -61,9 +72,9 @@ base_values = sorted(10 * np.random.uniform(size=args.search_depth)) print(base_values) exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) -best_error = float('inf') +best_error = float("inf") best_schedule = None -total_search_iter = len(base_values)**args.num_iter +total_search_iter = len(base_values) ** args.num_iter for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): beta = exponents * base model.compute_noise_level(beta) @@ -84,6 +95,6 @@ mse = torch.sum((mel - mel_hat) ** 2).mean() if mse.item() < best_error: best_error = mse.item() - best_schedule = {'beta': beta} + best_schedule = {"beta": beta} print(f" > Found a better schedule. - MSE: {mse.item()}") np.save(args.output_path, best_schedule) diff --git a/TTS/server/server.py b/TTS/server/server.py index 05960f88de..9e533dc6f5 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -1,35 +1,53 @@ #!flask/bin/python import argparse +import io import os import sys -import io from pathlib import Path from flask import Flask, render_template, request, send_file -from TTS.utils.synthesizer import Synthesizer -from TTS.utils.manage import ModelManager + from TTS.utils.io import load_config +from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer def create_argparser(): def convert_boolean(x): - return x.lower() in ['true', '1', 'yes'] + return x.lower() in ["true", "1", "yes"] parser = argparse.ArgumentParser() - parser.add_argument('--list_models', type=convert_boolean, nargs='?', const=True, default=False, help='list available pre-trained tts and vocoder models.') - parser.add_argument('--model_name', type=str, default="tts_models/en/ljspeech/speedy-speech-wn", help='name of one of the released tts models.') - parser.add_argument('--vocoder_name', type=str, default=None, help='name of one of the released vocoder models.') - parser.add_argument('--tts_checkpoint', type=str, help='path to custom tts checkpoint file') - parser.add_argument('--tts_config', type=str, help='path to custom tts config.json file') - parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model') - parser.add_argument('--vocoder_config', type=str, default=None, help='path to vocoder config file.') - parser.add_argument('--vocoder_checkpoint', type=str, default=None, help='path to vocoder checkpoint file.') - parser.add_argument('--port', type=int, default=5002, help='port to listen on.') - parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') - parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') - parser.add_argument('--show_details', type=convert_boolean, default=False, help='Generate model detail page.') + parser.add_argument( + "--list_models", + type=convert_boolean, + nargs="?", + const=True, + default=False, + help="list available pre-trained tts and vocoder models.", + ) + parser.add_argument( + "--model_name", + type=str, + default="tts_models/en/ljspeech/speedy-speech-wn", + help="name of one of the released tts models.", + ) + parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.") + parser.add_argument("--tts_checkpoint", type=str, help="path to custom tts checkpoint file") + parser.add_argument("--tts_config", type=str, help="path to custom tts config.json file") + parser.add_argument( + "--tts_speakers", + type=str, + help="path to JSON file containing speaker ids, if speaker ids are used in the model", + ) + parser.add_argument("--vocoder_config", type=str, default=None, help="path to vocoder config file.") + parser.add_argument("--vocoder_checkpoint", type=str, default=None, help="path to vocoder checkpoint file.") + parser.add_argument("--port", type=int, default=5002, help="port to listen on.") + parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.") + parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.") + parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.") return parser + # parse the args args = create_argparser().parse_args() @@ -43,7 +61,7 @@ def convert_boolean(x): # update in-use models to the specified released models. if args.model_name is not None: tts_checkpoint_file, tts_config_file, tts_json_dict = manager.download_model(args.model_name) - args.vocoder_name = tts_json_dict['default_vocoder'] if args.vocoder_name is None else args.vocoder_name + args.vocoder_name = tts_json_dict["default_vocoder"] if args.vocoder_name is None else args.vocoder_name if args.vocoder_name is not None: vocoder_checkpoint_file, vocoder_config_file, vocoder_json_dict = manager.download_model(args.vocoder_name) @@ -59,16 +77,19 @@ def convert_boolean(x): if not args.vocoder_config and os.path.isfile(vocoder_config_file): args.vocoder_config = vocoder_config_file -synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda) +synthesizer = Synthesizer( + args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda +) app = Flask(__name__) -@app.route('/') +@app.route("/") def index(): - return render_template('index.html', show_details=args.show_details) + return render_template("index.html", show_details=args.show_details) -@app.route('/details') + +@app.route("/details") def details(): model_config = load_config(args.tts_config) if args.vocoder_config is not None and os.path.isfile(args.vocoder_config): @@ -76,26 +97,28 @@ def details(): else: vocoder_config = None - return render_template('details.html', - show_details=args.show_details - , model_config=model_config - , vocoder_config=vocoder_config - , args=args.__dict__ - ) + return render_template( + "details.html", + show_details=args.show_details, + model_config=model_config, + vocoder_config=vocoder_config, + args=args.__dict__, + ) + -@app.route('/api/tts', methods=['GET']) +@app.route("/api/tts", methods=["GET"]) def tts(): - text = request.args.get('text') + text = request.args.get("text") print(" > Model input: {}".format(text)) wavs = synthesizer.tts(text) out = io.BytesIO() synthesizer.save_wav(wavs, out) - return send_file(out, mimetype='audio/wav') + return send_file(out, mimetype="audio/wav") def main(): - app.run(debug=args.debug, host='0.0.0.0', port=args.port) + app.run(debug=args.debug, host="0.0.0.0", port=args.port) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index 748f513663..38d8b5f9f6 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -7,9 +7,19 @@ class MyDataset(Dataset): - def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64, - storage_size=1, sample_from_storage_p=0.5, additive_noise=0, - num_utter_per_speaker=10, skip_speakers=False, verbose=False): + def __init__( + self, + ap, + meta_data, + voice_len=1.6, + num_speakers_in_batch=64, + storage_size=1, + sample_from_storage_p=0.5, + additive_noise=0, + num_utter_per_speaker=10, + skip_speakers=False, + verbose=False, + ): """ Args: ap (TTS.tts.utils.AudioProcessor): audio processor object. @@ -28,7 +38,7 @@ def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64, self.ap = ap self.verbose = verbose self.__parse_items() - self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch) + self.storage = queue.Queue(maxsize=storage_size * num_speakers_in_batch) self.sample_from_storage_p = float(sample_from_storage_p) self.additive_noise = float(additive_noise) if self.verbose: @@ -69,11 +79,14 @@ def __parse_items(self): if speaker_ in self.speaker_to_utters.keys(): self.speaker_to_utters[speaker_].append(path_) else: - self.speaker_to_utters[speaker_] = [path_, ] + self.speaker_to_utters[speaker_] = [ + path_, + ] if self.skip_speakers: - self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if - len(v) >= self.num_utter_per_speaker} + self.speaker_to_utters = { + k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker + } self.speakers = [k for (k, v) in self.speaker_to_utters.items()] @@ -100,13 +113,9 @@ def __len__(self): def __sample_speaker(self): speaker = random.sample(self.speakers, 1)[0] if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): - utters = random.choices( - self.speaker_to_utters[speaker], k=self.num_utter_per_speaker - ) + utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker) else: - utters = random.sample( - self.speaker_to_utters[speaker], self.num_utter_per_speaker - ) + utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker) return speaker, utters def __sample_speaker_utterances(self, speaker): @@ -160,7 +169,9 @@ def collate_fn(self, batch): # get a random subset of each of the wavs and convert to MFCC. offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_] - mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))] + mels_ = [ + self.ap.melspectrogram(wavs_[i][offsets_[i] : offsets_[i] + self.seq_len]) for i in range(len(wavs_)) + ] feats_ = [torch.FloatTensor(mel) for mel in mels_] labels.append(labels_) diff --git a/TTS/speaker_encoder/losses.py b/TTS/speaker_encoder/losses.py index fc085674ff..d683df01eb 100644 --- a/TTS/speaker_encoder/losses.py +++ b/TTS/speaker_encoder/losses.py @@ -23,7 +23,7 @@ def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): self.b = nn.Parameter(torch.tensor(init_b)) self.loss_method = loss_method - print(' > Initialised Generalized End-to-End loss') + print(" > Initialised Generalized End-to-End loss") assert self.loss_method in ["softmax", "contrast"] @@ -55,9 +55,7 @@ def calc_cosine_sim(self, dvecs, centroids): for spkr_idx, speaker in enumerate(dvecs): cs_row = [] for utt_idx, utterance in enumerate(speaker): - new_centroids = self.calc_new_centroids( - dvecs, centroids, spkr_idx, utt_idx - ) + new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx) # vector based cosine similarity for speed cs_row.append( torch.clamp( @@ -99,14 +97,8 @@ def embed_loss_contrast(self, dvecs, cos_sim_matrix): L_row = [] for i in range(M): centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) - excl_centroids_sigmoids = torch.cat( - (centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]) - ) - L_row.append( - 1.0 - - torch.sigmoid(cos_sim_matrix[j, i, j]) - + torch.max(excl_centroids_sigmoids) - ) + excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])) + L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids)) L_row = torch.stack(L_row) L.append(L_row) return torch.stack(L) @@ -122,6 +114,7 @@ def forward(self, dvecs): L = self.embed_loss(dvecs, cos_sim_matrix) return L.mean() + # adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py class AngleProtoLoss(nn.Module): """ @@ -134,6 +127,7 @@ class AngleProtoLoss(nn.Module): - init_w (float): defines the initial value of w - init_b (float): definies the initial value of b """ + def __init__(self, init_w=10.0, init_b=-5.0): super(AngleProtoLoss, self).__init__() # pylint: disable=E1102 @@ -142,7 +136,7 @@ def __init__(self, init_w=10.0, init_b=-5.0): self.b = nn.Parameter(torch.tensor(init_b)) self.criterion = torch.nn.CrossEntropyLoss() - print(' > Initialised Angular Prototypical loss') + print(" > Initialised Angular Prototypical loss") def forward(self, x): """ @@ -152,7 +146,10 @@ def forward(self, x): out_positive = x[:, 0, :] num_speakers = out_anchor.size()[0] - cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2)) + cos_sim_matrix = F.cosine_similarity( + out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), + out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2), + ) torch.clamp(self.w, 1e-6) cos_sim_matrix = cos_sim_matrix * self.w + self.b label = torch.arange(num_speakers).to(cos_sim_matrix.device) diff --git a/TTS/speaker_encoder/model.py b/TTS/speaker_encoder/model.py index 322ee42f78..7a3dc09ce2 100644 --- a/TTS/speaker_encoder/model.py +++ b/TTS/speaker_encoder/model.py @@ -16,19 +16,19 @@ def forward(self, x): o, (_, _) = self.lstm(x) return self.linear(o) + class LSTMWithoutProjection(nn.Module): def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): super().__init__() - self.lstm = nn.LSTM(input_size=input_dim, - hidden_size=lstm_dim, - num_layers=num_lstm_layers, - batch_first=True) + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) self.relu = nn.ReLU() + def forward(self, x): _, (hidden, _) = self.lstm(x) return self.relu(self.linear(hidden[-1])) + class SpeakerEncoder(nn.Module): def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): super().__init__() @@ -106,7 +106,5 @@ def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5): if embed is None: embed = self.inference(frames) else: - embed[cur_iter <= num_iters, :] += self.inference( - frames[cur_iter <= num_iters, :, :] - ) + embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) return embed / num_iters diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index 47bf79cc30..c9bfa67989 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -3,114 +3,121 @@ import re import torch + from TTS.speaker_encoder.model import SpeakerEncoder from TTS.utils.generic_utils import check_argument def to_camel(text): text = text.capitalize() - return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) def setup_model(c): - model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'], - c.model['lstm_dim'], c.model['num_lstm_layers']) + model = SpeakerEncoder(c.model["input_dim"], c.model["proj_dim"], c.model["lstm_dim"], c.model["num_lstm_layers"]) return model -def save_checkpoint(model, optimizer, model_loss, out_path, - current_step, epoch): - checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) +def save_checkpoint(model, optimizer, model_loss, out_path, current_step, epoch): + checkpoint_path = "checkpoint_{}.pth.tar".format(current_step) checkpoint_path = os.path.join(out_path, checkpoint_path) print(" | | > Checkpoint saving : {}".format(checkpoint_path)) new_state_dict = model.state_dict() state = { - 'model': new_state_dict, - 'optimizer': optimizer.state_dict() if optimizer is not None else None, - 'step': current_step, - 'epoch': epoch, - 'loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y"), + "model": new_state_dict, + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "step": current_step, + "epoch": epoch, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), } torch.save(state, checkpoint_path) -def save_best_model(model, optimizer, model_loss, best_loss, out_path, - current_step): +def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): if model_loss < best_loss: new_state_dict = model.state_dict() state = { - 'model': new_state_dict, - 'optimizer': optimizer.state_dict(), - 'step': current_step, - 'loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y"), + "model": new_state_dict, + "optimizer": optimizer.state_dict(), + "step": current_step, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), } best_loss = model_loss - bestmodel_path = 'best_model.pth.tar' + bestmodel_path = "best_model.pth.tar" bestmodel_path = os.path.join(out_path, bestmodel_path) - print("\n > BEST MODEL ({0:.5f}) : {1:}".format( - model_loss, bestmodel_path)) + print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) torch.save(state, bestmodel_path) return best_loss def check_config_speaker_encoder(c): """Check the config.json file of the speaker encoder""" - check_argument('run_name', c, restricted=True, val_type=str) - check_argument('run_description', c, val_type=str) + check_argument("run_name", c, restricted=True, val_type=str) + check_argument("run_description", c, val_type=str) # audio processing parameters - check_argument('audio', c, restricted=True, val_type=dict) - check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) - check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) - check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) - check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') - check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') - check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) - check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) - check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) - check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) - check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + check_argument("audio", c, restricted=True, val_type=dict) + check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056) + check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058) + check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c["audio"], + restricted=True, + val_type=float, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument( + "frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length" + ) + check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1) + check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10) + check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000) + check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000) # training parameters - check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str) - check_argument('grad_clip', c, restricted=True, val_type=float) - check_argument('epochs', c, restricted=True, val_type=int, min_val=1) - check_argument('lr', c, restricted=True, val_type=float, min_val=0) - check_argument('lr_decay', c, restricted=True, val_type=bool) - check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) - check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) - check_argument('num_speakers_in_batch', c, restricted=True, val_type=int) - check_argument('num_loader_workers', c, restricted=True, val_type=int) - check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0) + check_argument("loss", c, enum_list=["ge2e", "angleproto"], restricted=True, val_type=str) + check_argument("grad_clip", c, restricted=True, val_type=float) + check_argument("epochs", c, restricted=True, val_type=int, min_val=1) + check_argument("lr", c, restricted=True, val_type=float, min_val=0) + check_argument("lr_decay", c, restricted=True, val_type=bool) + check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0) + check_argument("tb_model_param_stats", c, restricted=True, val_type=bool) + check_argument("num_speakers_in_batch", c, restricted=True, val_type=int) + check_argument("num_loader_workers", c, restricted=True, val_type=int) + check_argument("wd", c, restricted=True, val_type=float, min_val=0.0, max_val=1.0) # checkpoint and output parameters - check_argument('steps_plot_stats', c, restricted=True, val_type=int) - check_argument('checkpoint', c, restricted=True, val_type=bool) - check_argument('save_step', c, restricted=True, val_type=int) - check_argument('print_step', c, restricted=True, val_type=int) - check_argument('output_path', c, restricted=True, val_type=str) + check_argument("steps_plot_stats", c, restricted=True, val_type=int) + check_argument("checkpoint", c, restricted=True, val_type=bool) + check_argument("save_step", c, restricted=True, val_type=int) + check_argument("print_step", c, restricted=True, val_type=int) + check_argument("output_path", c, restricted=True, val_type=str) # model parameters - check_argument('model', c, restricted=True, val_type=dict) - check_argument('input_dim', c['model'], restricted=True, val_type=int) - check_argument('proj_dim', c['model'], restricted=True, val_type=int) - check_argument('lstm_dim', c['model'], restricted=True, val_type=int) - check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int) - check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool) + check_argument("model", c, restricted=True, val_type=dict) + check_argument("input_dim", c["model"], restricted=True, val_type=int) + check_argument("proj_dim", c["model"], restricted=True, val_type=int) + check_argument("lstm_dim", c["model"], restricted=True, val_type=int) + check_argument("num_lstm_layers", c["model"], restricted=True, val_type=int) + check_argument("use_lstm_with_projection", c["model"], restricted=True, val_type=bool) # in-memory storage parameters - check_argument('storage', c, restricted=True, val_type=dict) - check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) - check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100) - check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + check_argument("storage", c, restricted=True, val_type=dict) + check_argument("sample_from_storage_p", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + check_argument("storage_size", c["storage"], restricted=True, val_type=int, min_val=1, max_val=100) + check_argument("additive_noise", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0) # datasets - checking only the first entry - check_argument('datasets', c, restricted=True, val_type=list) - for dataset_entry in c['datasets']: - check_argument('name', dataset_entry, restricted=True, val_type=str) - check_argument('path', dataset_entry, restricted=True, val_type=str) - check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) - check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) + check_argument("datasets", c, restricted=True, val_type=list) + for dataset_entry in c["datasets"]: + check_argument("name", dataset_entry, restricted=True, val_type=str) + check_argument("path", dataset_entry, restricted=True, val_type=str) + check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list]) + check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str) diff --git a/TTS/speaker_encoder/utils/prepare_voxceleb.py b/TTS/speaker_encoder/utils/prepare_voxceleb.py index 758e1cb3b3..1901a21c73 100644 --- a/TTS/speaker_encoder/utils/prepare_voxceleb.py +++ b/TTS/speaker_encoder/utils/prepare_voxceleb.py @@ -17,55 +17,54 @@ # Only support eager mode and TF>=2.0.0 # pylint: disable=no-member, invalid-name, relative-beyond-top-level # pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes -''' voxceleb 1 & 2 ''' +""" voxceleb 1 & 2 """ +import hashlib import os +import subprocess import sys import zipfile -import subprocess -import hashlib + import pandas -from absl import logging -import tensorflow as tf import soundfile as sf +import tensorflow as tf +from absl import logging gfile = tf.compat.v1.gfile SUBSETS = { - "vox1_dev_wav": - ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad"], - "vox1_test_wav": - ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], - "vox2_dev_aac": - ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah"], - "vox2_test_aac": - ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"] + "vox1_dev_wav": [ + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad", + ], + "vox1_test_wav": ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], + "vox2_dev_aac": [ + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah", + ], + "vox2_test_aac": ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"], } MD5SUM = { "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", "vox1_test_wav": "185fdc63c3c739954633d50379a3d102", - "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312" + "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312", } -USER = { - "user": "", - "password": "" -} +USER = {"user": "", "password": ""} speaker_id_dict = {} + def download_and_extract(directory, subset, urls): """Download and extract the given split of dataset. @@ -83,31 +82,30 @@ def download_and_extract(directory, subset, urls): if os.path.exists(zip_filepath): continue logging.info("Downloading %s to %s" % (url, zip_filepath)) - subprocess.call('wget %s --user %s --password %s -O %s' % - (url, USER["user"], USER["password"], zip_filepath), shell=True) + subprocess.call( + "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), + shell=True, + ) statinfo = os.stat(zip_filepath) - logging.info( - "Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size) - ) + logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) # concatenate all parts into zip files if ".zip" not in zip_filepath: zip_filepath = "_".join(zip_filepath.split("_")[:-1]) - subprocess.call('cat %s* > %s.zip' % - (zip_filepath, zip_filepath), shell=True) + subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True) zip_filepath += ".zip" extract_path = zip_filepath.strip(".zip") # check zip file md5sum - md5 = hashlib.md5(open(zip_filepath, 'rb').read()).hexdigest() + md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest() if md5 != MD5SUM[subset]: raise ValueError("md5sum of %s mismatch" % zip_filepath) with zipfile.ZipFile(zip_filepath, "r") as zfile: zfile.extractall(directory) extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) - subprocess.call('mv %s %s' % (extract_path_ori, extract_path), shell=True) + subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True) finally: # gfile.Remove(zip_filepath) pass @@ -148,8 +146,7 @@ def decode_aac_with_ffmpeg(aac_file, wav_file): return True -def convert_audio_and_make_label(input_dir, subset, - output_dir, output_file): +def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): """Optionally convert AAC to WAV and make speaker labels. Args: input_dir: the directory which holds the input dataset. @@ -167,7 +164,7 @@ def convert_audio_and_make_label(input_dir, subset, for filename in filenames: name, ext = os.path.splitext(filename) if ext.lower() == ".wav": - _, ext2 = (os.path.splitext(name)) + _, ext2 = os.path.splitext(name) if ext2: continue wav_file = os.path.join(root, filename) @@ -186,15 +183,12 @@ def convert_audio_and_make_label(input_dir, subset, speaker_id_dict[speaker_name] = num # wav_filesize = os.path.getsize(wav_file) wav_length = len(sf.read(wav_file)[0]) - files.append( - (os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name) - ) + files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)) # Write to CSV file which contains four columns: # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". csv_file_path = os.path.join(output_dir, output_file) - df = pandas.DataFrame( - data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) + df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) df.to_csv(csv_file_path, index=False, sep="\t") logging.info("Successfully generated csv file {}".format(csv_file_path)) @@ -205,19 +199,14 @@ def processor(directory, subset, force_process): if subset not in urls: raise ValueError(subset, "is not in voxceleb") - subset_csv = os.path.join(directory, subset + '.csv') + subset_csv = os.path.join(directory, subset + ".csv") if not force_process and os.path.exists(subset_csv): return subset_csv logging.info("Downloading and process the voxceleb in %s", directory) logging.info("Preparing subset %s", subset) download_and_extract(directory, subset, urls[subset]) - convert_audio_and_make_label( - directory, - subset, - directory, - subset + ".csv" - ) + convert_audio_and_make_label(directory, subset, directory, subset + ".csv") logging.info("Finished downloading and processing") return subset_csv diff --git a/TTS/speaker_encoder/utils/visual.py b/TTS/speaker_encoder/utils/visual.py index 68c48f1234..4f40f68c9d 100644 --- a/TTS/speaker_encoder/utils/visual.py +++ b/TTS/speaker_encoder/utils/visual.py @@ -1,7 +1,7 @@ -import umap -import numpy as np import matplotlib import matplotlib.pyplot as plt +import numpy as np +import umap matplotlib.use("Agg") diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index eaabb42b53..6d055bf756 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -7,31 +7,32 @@ import torch import tqdm from torch.utils.data import Dataset -from TTS.tts.utils.data import (prepare_data, prepare_stop_target, - prepare_tensor) -from TTS.tts.utils.text import (pad_with_eos_bos, phoneme_to_sequence, - text_to_sequence) + +from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor +from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence class MyDataset(Dataset): - def __init__(self, - outputs_per_step, - text_cleaner, - compute_linear_spec, - ap, - meta_data, - tp=None, - add_blank=False, - batch_group_size=0, - min_seq_len=0, - max_seq_len=float("inf"), - use_phonemes=True, - phoneme_cache_path=None, - phoneme_language="en-us", - enable_eos_bos=False, - speaker_mapping=None, - use_noise_augment=False, - verbose=False): + def __init__( + self, + outputs_per_step, + text_cleaner, + compute_linear_spec, + ap, + meta_data, + tp=None, + add_blank=False, + batch_group_size=0, + min_seq_len=0, + max_seq_len=float("inf"), + use_phonemes=True, + phoneme_cache_path=None, + phoneme_language="en-us", + enable_eos_bos=False, + speaker_mapping=None, + use_noise_augment=False, + verbose=False, + ): """ Args: outputs_per_step (int): number of time frames predicted per step. @@ -88,45 +89,42 @@ def load_wav(self, filename): @staticmethod def load_np(filename): - data = np.load(filename).astype('float32') + data = np.load(filename).astype("float32") return data @staticmethod - def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, - language, tp, add_blank): + def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank): """generate a phoneme sequence from text. since the usage is for subsequent caching, we never add bos and eos chars here. Instead we add those dynamically later; based on the config option.""" - phonemes = phoneme_to_sequence(text, [cleaners], - language=language, - enable_eos_bos=False, - tp=tp, - add_blank=add_blank) + phonemes = phoneme_to_sequence( + text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank + ) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) return phonemes @staticmethod - def _load_or_generate_phoneme_sequence(wav_file, text, phoneme_cache_path, - enable_eos_bos, cleaners, language, - tp, add_blank): + def _load_or_generate_phoneme_sequence( + wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank + ): file_name = os.path.splitext(os.path.basename(wav_file))[0] # different names for normal phonemes and with blank chars. - file_name_ext = '_blanked_phoneme.npy' if add_blank else '_phoneme.npy' - cache_path = os.path.join(phoneme_cache_path, - file_name + file_name_ext) + file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy" + cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext) try: phonemes = np.load(cache_path) except FileNotFoundError: phonemes = MyDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, tp, add_blank) + text, cache_path, cleaners, language, tp, add_blank + ) except (ValueError, IOError): - print(" [!] failed loading phonemes for {}. " - "Recomputing.".format(wav_file)) + print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) phonemes = MyDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, tp, add_blank) + text, cache_path, cleaners, language, tp, add_blank + ) if enable_eos_bos: phonemes = pad_with_eos_bos(phonemes, tp=tp) phonemes = np.asarray(phonemes, dtype=np.int32) @@ -150,15 +148,20 @@ def load_data(self, idx): if not self.input_seq_computed: if self.use_phonemes: text = self._load_or_generate_phoneme_sequence( - wav_file, text, self.phoneme_cache_path, - self.enable_eos_bos, self.cleaners, self.phoneme_language, - self.tp, self.add_blank) + wav_file, + text, + self.phoneme_cache_path, + self.enable_eos_bos, + self.cleaners, + self.phoneme_language, + self.tp, + self.add_blank, + ) else: - text = np.asarray(text_to_sequence(text, [self.cleaners], - tp=self.tp, - add_blank=self.add_blank), - dtype=np.int32) + text = np.asarray( + text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32 + ) assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] @@ -173,12 +176,12 @@ def load_data(self, idx): return self.load_data(100) sample = { - 'text': text, - 'wav': wav, - 'attn': attn, - 'item_idx': self.items[idx][1], - 'speaker_name': speaker_name, - 'wav_file_name': os.path.basename(wav_file) + "text": text, + "wav": wav, + "attn": attn, + "item_idx": self.items[idx][1], + "speaker_name": speaker_name, + "wav_file_name": os.path.basename(wav_file), } return sample @@ -187,8 +190,7 @@ def _phoneme_worker(args): item = args[0] func_args = args[1] text, wav_file, *_ = item - phonemes = MyDataset._load_or_generate_phoneme_sequence( - wav_file, text, *func_args) + phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) return phonemes def compute_input_seq(self, num_workers=0): @@ -199,17 +201,19 @@ def compute_input_seq(self, num_workers=0): print(" | > Computing input sequences ...") for idx, item in enumerate(tqdm.tqdm(self.items)): text, *_ = item - sequence = np.asarray(text_to_sequence( - text, [self.cleaners], - tp=self.tp, - add_blank=self.add_blank), - dtype=np.int32) + sequence = np.asarray( + text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32 + ) self.items[idx][0] = sequence else: func_args = [ - self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, - self.phoneme_language, self.tp, self.add_blank + self.phoneme_cache_path, + self.enable_eos_bos, + self.cleaners, + self.phoneme_language, + self.tp, + self.add_blank, ] if self.verbose: print(" | > Computing phonemes ...") @@ -220,10 +224,11 @@ def compute_input_seq(self, num_workers=0): else: with Pool(num_workers) as p: phonemes = list( - tqdm.tqdm(p.imap(MyDataset._phoneme_worker, - [[item, func_args] - for item in self.items]), - total=len(self.items))) + tqdm.tqdm( + p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]), + total=len(self.items), + ) + ) for idx, p in enumerate(phonemes): self.items[idx][0] = p @@ -255,8 +260,10 @@ def sort_items(self): print(" | > Min length sequence: {}".format(np.min(lengths))) print(" | > Avg length sequence: {}".format(np.mean(lengths))) print( - " | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}" - .format(self.max_seq_len, self.min_seq_len, len(ignored))) + " | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format( + self.max_seq_len, self.min_seq_len, len(ignored) + ) + ) print(" | > Batch group size: {}.".format(self.batch_group_size)) def __len__(self): @@ -267,11 +274,11 @@ def __getitem__(self, idx): def collate_fn(self, batch): r""" - Perform preprocessing and create a final data batch: - 1. Sort batch instances by text-length - 2. Convert Audio signal to Spectrograms. - 3. PAD sequences wrt r. - 4. Load to Torch. + Perform preprocessing and create a final data batch: + 1. Sort batch instances by text-length + 2. Convert Audio signal to Spectrograms. + 3. PAD sequences wrt r. + 4. Load to Torch. """ # Puts each data field into a tensor with outer dimension batch size @@ -280,44 +287,29 @@ def collate_fn(self, batch): text_lenghts = np.array([len(d["text"]) for d in batch]) # sort items with text input length for RNN efficiency - text_lenghts, ids_sorted_decreasing = torch.sort( - torch.LongTensor(text_lenghts), dim=0, descending=True) + text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True) - wav = [batch[idx]['wav'] for idx in ids_sorted_decreasing] - item_idxs = [ - batch[idx]['item_idx'] for idx in ids_sorted_decreasing - ] - text = [batch[idx]['text'] for idx in ids_sorted_decreasing] + wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing] + item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] + text = [batch[idx]["text"] for idx in ids_sorted_decreasing] - speaker_name = [ - batch[idx]['speaker_name'] for idx in ids_sorted_decreasing - ] + speaker_name = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] # get speaker embeddings if self.speaker_mapping is not None: - wav_files_names = [ - batch[idx]['wav_file_name'] - for idx in ids_sorted_decreasing - ] - speaker_embedding = [ - self.speaker_mapping[w]['embedding'] - for w in wav_files_names - ] + wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing] + speaker_embedding = [self.speaker_mapping[w]["embedding"] for w in wav_files_names] else: speaker_embedding = None # compute features - mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] + mel = [self.ap.melspectrogram(w).astype("float32") for w in wav] mel_lengths = [m.shape[1] for m in mel] # compute 'stop token' targets - stop_targets = [ - np.array([0.] * (mel_len - 1) + [1.]) - for mel_len in mel_lengths - ] + stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] # PAD stop targets - stop_targets = prepare_stop_target(stop_targets, - self.outputs_per_step) + stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch text = prepare_data(text).astype(np.int32) @@ -340,9 +332,7 @@ def collate_fn(self, batch): # compute linear spectrogram if self.compute_linear_spec: - linear = [ - self.ap.spectrogram(w).astype('float32') for w in wav - ] + linear = [self.ap.spectrogram(w).astype("float32") for w in wav] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] @@ -351,8 +341,8 @@ def collate_fn(self, batch): linear = None # collate attention alignments - if batch[0]['attn'] is not None: - attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing] + if batch[0]["attn"] is not None: + attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] pad1 = text.shape[1] - attn.shape[0] @@ -362,8 +352,24 @@ def collate_fn(self, batch): attns = torch.FloatTensor(attns).unsqueeze(1) else: attns = None - return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \ - stop_targets, item_idxs, speaker_embedding, attns - - raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ - found {}".format(type(batch[0])))) + return ( + text, + text_lenghts, + speaker_name, + linear, + mel, + mel_lengths, + stop_targets, + item_idxs, + speaker_embedding, + attns, + ) + + raise TypeError( + ( + "batch must contain tensors, numbers, dicts or lists;\ + found {}".format( + type(batch[0]) + ) + ) + ) diff --git a/TTS/tts/datasets/preprocess.py b/TTS/tts/datasets/preprocess.py index 12148b1e24..0f82814db5 100644 --- a/TTS/tts/datasets/preprocess.py +++ b/TTS/tts/datasets/preprocess.py @@ -7,20 +7,22 @@ from typing import List from tqdm import tqdm + from TTS.tts.utils.generic_utils import split_dataset #################### # UTILITIES #################### + def load_meta_data(datasets, eval_split=True): meta_data_train_all = [] meta_data_eval_all = [] if eval_split else None for dataset in datasets: - name = dataset['name'] - root_path = dataset['path'] - meta_file_train = dataset['meta_file_train'] - meta_file_val = dataset['meta_file_val'] + name = dataset["name"] + root_path = dataset["path"] + meta_file_train = dataset["meta_file_train"] + meta_file_val = dataset["meta_file_val"] # setup the right data processor preprocessor = get_preprocessor_by_name(name) # load train set @@ -35,8 +37,8 @@ def load_meta_data(datasets, eval_split=True): meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train # load attention masks for duration predictor training - if 'meta_file_attn_mask' in dataset and dataset['meta_file_attn_mask'] is not None: - meta_data = dict(load_attention_mask_meta_data(dataset['meta_file_attn_mask'])) + if "meta_file_attn_mask" in dataset and dataset["meta_file_attn_mask"] is not None: + meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): attn_file = meta_data[ins[1]].strip() meta_data_train_all[idx].append(attn_file) @@ -49,12 +51,12 @@ def load_meta_data(datasets, eval_split=True): def load_attention_mask_meta_data(metafile_path): """Load meta data file created by compute_attention_masks.py""" - with open(metafile_path, 'r') as f: + with open(metafile_path, "r") as f: lines = f.readlines() meta_data = [] for line in lines: - wav_file, attn_file = line.split('|') + wav_file, attn_file = line.split("|") meta_data.append([wav_file, attn_file]) return meta_data @@ -69,6 +71,7 @@ def get_preprocessor_by_name(name): # DATASETS ######################## + def tweb(root_path, meta_file): """Normalize TWEB dataset. https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset @@ -76,10 +79,10 @@ def tweb(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "tweb" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - cols = line.split('\t') - wav_file = os.path.join(root_path, cols[0] + '.wav') + cols = line.split("\t") + wav_file = os.path.join(root_path, cols[0] + ".wav") text = cols[1] items.append([text, wav_file, speaker_name]) return items @@ -90,9 +93,9 @@ def mozilla(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "mozilla" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - cols = line.split('|') + cols = line.split("|") wav_file = cols[1].strip() text = cols[0].strip() wav_file = os.path.join(root_path, "wavs", wav_file) @@ -105,9 +108,9 @@ def mozilla_de(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "mozilla" - with open(txt_file, 'r', encoding="ISO 8859-1") as ttf: + with open(txt_file, "r", encoding="ISO 8859-1") as ttf: for line in ttf: - cols = line.strip().split('|') + cols = line.strip().split("|") wav_file = cols[0].strip() text = cols[1].strip() folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" @@ -118,8 +121,7 @@ def mozilla_de(root_path, meta_file): def mailabs(root_path, meta_files=None): """Normalizes M-AI-Labs meta data files to TTS format""" - speaker_regex = re.compile( - "by_book/(male|female)/(?P[^/]+)/") + speaker_regex = re.compile("by_book/(male|female)/(?P[^/]+)/") if meta_files is None: csv_files = glob(root_path + "/**/metadata.csv", recursive=True) else: @@ -135,21 +137,18 @@ def mailabs(root_path, meta_files=None): continue speaker_name = speaker_name_match.group("speaker_name") print(" | > {}".format(csv_file)) - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - cols = line.split('|') + cols = line.split("|") if meta_files is None: - wav_file = os.path.join(folder, 'wavs', cols[0] + '.wav') + wav_file = os.path.join(folder, "wavs", cols[0] + ".wav") else: - wav_file = os.path.join(root_path, - folder.replace("metadata.csv", ""), - 'wavs', cols[0] + '.wav') + wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") if os.path.isfile(wav_file): text = cols[1].strip() items.append([text, wav_file, speaker_name]) else: - raise RuntimeError("> File %s does not exist!" % - (wav_file)) + raise RuntimeError("> File %s does not exist!" % (wav_file)) return items @@ -159,10 +158,10 @@ def ljspeech(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "ljspeech" - with open(txt_file, 'r', encoding="utf-8") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: - cols = line.split('|') - wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav') + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[1] items.append([text, wav_file, speaker_name]) return items @@ -171,15 +170,15 @@ def ljspeech(root_path, meta_file): def sam_accenture(root_path, meta_file): """Normalizes the sam-accenture meta data file to TTS format https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files""" - xml_file = os.path.join(root_path, 'voice_over_recordings', meta_file) + xml_file = os.path.join(root_path, "voice_over_recordings", meta_file) xml_root = ET.parse(xml_file).getroot() items = [] speaker_name = "sam_accenture" - for item in xml_root.findall('./fileid'): + for item in xml_root.findall("./fileid"): text = item.text - wav_file = os.path.join(root_path, 'vo_voice_quality_transformation', item.get('id')+'.wav') + wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") if not os.path.exists(wav_file): - print(f' [!] {wav_file} in metafile does not exist. Skipping...') + print(f" [!] {wav_file} in metafile does not exist. Skipping...") continue items.append([text, wav_file, speaker_name]) return items @@ -191,10 +190,10 @@ def ruslan(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "ljspeech" - with open(txt_file, 'r', encoding="utf-8") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: - cols = line.split('|') - wav_file = os.path.join(root_path, 'RUSLAN', cols[0] + '.wav') + cols = line.split("|") + wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") text = cols[1] items.append([text, wav_file, speaker_name]) return items @@ -205,9 +204,9 @@ def css10(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "ljspeech" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - cols = line.split('|') + cols = line.split("|") wav_file = os.path.join(root_path, cols[0]) text = cols[1] items.append([text, wav_file, speaker_name]) @@ -219,10 +218,10 @@ def nancy(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "nancy" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: utt_id = line.split()[1] - text = line[line.find('"') + 1:line.rfind('"') - 1] + text = line[line.find('"') + 1 : line.rfind('"') - 1] wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") items.append([text, wav_file, speaker_name]) return items @@ -232,7 +231,7 @@ def common_voice(root_path, meta_file): """Normalize the common voice meta data file to TTS format.""" txt_file = os.path.join(root_path, meta_file) items = [] - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: if line.startswith("client_id"): continue @@ -240,7 +239,7 @@ def common_voice(root_path, meta_file): text = cols[2] speaker_name = cols[0] wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) - items.append([text, wav_file, 'MCV_' + speaker_name]) + items.append([text, wav_file, "MCV_" + speaker_name]) return items @@ -250,19 +249,18 @@ def libri_tts(root_path, meta_files=None): if meta_files is None: meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True) for meta_file in meta_files: - _meta_file = os.path.basename(meta_file).split('.')[0] - speaker_name = _meta_file.split('_')[0] - chapter_id = _meta_file.split('_')[1] + _meta_file = os.path.basename(meta_file).split(".")[0] + speaker_name = _meta_file.split("_")[0] + chapter_id = _meta_file.split("_")[1] _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}") - with open(meta_file, 'r') as ttf: + with open(meta_file, "r") as ttf: for line in ttf: - cols = line.split('\t') - wav_file = os.path.join(_root_path, cols[0] + '.wav') + cols = line.split("\t") + wav_file = os.path.join(_root_path, cols[0] + ".wav") text = cols[1] - items.append([text, wav_file, 'LTTS_' + speaker_name]) + items.append([text, wav_file, "LTTS_" + speaker_name]) for item in items: - assert os.path.exists( - item[1]), f" [!] wav files don't exist - {item[1]}" + assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" return items @@ -271,11 +269,10 @@ def custom_turkish(root_path, meta_file): items = [] speaker_name = "turkish-female" skipped_files = [] - with open(txt_file, 'r', encoding='utf-8') as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: - cols = line.split('|') - wav_file = os.path.join(root_path, 'wavs', - cols[0].strip() + '.wav') + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav") if not os.path.exists(wav_file): skipped_files.append(wav_file) continue @@ -287,14 +284,14 @@ def custom_turkish(root_path, meta_file): # ToDo: add the dataset link when the dataset is released publicly def brspeech(root_path, meta_file): - '''BRSpeech 3.0 beta''' + """BRSpeech 3.0 beta""" txt_file = os.path.join(root_path, meta_file) items = [] - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: if line.startswith("wav_filename"): continue - cols = line.split('|') + cols = line.split("|") wav_file = os.path.join(root_path, cols[0]) text = cols[2] speaker_name = cols[3] @@ -302,45 +299,41 @@ def brspeech(root_path, meta_file): return items -def vctk(root_path, meta_files=None, wavs_path='wav48'): +def vctk(root_path, meta_files=None, wavs_path="wav48"): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" test_speakers = meta_files items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for meta_file in meta_files: - _, speaker_id, txt_file = os.path.relpath(meta_file, - root_path).split(os.sep) - file_id = txt_file.split('.')[0] - if isinstance(test_speakers, - list): # if is list ignore this speakers ids + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + if isinstance(test_speakers, list): # if is list ignore this speakers ids if speaker_id in test_speakers: continue with open(meta_file) as file_text: text = file_text.readlines()[0] - wav_file = os.path.join(root_path, wavs_path, speaker_id, - file_id + '.wav') - items.append([text, wav_file, 'VCTK_' + speaker_id]) + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") + items.append([text, wav_file, "VCTK_" + speaker_id]) return items -def vctk_slim(root_path, meta_files=None, wavs_path='wav48'): +def vctk_slim(root_path, meta_files=None, wavs_path="wav48"): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" items = [] txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for text_file in txt_files: - _, speaker_id, txt_file = os.path.relpath(text_file, - root_path).split(os.sep) - file_id = txt_file.split('.')[0] + _, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] if isinstance(meta_files, list): # if is list ignore this speakers ids if speaker_id in meta_files: continue - wav_file = os.path.join(root_path, wavs_path, speaker_id, - file_id + '.wav') - items.append([None, wav_file, 'VCTK_' + speaker_id]) + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") + items.append([None, wav_file, "VCTK_" + speaker_id]) return items + # ======================================== VOX CELEB =========================================== def voxceleb2(root_path, meta_file=None): """ @@ -365,31 +358,33 @@ def _voxcel_x(root_path, meta_file, voxcel_idx): # if not exists meta file, crawl recursively for 'wav' files if meta_file is not None: - with open(str(meta_file), 'r') as f: - return [x.strip().split('|') for x in f.readlines()] + with open(str(meta_file), "r") as f: + return [x.strip().split("|") for x in f.readlines()] elif not cache_to.exists(): cnt = 0 meta_data = [] wav_files = voxceleb_path.rglob("**/*.wav") - for path in tqdm(wav_files, desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.", - total=expected_count): + for path in tqdm( + wav_files, + desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.", + total=expected_count, + ): speaker_id = str(Path(path).parent.parent.stem) - assert speaker_id.startswith('id') + assert speaker_id.startswith("id") text = None # VoxCel does not provide transciptions, and they are not needed for training the SE meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n") cnt += 1 - with open(str(cache_to), 'w') as f: + with open(str(cache_to), "w") as f: f.write("".join(meta_data)) if cnt < expected_count: raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}") - with open(str(cache_to), 'r') as f: - return [x.strip().split('|') for x in f.readlines()] - + with open(str(cache_to), "r") as f: + return [x.strip().split("|") for x in f.readlines()] -def baker(root_path: str, meta_file: str) -> List[List[str]]: +def baker(root_path: str, meta_file: str) -> List[List[str]]: """Normalizes the Baker meta data file to TTS format Args: @@ -401,9 +396,9 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]: txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "baker" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - wav_name, text = line.rstrip('\n').split("|") + wav_name, text = line.rstrip("\n").split("|") wav_path = os.path.join(root_path, "clips_22", wav_name) items.append([text, wav_path, speaker_name]) return items diff --git a/TTS/tts/layers/align_tts/duration_predictor.py b/TTS/tts/layers/align_tts/duration_predictor.py index 8391646415..b2b83894cc 100644 --- a/TTS/tts/layers/align_tts/duration_predictor.py +++ b/TTS/tts/layers/align_tts/duration_predictor.py @@ -1,6 +1,7 @@ from torch import nn -from TTS.tts.layers.generic.transformer import FFTransformerBlock + from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.generic.transformer import FFTransformerBlock class DurationPredictor(nn.Module): diff --git a/TTS/tts/layers/align_tts/mdn.py b/TTS/tts/layers/align_tts/mdn.py index f5847cb47c..cdb332524b 100644 --- a/TTS/tts/layers/align_tts/mdn.py +++ b/TTS/tts/layers/align_tts/mdn.py @@ -5,6 +5,7 @@ class MDNBlock(nn.Module): """Mixture of Density Network implementation https://arxiv.org/pdf/2003.01950.pdf """ + def __init__(self, in_channels, out_channels): super().__init__() self.out_channels = out_channels @@ -24,6 +25,6 @@ def forward(self, x): mu_sigma = self.conv2(o) # TODO: check this sigmoid # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :]) - mu = mu_sigma[:, :self.out_channels//2, :] - log_sigma = mu_sigma[:, self.out_channels//2:, :] + mu = mu_sigma[:, : self.out_channels // 2, :] + log_sigma = mu_sigma[:, self.out_channels // 2 :, :] return mu, log_sigma diff --git a/TTS/tts/layers/feed_forward/decoder.py b/TTS/tts/layers/feed_forward/decoder.py index 5293e8bc98..34c586aab2 100644 --- a/TTS/tts/layers/feed_forward/decoder.py +++ b/TTS/tts/layers/feed_forward/decoder.py @@ -1,9 +1,10 @@ import torch from torch import nn -from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN + +from TTS.tts.layers.generic.res_conv_bn import Conv1dBN, Conv1dBNBlock, ResidualConv1dBNBlock +from TTS.tts.layers.generic.transformer import FFTransformerBlock from TTS.tts.layers.generic.wavenet import WNBlocks from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer -from TTS.tts.layers.generic.transformer import FFTransformerBlock class WaveNetDecoder(nn.Module): @@ -31,15 +32,16 @@ class WaveNetDecoder(nn.Module): hidden_channels (int): number of hidden channels for prenet and postnet. params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params): super().__init__() # prenet - self.prenet = torch.nn.Conv1d(in_channels, params['hidden_channels'], 1) + self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1) # wavenet layers - self.wn = WNBlocks(params['hidden_channels'], c_in_channels=c_in_channels, **params) + self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params) # postnet self.postnet = [ - torch.nn.Conv1d(params['hidden_channels'], hidden_channels, 1), + torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1), torch.nn.ReLU(), torch.nn.Conv1d(hidden_channels, hidden_channels, 1), torch.nn.ReLU(), @@ -77,12 +79,12 @@ class RelativePositionTransformerDecoder(nn.Module): hidden_channels (int): number of hidden channels including Transformer layers. params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1) - self.rel_pos_transformer = RelativePositionTransformer( - in_channels, out_channels, hidden_channels, **params) + self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument o = self.prenet(x) * x_mask @@ -107,6 +109,7 @@ class FFTransformerDecoder(nn.Module): hidden_channels (int): number of hidden channels including Transformer layers. params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, params): super().__init__() @@ -117,7 +120,7 @@ def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument # TODO: handle multi-speaker x_mask = 1 if x_mask is None else x_mask o = self.transformer_block(x) * x_mask - o = self.postnet(o)* x_mask + o = self.postnet(o) * x_mask return o @@ -141,19 +144,15 @@ class ResidualConv1dBNDecoder(nn.Module): hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers. params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() - self.res_conv_block = ResidualConv1dBNBlock(in_channels, - hidden_channels, - hidden_channels, **params) + self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params) self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) self.postnet = nn.Sequential( - Conv1dBNBlock(hidden_channels, - hidden_channels, - hidden_channels, - params['kernel_size'], - 1, - num_conv_blocks=2), + Conv1dBNBlock( + hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2 + ), nn.Conv1d(hidden_channels, out_channels, 1), ) @@ -178,17 +177,18 @@ class Decoder(nn.Module): # pylint: disable=dangerous-default-value def __init__( - self, - out_channels, - in_hidden_channels, - decoder_type='residual_conv_bn', - decoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4, 8] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 17 - }, - c_in_channels=0): + self, + out_channels, + in_hidden_channels, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + c_in_channels=0, + ): super().__init__() if decoder_type.lower() == "relative_position_transformer": @@ -196,23 +196,27 @@ def __init__( in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, - params=decoder_params) - elif decoder_type.lower() == 'residual_conv_bn': + params=decoder_params, + ) + elif decoder_type.lower() == "residual_conv_bn": self.decoder = ResidualConv1dBNDecoder( in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, - params=decoder_params) - elif decoder_type.lower() == 'wavenet': - self.decoder = WaveNetDecoder(in_channels=in_hidden_channels, - out_channels=out_channels, - hidden_channels=in_hidden_channels, - c_in_channels=c_in_channels, - params=decoder_params) - elif decoder_type.lower() == 'fftransformer': + params=decoder_params, + ) + elif decoder_type.lower() == "wavenet": + self.decoder = WaveNetDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + c_in_channels=c_in_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "fftransformer": self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params) else: - raise ValueError(f'[!] Unknown decoder type - {decoder_type}') + raise ValueError(f"[!] Unknown decoder type - {decoder_type}") def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument """ diff --git a/TTS/tts/layers/feed_forward/duration_predictor.py b/TTS/tts/layers/feed_forward/duration_predictor.py index 5c5c4f3a27..5392aeca3c 100644 --- a/TTS/tts/layers/feed_forward/duration_predictor.py +++ b/TTS/tts/layers/feed_forward/duration_predictor.py @@ -16,16 +16,19 @@ class DurationPredictor(nn.Module): Args: hidden_channels (int): number of channels in the inner layers. """ + def __init__(self, hidden_channels): super().__init__() - self.layers = nn.ModuleList([ - Conv1dBN(hidden_channels, hidden_channels, 4, 1), - Conv1dBN(hidden_channels, hidden_channels, 3, 1), - Conv1dBN(hidden_channels, hidden_channels, 1, 1), - nn.Conv1d(hidden_channels, 1, 1) - ]) + self.layers = nn.ModuleList( + [ + Conv1dBN(hidden_channels, hidden_channels, 4, 1), + Conv1dBN(hidden_channels, hidden_channels, 3, 1), + Conv1dBN(hidden_channels, hidden_channels, 1, 1), + nn.Conv1d(hidden_channels, 1, 1), + ] + ) def forward(self, x, x_mask): """ diff --git a/TTS/tts/layers/feed_forward/encoder.py b/TTS/tts/layers/feed_forward/encoder.py index 6bc46cfa7f..81ffdeef53 100644 --- a/TTS/tts/layers/feed_forward/encoder.py +++ b/TTS/tts/layers/feed_forward/encoder.py @@ -1,8 +1,8 @@ from torch import nn -from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer -from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock from TTS.tts.layers.generic.transformer import FFTransformerBlock +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer class RelativePositionTransformerEncoder(nn.Module): @@ -16,17 +16,19 @@ class RelativePositionTransformerEncoder(nn.Module): hidden_channels (int): number of hidden channels params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() - self.prenet = ResidualConv1dBNBlock(in_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_res_blocks=3, - num_conv_blocks=1, - dilations=[1, 1, 1]) - self.rel_pos_transformer = RelativePositionTransformer( - hidden_channels, out_channels, hidden_channels, **params) + self.prenet = ResidualConv1dBNBlock( + in_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_res_blocks=3, + num_conv_blocks=1, + dilations=[1, 1, 1], + ) + self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument if x_mask is None: @@ -47,20 +49,20 @@ class ResidualConv1dBNEncoder(nn.Module): hidden_channels (int): number of hidden channels params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() - self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), - nn.ReLU()) - self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, - hidden_channels, - hidden_channels, **params) - - self.postnet = nn.Sequential(*[ - nn.Conv1d(hidden_channels, hidden_channels, 1), - nn.ReLU(), - nn.BatchNorm1d(hidden_channels), - nn.Conv1d(hidden_channels, out_channels, 1) - ]) + self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU()) + self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params) + + self.postnet = nn.Sequential( + *[ + nn.Conv1d(hidden_channels, hidden_channels, 1), + nn.ReLU(), + nn.BatchNorm1d(hidden_channels), + nn.Conv1d(hidden_channels, out_channels, 1), + ] + ) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument if x_mask is None: @@ -115,18 +117,15 @@ class Encoder(nn.Module): } ``` """ + def __init__( - self, - in_hidden_channels, - out_channels, - encoder_type='residual_conv_bn', - encoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 13 - }, - c_in_channels=0): + self, + in_hidden_channels, + out_channels, + encoder_type="residual_conv_bn", + encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, + c_in_channels=0, + ): super().__init__() self.out_channels = out_channels self.in_channels = in_hidden_channels @@ -138,20 +137,19 @@ def __init__( if encoder_type.lower() == "relative_position_transformer": # text encoder self.encoder = RelativePositionTransformerEncoder( - in_hidden_channels, out_channels, in_hidden_channels, - encoder_params) # pylint: disable=unexpected-keyword-arg - elif encoder_type.lower() == 'residual_conv_bn': - self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, - out_channels, - in_hidden_channels, - encoder_params) - elif encoder_type.lower() == 'fftransformer': - assert in_hidden_channels == out_channels, \ - "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" - self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg + in_hidden_channels, out_channels, in_hidden_channels, encoder_params + ) # pylint: disable=unexpected-keyword-arg + elif encoder_type.lower() == "residual_conv_bn": + self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params) + elif encoder_type.lower() == "fftransformer": + assert ( + in_hidden_channels == out_channels + ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" + self.encoder = FFTransformerBlock( + in_hidden_channels, **encoder_params + ) # pylint: disable=unexpected-keyword-arg else: - raise NotImplementedError(' [!] unknown encoder type.') - + raise NotImplementedError(" [!] unknown encoder type.") def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument """ diff --git a/TTS/tts/layers/generic/gated_conv.py b/TTS/tts/layers/generic/gated_conv.py index ec95565a99..9a29c4499f 100644 --- a/TTS/tts/layers/generic/gated_conv.py +++ b/TTS/tts/layers/generic/gated_conv.py @@ -10,6 +10,7 @@ class GatedConvBlock(nn.Module): kernel_size (int): convolution kernel size. dropout_p (float): dropout rate. """ + def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): super().__init__() # class arguments @@ -20,21 +21,14 @@ def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): self.norm_layers = nn.ModuleList() self.layers = nn.ModuleList() for _ in range(num_layers): - self.conv_layers += [ - nn.Conv1d(in_out_channels, - 2 * in_out_channels, - kernel_size, - padding=kernel_size // 2) - ] + self.conv_layers += [nn.Conv1d(in_out_channels, 2 * in_out_channels, kernel_size, padding=kernel_size // 2)] self.norm_layers += [LayerNorm(2 * in_out_channels)] def forward(self, x, x_mask): o = x res = x for idx in range(self.num_layers): - o = nn.functional.dropout(o, - p=self.dropout_p, - training=self.training) + o = nn.functional.dropout(o, p=self.dropout_p, training=self.training) o = self.conv_layers[idx](o * x_mask) o = self.norm_layers[idx](o) o = nn.functional.glu(o, dim=1) diff --git a/TTS/tts/layers/generic/normalization.py b/TTS/tts/layers/generic/normalization.py index e3dbb52f13..fd607b75a3 100644 --- a/TTS/tts/layers/generic/normalization.py +++ b/TTS/tts/layers/generic/normalization.py @@ -22,24 +22,17 @@ def __init__(self, channels, eps=1e-4): def forward(self, x): mean = torch.mean(x, 1, keepdim=True) - variance = torch.mean((x - mean)**2, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) x = (x - mean) * torch.rsqrt(variance + self.eps) x = x * self.gamma + self.beta return x class TemporalBatchNorm1d(nn.BatchNorm1d): - """Normalize each channel separately over time and batch. - """ - def __init__(self, - channels, - affine=True, - track_running_stats=True, - momentum=0.1): - super().__init__(channels, - affine=affine, - track_running_stats=track_running_stats, - momentum=momentum) + """Normalize each channel separately over time and batch.""" + + def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1): + super().__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum) def forward(self, x): return super().forward(x.transpose(2, 1)).transpose(2, 1) @@ -58,6 +51,7 @@ class ActNorm(nn.Module): - inputs: (B, C, T) - outputs: (B, C, T) """ + def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument super().__init__() self.channels = channels @@ -68,8 +62,7 @@ def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-arg def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument if x_mask is None: - x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, - dtype=x.dtype) + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) x_len = torch.sum(x_mask, [1, 2]) if not self.initialized: self.initialize(x, x_mask) @@ -95,13 +88,11 @@ def initialize(self, x, x_mask): denom = torch.sum(x_mask, [0, 2]) m = torch.sum(x * x_mask, [0, 2]) / denom m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom - v = m_sq - (m**2) + v = m_sq - (m ** 2) logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) - bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to( - dtype=self.bias.dtype) - logs_init = (-logs).view(*self.logs.shape).to( - dtype=self.logs.dtype) + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) self.bias.data.copy_(bias_init) self.logs.data.copy_(logs_init) diff --git a/TTS/tts/layers/generic/pos_encoding.py b/TTS/tts/layers/generic/pos_encoding.py index 95330b4acf..46a0b516b9 100644 --- a/TTS/tts/layers/generic/pos_encoding.py +++ b/TTS/tts/layers/generic/pos_encoding.py @@ -1,6 +1,6 @@ -import torch import math +import torch from torch import nn @@ -11,20 +11,20 @@ class PositionalEncoding(nn.Module): channels (int): embedding size dropout (float): dropout parameter """ + def __init__(self, channels, dropout_p=0.0, max_len=5000): super().__init__() if channels % 2 != 0: raise ValueError( - "Cannot use sin/cos positional encoding with " - "odd channels (got channels={:d})".format(channels)) + "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels) + ) pe = torch.zeros(max_len, channels) position = torch.arange(0, max_len).unsqueeze(1) - div_term = torch.pow(10000, - torch.arange(0, channels, 2).float() / channels) + div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) pe = pe.unsqueeze(0).transpose(1, 2) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) if dropout_p > 0: self.dropout = nn.Dropout(p=dropout_p) self.channels = channels @@ -43,14 +43,15 @@ def forward(self, x, mask=None, first_idx=None, last_idx=None): if self.pe.size(2) < x.size(2): raise RuntimeError( f"Sequence is {x.size(2)} but PositionalEncoding is" - f" limited to {self.pe.size(2)}. See max_len argument.") + f" limited to {self.pe.size(2)}. See max_len argument." + ) if mask is not None: - pos_enc = (self.pe[:, :, :x.size(2)] * mask) + pos_enc = self.pe[:, :, : x.size(2)] * mask else: - pos_enc = self.pe[:, :, :x.size(2)] + pos_enc = self.pe[:, :, : x.size(2)] x = x + pos_enc else: x = x + self.pe[:, :, first_idx:last_idx] - if hasattr(self, 'dropout'): + if hasattr(self, "dropout"): x = self.dropout(x) return x diff --git a/TTS/tts/layers/generic/res_conv_bn.py b/TTS/tts/layers/generic/res_conv_bn.py index 964afd0a5f..30c134cd70 100644 --- a/TTS/tts/layers/generic/res_conv_bn.py +++ b/TTS/tts/layers/generic/res_conv_bn.py @@ -3,9 +3,10 @@ class ZeroTemporalPad(nn.Module): """Pad sequences to equal lentgh in the temporal dimension""" + def __init__(self, kernel_size, dilation): super().__init__() - total_pad = (dilation * (kernel_size - 1)) + total_pad = dilation * (kernel_size - 1) begin = total_pad // 2 end = total_pad - begin self.pad_layer = nn.ZeroPad2d((0, 0, begin, end)) @@ -27,9 +28,10 @@ class Conv1dBN(nn.Module): kernel_size (int): kernel size for convolutional filters. dilation (int): dilation for convolution layers. """ + def __init__(self, in_channels, out_channels, kernel_size, dilation): super().__init__() - padding = (dilation * (kernel_size - 1)) + padding = dilation * (kernel_size - 1) pad_s = padding // 2 pad_e = padding - pad_s self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation) @@ -55,14 +57,17 @@ class Conv1dBNBlock(nn.Module): dilation (int): dilation for convolution layers. num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2. """ + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2): super().__init__() self.conv_bn_blocks = [] for idx in range(num_conv_blocks): - layer = Conv1dBN(in_channels if idx == 0 else hidden_channels, - out_channels if idx == (num_conv_blocks - 1) else hidden_channels, - kernel_size, - dilation) + layer = Conv1dBN( + in_channels if idx == 0 else hidden_channels, + out_channels if idx == (num_conv_blocks - 1) else hidden_channels, + kernel_size, + dilation, + ) self.conv_bn_blocks.append(layer) self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks) @@ -91,18 +96,23 @@ class ResidualConv1dBNBlock(nn.Module): num_res_blocks (int, optional): number of residual blocks. Defaults to 13. num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2. """ - def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2): + + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2 + ): super().__init__() assert len(dilations) == num_res_blocks self.res_blocks = nn.ModuleList() for idx, dilation in enumerate(dilations): - block = Conv1dBNBlock(in_channels if idx == 0 else hidden_channels, - out_channels if (idx + 1) == len(dilations) else hidden_channels, - hidden_channels, - kernel_size, - dilation, - num_conv_blocks) + block = Conv1dBNBlock( + in_channels if idx == 0 else hidden_channels, + out_channels if (idx + 1) == len(dilations) else hidden_channels, + hidden_channels, + kernel_size, + dilation, + num_conv_blocks, + ) self.res_blocks.append(block) def forward(self, x, x_mask=None): diff --git a/TTS/tts/layers/generic/time_depth_sep_conv.py b/TTS/tts/layers/generic/time_depth_sep_conv.py index c9a117c8e5..186cea02e7 100644 --- a/TTS/tts/layers/generic/time_depth_sep_conv.py +++ b/TTS/tts/layers/generic/time_depth_sep_conv.py @@ -5,12 +5,8 @@ class TimeDepthSeparableConv(nn.Module): """Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf It shows competative results with less computation and memory footprint.""" - def __init__(self, - in_channels, - hid_channels, - out_channels, - kernel_size, - bias=True): + + def __init__(self, in_channels, hid_channels, out_channels, kernel_size, bias=True): super().__init__() self.in_channels = in_channels @@ -62,28 +58,24 @@ def forward(self, x): class TimeDepthSeparableConvBlock(nn.Module): - def __init__(self, - in_channels, - hid_channels, - out_channels, - num_layers, - kernel_size, - bias=True): + def __init__(self, in_channels, hid_channels, out_channels, num_layers, kernel_size, bias=True): super().__init__() assert (kernel_size - 1) % 2 == 0 assert num_layers > 1 self.layers = nn.ModuleList() layer = TimeDepthSeparableConv( - in_channels, hid_channels, - out_channels if num_layers == 1 else hid_channels, kernel_size, - bias) + in_channels, hid_channels, out_channels if num_layers == 1 else hid_channels, kernel_size, bias + ) self.layers.append(layer) for idx in range(num_layers - 1): layer = TimeDepthSeparableConv( - hid_channels, hid_channels, out_channels if - (idx + 1) == (num_layers - 1) else hid_channels, kernel_size, - bias) + hid_channels, + hid_channels, + out_channels if (idx + 1) == (num_layers - 1) else hid_channels, + kernel_size, + bias, + ) self.layers.append(layer) def forward(self, x, mask): diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 2324938e91..24d604f6ae 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -4,16 +4,9 @@ class FFTransformer(nn.Module): - def __init__(self, - in_out_channels, - num_heads, - hidden_channels_ffn=1024, - kernel_size_fft=3, - dropout_p=0.1): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn=1024, kernel_size_fft=3, dropout_p=0.1): super().__init__() - self.self_attn = nn.MultiheadAttention(in_out_channels, - num_heads, - dropout=dropout_p) + self.self_attn = nn.MultiheadAttention(in_out_channels, num_heads, dropout=dropout_p) padding = (kernel_size_fft - 1) // 2 self.conv1 = nn.Conv1d(in_out_channels, hidden_channels_ffn, kernel_size=kernel_size_fft, padding=padding) @@ -27,11 +20,7 @@ def __init__(self, def forward(self, src, src_mask=None, src_key_padding_mask=None): """😦 ugly looking with all the transposing """ src = src.permute(2, 0, 1) - src2, enc_align = self.self_attn(src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask) + src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) src = self.norm1(src + src2) # T x B x D -> B x D x T src = src.permute(1, 2, 0) @@ -45,15 +34,19 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None): class FFTransformerBlock(nn.Module): - def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, - num_layers, dropout_p): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p): super().__init__() - self.fft_layers = nn.ModuleList([ - FFTransformer(in_out_channels=in_out_channels, - num_heads=num_heads, - hidden_channels_ffn=hidden_channels_ffn, - dropout_p=dropout_p) for _ in range(num_layers) - ]) + self.fft_layers = nn.ModuleList( + [ + FFTransformer( + in_out_channels=in_out_channels, + num_heads=num_heads, + hidden_channels_ffn=hidden_channels_ffn, + dropout_p=dropout_p, + ) + for _ in range(num_layers) + ] + ) def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument """ diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index 97eee8793e..0c87e9dff2 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -32,15 +32,18 @@ class WN(torch.nn.Module): dropout_p (float): dropout rate. weight_norm (bool): enable/disable weight norm for convolution layers. """ - def __init__(self, - in_channels, - hidden_channels, - kernel_size, - dilation_rate, - num_layers, - c_in_channels=0, - dropout_p=0, - weight_norm=True): + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): super().__init__() assert kernel_size % 2 == 1 assert hidden_channels % 2 == 0 @@ -58,20 +61,16 @@ def __init__(self, # init conditioning layer if c_in_channels > 0: - cond_layer = torch.nn.Conv1d(c_in_channels, - 2 * hidden_channels * num_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, - name='weight') + cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") # intermediate layers for i in range(num_layers): - dilation = dilation_rate**i + dilation = dilation_rate ** i padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d(hidden_channels, - 2 * hidden_channels, - kernel_size, - dilation=dilation, - padding=padding) - in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) if i < num_layers - 1: @@ -79,10 +78,8 @@ def __init__(self, else: res_skip_channels = hidden_channels - res_skip_layer = torch.nn.Conv1d(hidden_channels, - res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, - name='weight') + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) # setup weight norm if not weight_norm: @@ -99,15 +96,14 @@ def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-a x_in = self.dropout(x_in) if g is not None: cond_offset = i * 2 * self.hidden_channels - g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] else: g_l = torch.zeros_like(x_in) - acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, - n_channels_tensor) + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) res_skip_acts = self.res_skip_layers[i](acts) if i < self.num_layers - 1: - x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask - output = output + res_skip_acts[:, self.hidden_channels:, :] + x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] else: output = output + res_skip_acts return output * x_mask @@ -140,28 +136,32 @@ class WNBlocks(nn.Module): weight_norm (bool): enable/disable weight norm for convolution layers. """ - def __init__(self, - in_channels, - hidden_channels, - kernel_size, - dilation_rate, - num_blocks, - num_layers, - c_in_channels=0, - dropout_p=0, - weight_norm=True): + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_blocks, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): super().__init__() self.wn_blocks = nn.ModuleList() for idx in range(num_blocks): - layer = WN(in_channels=in_channels if idx == 0 else hidden_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - dilation_rate=dilation_rate, - num_layers=num_layers, - c_in_channels=c_in_channels, - dropout_p=dropout_p, - weight_norm=weight_norm) + layer = WN( + in_channels=in_channels if idx == 0 else hidden_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + weight_norm=weight_norm, + ) self.wn_blocks.append(layer) def forward(self, x, x_mask=None, g=None): diff --git a/TTS/tts/layers/glow_tts/decoder.py b/TTS/tts/layers/glow_tts/decoder.py index 46533ed187..7b3f0ed1f5 100644 --- a/TTS/tts/layers/glow_tts/decoder.py +++ b/TTS/tts/layers/glow_tts/decoder.py @@ -1,8 +1,8 @@ import torch from torch import nn -from TTS.tts.layers.glow_tts.glow import InvConvNear, CouplingBlock from TTS.tts.layers.generic.normalization import ActNorm +from TTS.tts.layers.glow_tts.glow import CouplingBlock, InvConvNear def squeeze(x, x_mask=None, num_sqz=2): @@ -18,14 +18,12 @@ def squeeze(x, x_mask=None, num_sqz=2): t = (t // num_sqz) * num_sqz x = x[:, :, :t] x_sqz = x.view(b, c, t // num_sqz, num_sqz) - x_sqz = x_sqz.permute(0, 3, 1, - 2).contiguous().view(b, c * num_sqz, t // num_sqz) + x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz) if x_mask is not None: - x_mask = x_mask[:, :, num_sqz - 1::num_sqz] + x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz] else: - x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, - dtype=x.dtype) + x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype) return x_sqz * x_mask, x_mask @@ -34,20 +32,16 @@ def unsqueeze(x, x_mask=None, num_sqz=2): Note: each 's' is a n-dimensional vector. - [[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]] """ + [[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]""" b, c, t = x.size() x_unsqz = x.view(b, num_sqz, c // num_sqz, t) - x_unsqz = x_unsqz.permute(0, 2, 3, - 1).contiguous().view(b, c // num_sqz, - t * num_sqz) + x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz) if x_mask is not None: - x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, - num_sqz).view(b, 1, t * num_sqz) + x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz) else: - x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, - dtype=x.dtype) + x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype) return x_unsqz * x_mask, x_mask @@ -65,18 +59,21 @@ class Decoder(nn.Module): dropout_p (float): wavenet dropout rate. sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. """ - def __init__(self, - in_channels, - hidden_channels, - kernel_size, - dilation_rate, - num_flow_blocks, - num_coupling_layers, - dropout_p=0., - num_splits=4, - num_squeeze=2, - sigmoid_scale=False, - c_in_channels=0): + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p=0.0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + c_in_channels=0, + ): super().__init__() self.in_channels = in_channels @@ -94,18 +91,19 @@ def __init__(self, self.flows = nn.ModuleList() for _ in range(num_flow_blocks): self.flows.append(ActNorm(channels=in_channels * num_squeeze)) + self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits)) self.flows.append( - InvConvNear(channels=in_channels * num_squeeze, - num_splits=num_splits)) - self.flows.append( - CouplingBlock(in_channels * num_squeeze, - hidden_channels, - kernel_size=kernel_size, - dilation_rate=dilation_rate, - num_layers=num_coupling_layers, - c_in_channels=c_in_channels, - dropout_p=dropout_p, - sigmoid_scale=sigmoid_scale)) + CouplingBlock( + in_channels * num_squeeze, + hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_coupling_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + sigmoid_scale=sigmoid_scale, + ) + ) def forward(self, x, x_mask, g=None, reverse=False): if not reverse: diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py index a08f64a870..51d1066a76 100644 --- a/TTS/tts/layers/glow_tts/duration_predictor.py +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -14,6 +14,7 @@ class DurationPredictor(nn.Module): kernel_size ([type]): [description] dropout_p ([type]): [description] """ + def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p): super().__init__() # class arguments @@ -23,15 +24,9 @@ def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p): self.dropout_p = dropout_p # layers self.drop = nn.Dropout(dropout_p) - self.conv_1 = nn.Conv1d(in_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2) + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) self.norm_1 = LayerNorm(hidden_channels) - self.conv_2 = nn.Conv1d(hidden_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2) + self.conv_2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) self.norm_2 = LayerNorm(hidden_channels) # output layer self.proj = nn.Conv1d(hidden_channels, 1, 1) diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index 8de006a978..48bb3008c3 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -1,14 +1,15 @@ import math + import torch from torch import nn -from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.generic.gated_conv import GatedConvBlock -from TTS.tts.utils.generic_utils import sequence_mask -from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock -from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor -from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock +from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.utils.generic_utils import sequence_mask class Encoder(nn.Module): @@ -69,17 +70,20 @@ class Encoder(nn.Module): 'num_layers': 9, } """ - def __init__(self, - num_chars, - out_channels, - hidden_channels, - hidden_channels_dp, - encoder_type, - encoder_params, - dropout_p_dp=0.1, - mean_only=False, - use_prenet=True, - c_in_channels=0): + + def __init__( + self, + num_chars, + out_channels, + hidden_channels, + hidden_channels_dp, + encoder_type, + encoder_params, + dropout_p_dp=0.1, + mean_only=False, + use_prenet=True, + c_in_channels=0, + ): super().__init__() # class arguments self.num_chars = num_chars @@ -93,47 +97,33 @@ def __init__(self, self.encoder_type = encoder_type # embedding layer self.emb = nn.Embedding(num_chars, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) # init encoder module if encoder_type.lower() == "rel_pos_transformer": if use_prenet: - self.prenet = ResidualConv1dLayerNormBlock(hidden_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_layers=3, - dropout_p=0.5) - self.encoder = RelativePositionTransformer(hidden_channels, - hidden_channels, - hidden_channels, - **encoder_params) - elif encoder_type.lower() == 'gated_conv': + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 + ) + self.encoder = RelativePositionTransformer( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) + elif encoder_type.lower() == "gated_conv": self.encoder = GatedConvBlock(hidden_channels, **encoder_params) - elif encoder_type.lower() == 'residual_conv_bn': + elif encoder_type.lower() == "residual_conv_bn": if use_prenet: - self.prenet = nn.Sequential( - nn.Conv1d(hidden_channels, hidden_channels, 1), - nn.ReLU() - ) - self.encoder = ResidualConv1dBNBlock(hidden_channels, - hidden_channels, - hidden_channels, - **encoder_params) + self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU()) + self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params) self.postnet = nn.Sequential( - nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), - nn.BatchNorm1d(self.hidden_channels)) - elif encoder_type.lower() == 'time_depth_separable': + nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels) + ) + elif encoder_type.lower() == "time_depth_separable": if use_prenet: - self.prenet = ResidualConv1dLayerNormBlock(hidden_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_layers=3, - dropout_p=0.5) - self.encoder = TimeDepthSeparableConvBlock(hidden_channels, - hidden_channels, - hidden_channels, - **encoder_params) + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 + ) + self.encoder = TimeDepthSeparableConvBlock( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) else: raise ValueError(" [!] Unkown encoder type.") @@ -143,8 +133,8 @@ def __init__(self, self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) # duration predictor self.duration_predictor = DurationPredictor( - hidden_channels + c_in_channels, hidden_channels_dp, 3, - dropout_p_dp) + hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp + ) def forward(self, x, x_lengths, g=None): """ @@ -159,15 +149,14 @@ def forward(self, x, x_lengths, g=None): # [B, D, T] x = torch.transpose(x, 1, -1) # compute input sequence mask - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), - 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # prenet - if hasattr(self, 'prenet') and self.use_prenet: + if hasattr(self, "prenet") and self.use_prenet: x = self.prenet(x, x_mask) # encoder x = self.encoder(x, x_mask) # postnet - if hasattr(self, 'postnet'): + if hasattr(self, "postnet"): x = self.postnet(x) * x_mask # set duration predictor input if g is not None: diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index c8ad410d49..18c491e3d1 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -1,14 +1,14 @@ import torch from torch import nn from torch.nn import functional as F + from TTS.tts.layers.generic.wavenet import WN from ..generic.normalization import LayerNorm class ResidualConv1dLayerNormBlock(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, - num_layers, dropout_p): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p): """Conv1d with Layer Normalization and residual connection as in GlowTTS paper. https://arxiv.org/pdf/1811.00002.pdf @@ -38,10 +38,10 @@ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, for idx in range(num_layers): self.conv_layers.append( - nn.Conv1d(in_channels if idx == 0 else hidden_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2)) + nn.Conv1d( + in_channels if idx == 0 else hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) @@ -72,6 +72,7 @@ class InvConvNear(nn.Module): perform 1x1 convolution separately. Cast 1x1 conv operation to 2d by reshaping the input for efficiency. """ + def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument super().__init__() assert num_splits % 2 == 0 @@ -80,8 +81,7 @@ def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pyli self.no_jacobian = no_jacobian self.weight_inv = None - w_init = torch.qr( - torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0] + w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0] if torch.det(w_init) < 0: w_init[:, 0] = -1 * w_init[:, 0] self.weight = nn.Parameter(w_init) @@ -97,28 +97,25 @@ def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=u assert c % self.num_splits == 0 if x_mask is None: x_mask = 1 - x_len = torch.ones((b, ), dtype=x.dtype, device=x.device) * t + x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t else: x_len = torch.sum(x_mask, [1, 2]) x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t) - x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, - c // self.num_splits, t) + x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, c // self.num_splits, t) if reverse: if self.weight_inv is not None: weight = self.weight_inv else: - weight = torch.inverse( - self.weight.float()).to(dtype=self.weight.dtype) + weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) logdet = None else: weight = self.weight if self.no_jacobian: logdet = 0 else: - logdet = torch.logdet( - self.weight) * (c / self.num_splits) * x_len # [b] + logdet = torch.logdet(self.weight) * (c / self.num_splits) * x_len # [b] weight = weight.view(self.num_splits, self.num_splits, 1, 1) z = F.conv2d(x, weight) @@ -128,40 +125,42 @@ def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=u return z, logdet def store_inverse(self): - weight_inv = torch.inverse( - self.weight.float()).to(dtype=self.weight.dtype) + weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) self.weight_inv = nn.Parameter(weight_inv, requires_grad=False) class CouplingBlock(nn.Module): """Glow Affine Coupling block as in GlowTTS paper. - https://arxiv.org/pdf/1811.00002.pdf - - x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o - '-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^ - - Args: - in_channels (int): number of input tensor channels. - hidden_channels (int): number of hidden channels. - kernel_size (int): WaveNet filter kernel size. - dilation_rate (int): rate to increase dilation by each layer in a decoder block. - num_layers (int): number of WaveNet layers. - c_in_channels (int): number of conditioning input channels. - dropout_p (int): wavenet dropout rate. - sigmoid_scale (bool): enable/disable sigmoid scaling for output scale. + https://arxiv.org/pdf/1811.00002.pdf - Note: - It does not use conditional inputs differently from WaveGlow. + x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o + '-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^ + + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of hidden channels. + kernel_size (int): WaveNet filter kernel size. + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_layers (int): number of WaveNet layers. + c_in_channels (int): number of conditioning input channels. + dropout_p (int): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling for output scale. + + Note: + It does not use conditional inputs differently from WaveGlow. """ - def __init__(self, - in_channels, - hidden_channels, - kernel_size, - dilation_rate, - num_layers, - c_in_channels=0, - dropout_p=0, - sigmoid_scale=False): + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + sigmoid_scale=False, + ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels @@ -183,8 +182,7 @@ def __init__(self, end.bias.data.zero_() self.end = end # coupling layers - self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, - num_layers, c_in_channels, dropout_p) + self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument """ @@ -195,15 +193,15 @@ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: d """ if x_mask is None: x_mask = 1 - x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:] + x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :] x = self.start(x_0) * x_mask x = self.wn(x, x_mask, g) out = self.end(x) z_0 = x_0 - t = out[:, :self.in_channels // 2, :] - s = out[:, self.in_channels // 2:, :] + t = out[:, : self.in_channels // 2, :] + s = out[:, self.in_channels // 2 :, :] if self.sigmoid_scale: s = torch.log(1e-6 + torch.sigmoid(s + 2)) diff --git a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py index 78fa0fbf6c..7be124f402 100644 --- a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py +++ b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py @@ -1,11 +1,13 @@ import numpy as np import torch from torch.nn import functional as F + from TTS.tts.utils.generic_utils import sequence_mask try: # TODO: fix pypi cython installation problem. from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c + CYTHON = True except ModuleNotFoundError: CYTHON = False @@ -30,8 +32,7 @@ def generate_path(duration, mask): cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0] - ]))[:, :-1] + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] path = path * mask return path @@ -43,7 +44,7 @@ def maximum_path(value, mask): def maximum_path_cython(value, mask): - """ Cython optimised version. + """Cython optimised version. value: [b, t_x, t_y] mask: [b, t_x, t_y] """ diff --git a/TTS/tts/layers/glow_tts/monotonic_align/core.pyx b/TTS/tts/layers/glow_tts/monotonic_align/core.pyx index 6aabccc4c4..091fcc3a50 100644 --- a/TTS/tts/layers/glow_tts/monotonic_align/core.pyx +++ b/TTS/tts/layers/glow_tts/monotonic_align/core.pyx @@ -1,6 +1,8 @@ import numpy as np -cimport numpy as np + cimport cython +cimport numpy as np + from cython.parallel import prange diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 77ea05f937..1a67d0ba07 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -1,4 +1,5 @@ import math + import torch from torch import nn from torch.nn import functional as F @@ -48,16 +49,19 @@ class RelativePositionMultiHeadAttention(nn.Module): proximal_init (bool, optional): enable/disable poximal init as in the paper. Init key and query layer weights the same. Defaults to False. """ - def __init__(self, - channels, - out_channels, - num_heads, - rel_attn_window_size=None, - heads_share=True, - dropout_p=0., - input_length=None, - proximal_bias=False, - proximal_init=False): + + def __init__( + self, + channels, + out_channels, + num_heads, + rel_attn_window_size=None, + heads_share=True, + dropout_p=0.0, + input_length=None, + proximal_bias=False, + proximal_init=False, + ): super().__init__() assert channels % num_heads == 0, " [!] channels should be divisible by num_heads." @@ -82,15 +86,15 @@ def __init__(self, # relative positional encoding layers if rel_attn_window_size is not None: n_heads_rel = 1 if heads_share else num_heads - rel_stddev = self.k_channels**-0.5 + rel_stddev = self.k_channels ** -0.5 emb_rel_k = nn.Parameter( - torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, - self.k_channels) * rel_stddev) + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) emb_rel_v = nn.Parameter( - torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, - self.k_channels) * rel_stddev) - self.register_parameter('emb_rel_k', emb_rel_k) - self.register_parameter('emb_rel_v', emb_rel_v) + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + self.register_parameter("emb_rel_k", emb_rel_k) + self.register_parameter("emb_rel_v", emb_rel_v) # init layers nn.init.xavier_uniform_(self.conv_q.weight) @@ -112,38 +116,30 @@ def forward(self, x, c, attn_mask=None): def attention(self, query, key, value, mask=None): # reshape [b, d, t] -> [b, n_h, t, d_k] b, d, t_s, t_t = (*key.size(), query.size(2)) - query = query.view(b, self.num_heads, self.k_channels, - t_t).transpose(2, 3) + query = query.view(b, self.num_heads, self.k_channels, t_t).transpose(2, 3) key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) - value = value.view(b, self.num_heads, self.k_channels, - t_s).transpose(2, 3) + value = value.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) # compute raw attention scores - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( - self.k_channels) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) # relative positional encoding for scores if self.rel_attn_window_size is not None: assert t_s == t_t, "Relative attention is only available for self-attention." # get relative key embeddings - key_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys( - query, key_relative_embeddings) - rel_logits = self._relative_position_to_absolute_position( - rel_logits) + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) scores_local = rel_logits / math.sqrt(self.k_channels) scores = scores + scores_local # proximan bias if self.proximal_bias: assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attn_proximity_bias(t_s).to( - device=scores.device, dtype=scores.dtype) + scores = scores + self._attn_proximity_bias(t_s).to(device=scores.device, dtype=scores.dtype) # attention score masking if mask is not None: # add small value to prevent oor error. scores = scores.masked_fill(mask == 0, -1e4) if self.input_length is not None: - block_mask = torch.ones_like(scores).triu( - -1 * self.input_length).tril(self.input_length) + block_mask = torch.ones_like(scores).triu(-1 * self.input_length).tril(self.input_length) scores = scores * block_mask + -1e4 * (1 - block_mask) # attention score normalization p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] @@ -153,14 +149,10 @@ def attention(self, query, key, value, mask=None): output = torch.matmul(p_attn, value) # relative positional encoding for values if self.rel_attn_window_size is not None: - relative_weights = self._absolute_position_to_relative_position( - p_attn) - value_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_v, t_s) - output = output + self._matmul_with_relative_values( - relative_weights, value_relative_embeddings) - output = output.transpose(2, 3).contiguous().view( - b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] return output, p_attn @staticmethod @@ -195,20 +187,16 @@ def _matmul_with_relative_keys(query, re): return logits def _get_relative_embeddings(self, relative_embeddings, length): - """Convert embedding vestors to a tensor of embeddings - """ + """Convert embedding vestors to a tensor of embeddings""" # Pad first before slice to avoid using cond ops. pad_length = max(length - (self.rel_attn_window_size + 1), 0) slice_start_position = max((self.rel_attn_window_size + 1) - length, 0) slice_end_position = slice_start_position + 2 * length - 1 if pad_length > 0: - padded_relative_embeddings = F.pad( - relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) + padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) else: padded_relative_embeddings = relative_embeddings - used_relative_embeddings = padded_relative_embeddings[:, - slice_start_position: - slice_end_position] + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] return used_relative_embeddings @staticmethod @@ -226,8 +214,7 @@ def _relative_position_to_absolute_position(x): x_flat = x.view([batch, heads, length * 2 * length]) x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0]) # Reshape and slice out the padded elements. - x_final = x_flat.view([batch, heads, length + 1, - 2 * length - 1])[:, :, :length, length - 1:] + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] return x_final @staticmethod @@ -239,7 +226,7 @@ def _absolute_position_to_relative_position(x): batch, heads, length, _ = x.size() # padd along column x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) - x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) # add 0's in the beginning that will skew the elements after reshape x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] @@ -267,19 +254,15 @@ def _attn_proximity_bias(length): class FeedForwardNetwork(nn.Module): """Feed Forward Inner layers for Transformer. - Args: - in_channels (int): input tensor channels. - out_channels (int): output tensor channels. - hidden_channels (int): inner layers hidden channels. - kernel_size (int): conv1d filter kernel size. - dropout_p (float, optional): dropout rate. Defaults to 0. + Args: + in_channels (int): input tensor channels. + out_channels (int): output tensor channels. + hidden_channels (int): inner layers hidden channels. + kernel_size (int): conv1d filter kernel size. + dropout_p (float, optional): dropout rate. Defaults to 0. """ - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dropout_p=0.): + + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0): super().__init__() self.in_channels = in_channels @@ -288,14 +271,8 @@ def __init__(self, self.kernel_size = kernel_size self.dropout_p = dropout_p - self.conv_1 = nn.Conv1d(in_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2) - self.conv_2 = nn.Conv1d(hidden_channels, - out_channels, - kernel_size, - padding=kernel_size // 2) + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2) self.dropout = nn.Dropout(dropout_p) def forward(self, x, x_mask): @@ -308,34 +285,37 @@ def forward(self, x, x_mask): class RelativePositionTransformer(nn.Module): """Transformer with Relative Potional Encoding. - https://arxiv.org/abs/1803.02155 + https://arxiv.org/abs/1803.02155 - Args: - in_channels (int): number of channels of the input tensor. - out_chanels (int): number of channels of the output tensor. - hidden_channels (int): model hidden channels. - hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. - num_heads (int): number of attention heads. - num_layers (int): number of transformer layers. - kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. - dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. - rel_attn_window_size (int, optional): relation attention window size. - If 4, for each time step next and previous 4 time steps are attended. - If default, relative encoding is disabled and it is a regular transformer. - Defaults to None. - input_length (int, optional): input lenght to limit position encoding. Defaults to None. + Args: + in_channels (int): number of channels of the input tensor. + out_chanels (int): number of channels of the output tensor. + hidden_channels (int): model hidden channels. + hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. + num_heads (int): number of attention heads. + num_layers (int): number of transformer layers. + kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. + dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + input_length (int, optional): input lenght to limit position encoding. Defaults to None. """ - def __init__(self, - in_channels, - out_channels, - hidden_channels, - hidden_channels_ffn, - num_heads, - num_layers, - kernel_size=1, - dropout_p=0., - rel_attn_window_size=None, - input_length=None): + + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + hidden_channels_ffn, + num_heads, + num_layers, + kernel_size=1, + dropout_p=0.0, + rel_attn_window_size=None, + input_length=None, + ): super().__init__() self.hidden_channels = hidden_channels self.hidden_channels_ffn = hidden_channels_ffn @@ -359,7 +339,9 @@ def __init__(self, num_heads, rel_attn_window_size=rel_attn_window_size, dropout_p=dropout_p, - input_length=input_length)) + input_length=input_length, + ) + ) self.norm_layers_1.append(LayerNorm(hidden_channels)) if hidden_channels != out_channels and (idx + 1) == self.num_layers: @@ -368,15 +350,14 @@ def __init__(self, self.ffn_layers.append( FeedForwardNetwork( hidden_channels, - hidden_channels if - (idx + 1) != self.num_layers else out_channels, + hidden_channels if (idx + 1) != self.num_layers else out_channels, hidden_channels_ffn, kernel_size, - dropout_p=dropout_p)) + dropout_p=dropout_p, + ) + ) - self.norm_layers_2.append( - LayerNorm(hidden_channels if ( - idx + 1) != self.num_layers else out_channels)) + self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) def forward(self, x, x_mask): """ @@ -394,7 +375,7 @@ def forward(self, x, x_mask): y = self.ffn_layers[i](x, x_mask) y = self.dropout(y) - if (i + 1) == self.num_layers and hasattr(self, 'proj'): + if (i + 1) == self.num_layers and hasattr(self, "proj"): x = self.proj(x) x = self.norm_layers_2[i](x + y) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 03baf488dd..51772babcf 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -1,8 +1,10 @@ import math + import numpy as np import torch from torch import nn from torch.nn import functional + from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.utils.ssim import ssim @@ -34,19 +36,16 @@ class for each corresponding step. """ # mask: (batch, max_len, 1) target.requires_grad = False - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).unsqueeze(2).float() + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) - loss = functional.l1_loss(x * mask, - target * mask, - reduction='none') + loss = functional.l1_loss(x * mask, target * mask, reduction="none") loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) - loss = functional.l1_loss(x * mask, target * mask, reduction='sum') + loss = functional.l1_loss(x * mask, target * mask, reduction="sum") loss = loss / mask.sum() return loss @@ -76,27 +75,23 @@ class for each corresponding step. """ # mask: (batch, max_len, 1) target.requires_grad = False - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).unsqueeze(2).float() + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) - loss = functional.mse_loss(x * mask, - target * mask, - reduction='none') + loss = functional.mse_loss(x * mask, target * mask, reduction="none") loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) - loss = functional.mse_loss(x * mask, - target * mask, - reduction='sum') + loss = functional.mse_loss(x * mask, target * mask, reduction="sum") loss = loss / mask.sum() return loss class SSIMLoss(torch.nn.Module): """SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity""" + def __init__(self): super().__init__() self.loss_func = ssim @@ -115,9 +110,7 @@ def forward(self, y_hat, y, length=None): loss: An average loss value in range [0, 1] masked by the length. """ if length is not None: - m = sequence_mask(sequence_length=length, - max_len=y.size(1)).unsqueeze(2).float().to( - y_hat.device) + m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device) y_hat, y = y_hat * m, y * m return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1)) @@ -163,25 +156,20 @@ class for each corresponding step. # mask: (batch, max_len, 1) target.requires_grad = False if length is not None: - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).float() + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() x = x * mask target = target * mask num_items = mask.sum() else: num_items = torch.numel(x) - loss = functional.binary_cross_entropy_with_logits( - x, - target, - pos_weight=self.pos_weight, - reduction='sum') + loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum") loss = loss / num_items return loss class DifferentailSpectralLoss(nn.Module): """Differential Spectral Loss - https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf""" + https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf""" def __init__(self, loss_func): super().__init__() @@ -200,7 +188,7 @@ def forward(self, x, target, length=None): target_diff = target[:, 1:] - target[:, :-1] if length is None: return self.loss_func(x_diff, target_diff) - return self.loss_func(x_diff, target_diff, length-1) + return self.loss_func(x_diff, target_diff, length - 1) class GuidedAttentionLoss(torch.nn.Module): @@ -214,8 +202,7 @@ def _make_ga_masks(self, ilens, olens): max_olen = max(olens) ga_masks = torch.zeros((B, max_olen, max_ilen)) for idx, (ilen, olen) in enumerate(zip(ilens, olens)): - ga_masks[idx, :olen, :ilen] = self._make_ga_mask( - ilen, olen, self.sigma) + ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma) return ga_masks def forward(self, att_ws, ilens, olens): @@ -229,8 +216,7 @@ def forward(self, att_ws, ilens, olens): def _make_ga_mask(ilen, olen, sigma): grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) grid_x, grid_y = grid_x.float(), grid_y.float() - return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen)**2 / - (2 * (sigma**2))) + return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2))) @staticmethod def _make_masks(ilens, olens): @@ -249,16 +235,17 @@ def forward(self, x, y, length=None): length: B """ mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float() - return torch.nn.functional.smooth_l1_loss( - x * mask, y * mask, reduction='sum') / mask.sum() + return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum() ######################## # MODEL LOSS LAYERS ######################## + class TacotronLoss(torch.nn.Module): """Collection of Tacotron set-up based on provided config.""" + def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4): super(TacotronLoss, self).__init__() self.stopnet_pos_weight = stopnet_pos_weight @@ -273,12 +260,9 @@ def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4): # postnet and decoder loss if c.loss_masking: - self.criterion = L1LossMasked(c.seq_len_norm) if c.model in [ - "Tacotron" - ] else MSELossMasked(c.seq_len_norm) + self.criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron"] else MSELossMasked(c.seq_len_norm) else: - self.criterion = nn.L1Loss() if c.model in ["Tacotron" - ] else nn.MSELoss() + self.criterion = nn.L1Loss() if c.model in ["Tacotron"] else nn.MSELoss() # guided attention loss if c.ga_alpha > 0: self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma) @@ -290,13 +274,23 @@ def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4): self.criterion_ssim = SSIMLoss() # stopnet loss # pylint: disable=not-callable - self.criterion_st = BCELossMasked( - pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None - - def forward(self, postnet_output, decoder_output, mel_input, linear_input, - stopnet_output, stopnet_target, output_lens, decoder_b_output, - alignments, alignment_lens, alignments_backwards, input_lens): - + self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None + + def forward( + self, + postnet_output, + decoder_output, + mel_input, + linear_input, + stopnet_output, + stopnet_target, + output_lens, + decoder_b_output, + alignments, + alignment_lens, + alignments_backwards, + input_lens, + ): # decoder outputs linear or mel spectrograms for Tacotron and Tacotron2 # the target should be set acccordingly @@ -309,85 +303,80 @@ def forward(self, postnet_output, decoder_output, mel_input, linear_input, # decoder and postnet losses if self.config.loss_masking: if self.decoder_alpha > 0: - decoder_loss = self.criterion(decoder_output, mel_input, - output_lens) + decoder_loss = self.criterion(decoder_output, mel_input, output_lens) if self.postnet_alpha > 0: - postnet_loss = self.criterion(postnet_output, postnet_target, - output_lens) + postnet_loss = self.criterion(postnet_output, postnet_target, output_lens) else: if self.decoder_alpha > 0: decoder_loss = self.criterion(decoder_output, mel_input) if self.postnet_alpha > 0: postnet_loss = self.criterion(postnet_output, postnet_target) loss = self.decoder_alpha * decoder_loss + self.postnet_alpha * postnet_loss - return_dict['decoder_loss'] = decoder_loss - return_dict['postnet_loss'] = postnet_loss + return_dict["decoder_loss"] = decoder_loss + return_dict["postnet_loss"] = postnet_loss # stopnet loss - stop_loss = self.criterion_st( - stopnet_output, stopnet_target, - output_lens) if self.config.stopnet else torch.zeros(1) + stop_loss = ( + self.criterion_st(stopnet_output, stopnet_target, output_lens) if self.config.stopnet else torch.zeros(1) + ) if not self.config.separate_stopnet and self.config.stopnet: loss += stop_loss - return_dict['stopnet_loss'] = stop_loss + return_dict["stopnet_loss"] = stop_loss # backward decoder loss (if enabled) if self.config.bidirectional_decoder: if self.config.loss_masking: - decoder_b_loss = self.criterion( - torch.flip(decoder_b_output, dims=(1, )), mel_input, - output_lens) + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input, output_lens) else: - decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input) - decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1, )), decoder_output) + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1,)), decoder_output) loss += self.decoder_alpha * (decoder_b_loss + decoder_c_loss) - return_dict['decoder_b_loss'] = decoder_b_loss - return_dict['decoder_c_loss'] = decoder_c_loss + return_dict["decoder_b_loss"] = decoder_b_loss + return_dict["decoder_c_loss"] = decoder_c_loss # double decoder consistency loss (if enabled) if self.config.double_decoder_consistency: if self.config.loss_masking: - decoder_b_loss = self.criterion(decoder_b_output, mel_input, - output_lens) + decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens) else: decoder_b_loss = self.criterion(decoder_b_output, mel_input) # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss) - return_dict['decoder_coarse_loss'] = decoder_b_loss - return_dict['decoder_ddc_loss'] = attention_c_loss + return_dict["decoder_coarse_loss"] = decoder_b_loss + return_dict["decoder_ddc_loss"] = attention_c_loss # guided attention loss (if enabled) if self.config.ga_alpha > 0: ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens) loss += ga_loss * self.ga_alpha - return_dict['ga_loss'] = ga_loss + return_dict["ga_loss"] = ga_loss # decoder differential spectral loss if self.config.decoder_diff_spec_alpha > 0: decoder_diff_spec_loss = self.criterion_diff_spec(decoder_output, mel_input, output_lens) loss += decoder_diff_spec_loss * self.decoder_diff_spec_alpha - return_dict['decoder_diff_spec_loss'] = decoder_diff_spec_loss + return_dict["decoder_diff_spec_loss"] = decoder_diff_spec_loss # postnet differential spectral loss if self.config.postnet_diff_spec_alpha > 0: postnet_diff_spec_loss = self.criterion_diff_spec(postnet_output, postnet_target, output_lens) loss += postnet_diff_spec_loss * self.postnet_diff_spec_alpha - return_dict['postnet_diff_spec_loss'] = postnet_diff_spec_loss + return_dict["postnet_diff_spec_loss"] = postnet_diff_spec_loss # decoder ssim loss if self.config.decoder_ssim_alpha > 0: decoder_ssim_loss = self.criterion_ssim(decoder_output, mel_input, output_lens) loss += decoder_ssim_loss * self.postnet_ssim_alpha - return_dict['decoder_ssim_loss'] = decoder_ssim_loss + return_dict["decoder_ssim_loss"] = decoder_ssim_loss # postnet ssim loss if self.config.postnet_ssim_alpha > 0: postnet_ssim_loss = self.criterion_ssim(postnet_output, postnet_target, output_lens) loss += postnet_ssim_loss * self.postnet_ssim_alpha - return_dict['postnet_ssim_loss'] = postnet_ssim_loss + return_dict["postnet_ssim_loss"] = postnet_ssim_loss - return_dict['loss'] = loss + return_dict["loss"] = loss # check if any loss is NaN for key, loss in return_dict.items(): @@ -401,22 +390,18 @@ def __init__(self): super().__init__() self.constant_factor = 0.5 * math.log(2 * math.pi) - def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, - o_attn_dur, x_lengths): + def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths): return_dict = {} # flow loss - neg log likelihood - pz = torch.sum(scales) + 0.5 * torch.sum( - torch.exp(-2 * scales) * (z - means)**2) - log_mle = self.constant_factor + (pz - torch.sum(log_det)) / ( - torch.sum(y_lengths) * z.shape[1]) + pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2) + log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[1]) # duration loss - MSE # loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths) # duration loss - huber loss - loss_dur = torch.nn.functional.smooth_l1_loss( - o_dur_log, o_attn_dur, reduction='sum') / torch.sum(x_lengths) - return_dict['loss'] = log_mle + loss_dur - return_dict['log_mle'] = log_mle - return_dict['loss_dur'] = loss_dur + loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths) + return_dict["loss"] = log_mle + loss_dur + return_dict["log_mle"] = log_mle + return_dict["loss_dur"] = loss_dur # check if any loss is NaN for key, loss in return_dict.items(): @@ -441,7 +426,7 @@ def forward(self, decoder_output, decoder_target, decoder_output_lens, dur_outpu ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) huber_loss = self.huber(dur_output, dur_target, input_lens) loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss - return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss} + return {"loss": loss, "loss_l1": l1_loss, "loss_ssim": ssim_loss, "loss_dur": huber_loss} def mse_loss_custom(x, y): @@ -452,26 +437,27 @@ def mse_loss_custom(x, y): class MDNLoss(nn.Module): - """Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf. - """ + """Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.""" def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use - ''' + """ Shapes: mu: [B, D, T] log_sigma: [B, D, T] mel_spec: [B, D, T] - ''' + """ B, T_seq, T_mel = logp.shape - log_alpha = logp.new_ones(B, T_seq, T_mel)*(-1e4) + log_alpha = logp.new_ones(B, T_seq, T_mel) * (-1e4) log_alpha[:, 0, 0] = logp[:, 0, 0] for t in range(1, T_mel): - prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], - (0, 0, 1, -1), value=-1e4)], dim=-1) + prev_step = torch.cat( + [log_alpha[:, :, t - 1 : t], functional.pad(log_alpha[:, :, t - 1 : t], (0, 0, 1, -1), value=-1e4)], + dim=-1, + ) log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t] - alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1] + alpha_last = log_alpha[torch.arange(B), text_lengths - 1, mel_lengths - 1] mdn_loss = -alpha_last.mean() / T_seq - return mdn_loss#, log_prob_matrix + return mdn_loss # , log_prob_matrix class AlignTTSLoss(nn.Module): @@ -487,6 +473,7 @@ class AlignTTSLoss(nn.Module): Args: c (dict): TTS model configuration. """ + def __init__(self, c): super().__init__() self.mdn_loss = MDNLoss() @@ -499,10 +486,10 @@ def __init__(self, c): self.spec_loss_alpha = c.spec_loss_alpha self.mdn_alpha = c.mdn_alpha - def forward(self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, - input_lens, step, phase): - ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas( - step) + def forward( + self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase + ): + ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step) spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0 if phase == 0: mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) @@ -521,11 +508,11 @@ def forward(self, logp, decoder_output, decoder_target, decoder_output_lens, dur ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss - return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss} + return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} @staticmethod def _set_alpha(step, alpha_settings): - '''Set the loss alpha wrt number of steps. + """Set the loss alpha wrt number of steps. Return the corresponding value if no schedule is set. Example: @@ -536,7 +523,7 @@ def _set_alpha(step, alpha_settings): Args: step (int): number of training steps. alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above. - ''' + """ return_alpha = None if isinstance(alpha_settings, list): for key, alpha in alpha_settings: @@ -547,8 +534,7 @@ def _set_alpha(step, alpha_settings): return return_alpha def set_alphas(self, step): - '''Set the alpha values for all the loss functions - ''' + """Set the alpha values for all the loss functions""" ssim_alpha = self._set_alpha(step, self.ssim_alpha) dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha) spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha) diff --git a/TTS/tts/layers/tacotron/attentions.py b/TTS/tts/layers/tacotron/attentions.py index 1f682e4cf9..cbb643b845 100644 --- a/TTS/tts/layers/tacotron/attentions.py +++ b/TTS/tts/layers/tacotron/attentions.py @@ -1,9 +1,9 @@ import torch +from scipy.stats import betabinom from torch import nn from torch.nn import functional as F from TTS.tts.layers.tacotron.common_layers import Linear -from scipy.stats import betabinom class LocationLayer(nn.Module): @@ -14,10 +14,8 @@ class LocationLayer(nn.Module): attention_n_filters (int, optional): number of filters in convolution. Defaults to 32. attention_kernel_size (int, optional): kernel size of convolution filter. Defaults to 31. """ - def __init__(self, - attention_dim, - attention_n_filters=32, - attention_kernel_size=31): + + def __init__(self, attention_dim, attention_n_filters=32, attention_kernel_size=31): super(LocationLayer, self).__init__() self.location_conv1d = nn.Conv1d( in_channels=2, @@ -25,9 +23,9 @@ def __init__(self, kernel_size=attention_kernel_size, stride=1, padding=(attention_kernel_size - 1) // 2, - bias=False) - self.location_dense = Linear( - attention_n_filters, attention_dim, bias=False, init_gain='tanh') + bias=False, + ) + self.location_dense = Linear(attention_n_filters, attention_dim, bias=False, init_gain="tanh") def forward(self, attention_cat): """ @@ -35,8 +33,7 @@ def forward(self, attention_cat): attention_cat: [B, 2, C] """ processed_attention = self.location_conv1d(attention_cat) - processed_attention = self.location_dense( - processed_attention.transpose(1, 2)) + processed_attention = self.location_dense(processed_attention.transpose(1, 2)) return processed_attention @@ -49,6 +46,7 @@ class GravesAttention(nn.Module): query_dim (int): number of channels in query tensor. K (int): number of Gaussian heads to be used for computing attention. """ + COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) def __init__(self, query_dim, K): @@ -60,20 +58,19 @@ def __init__(self, query_dim, K): self.eps = 1e-5 self.J = None self.N_a = nn.Sequential( - nn.Linear(query_dim, query_dim, bias=True), - nn.ReLU(), - nn.Linear(query_dim, 3*K, bias=True)) + nn.Linear(query_dim, query_dim, bias=True), nn.ReLU(), nn.Linear(query_dim, 3 * K, bias=True) + ) self.attention_weights = None self.mu_prev = None self.init_layers() def init_layers(self): - torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) # bias mean - torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) # bias std + torch.nn.init.constant_(self.N_a[2].bias[(2 * self.K) : (3 * self.K)], 1.0) # bias mean + torch.nn.init.constant_(self.N_a[2].bias[self.K : (2 * self.K)], 10) # bias std def init_states(self, inputs): - if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5 + if self.J is None or inputs.shape[1] + 1 > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1] + 2.0).to(inputs.device) + 0.5 self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) @@ -108,7 +105,7 @@ def forward(self, query, inputs, processed_inputs, mask): mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) g_t = torch.softmax(g_t, dim=-1) + self.eps - j = self.J[:inputs.size(1)+1] + j = self.J[: inputs.size(1) + 1] # attention weights phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1)))) @@ -164,21 +161,29 @@ class OriginalAttention(nn.Module): trans_agent (bool): enable/disable transition agent in the forward attention. forward_attn_mask (int): enable/disable an explicit masking in forward attention. It is useful to set at especially inference time. """ + # Pylint gets confused by PyTorch conventions here - #pylint: disable=attribute-defined-outside-init - def __init__(self, query_dim, embedding_dim, attention_dim, - location_attention, attention_location_n_filters, - attention_location_kernel_size, windowing, norm, forward_attn, - trans_agent, forward_attn_mask): + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ): super(OriginalAttention, self).__init__() - self.query_layer = Linear( - query_dim, attention_dim, bias=False, init_gain='tanh') - self.inputs_layer = Linear( - embedding_dim, attention_dim, bias=False, init_gain='tanh') + self.query_layer = Linear(query_dim, attention_dim, bias=False, init_gain="tanh") + self.inputs_layer = Linear(embedding_dim, attention_dim, bias=False, init_gain="tanh") self.v = Linear(attention_dim, 1, bias=True) if trans_agent: - self.ta = nn.Linear( - query_dim + embedding_dim, 1, bias=True) + self.ta = nn.Linear(query_dim + embedding_dim, 1, bias=True) if location_attention: self.location_layer = LocationLayer( attention_dim, @@ -202,9 +207,7 @@ def init_win_idx(self): def init_forward_attn(self, inputs): B = inputs.shape[0] T = inputs.shape[1] - self.alpha = torch.cat( - [torch.ones([B, 1]), - torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) + self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) def init_location_attention(self, inputs): @@ -230,14 +233,10 @@ def update_location_attention(self, alignments): self.attention_weights_cum += alignments def get_location_attention(self, query, processed_inputs): - attention_cat = torch.cat((self.attention_weights.unsqueeze(1), - self.attention_weights_cum.unsqueeze(1)), - dim=1) + attention_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) processed_query = self.query_layer(query.unsqueeze(1)) processed_attention_weights = self.location_layer(attention_cat) - energies = self.v( - torch.tanh(processed_query + processed_attention_weights + - processed_inputs)) + energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_inputs)) energies = energies.squeeze(-1) return energies, processed_query @@ -264,24 +263,17 @@ def apply_windowing(self, attention, inputs): def apply_forward_attention(self, alignment): # forward attention - fwd_shifted_alpha = F.pad( - self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0)) + fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0)) # compute transition potentials - alpha = ((1 - self.u) * self.alpha - + self.u * fwd_shifted_alpha - + 1e-8) * alignment + alpha = ((1 - self.u) * self.alpha + self.u * fwd_shifted_alpha + 1e-8) * alignment # force incremental alignment if not self.training and self.forward_attn_mask: _, n = fwd_shifted_alpha.max(1) val, _ = alpha.max(1) for b in range(alignment.shape[0]): - alpha[b, n[b] + 3:] = 0 - alpha[b, :( - n[b] - 1 - )] = 0 # ignore all previous states to prevent repetition. - alpha[b, - (n[b] - 2 - )] = 0.01 * val[b] # smoothing factor for the prev step + alpha[b, n[b] + 3 :] = 0 + alpha[b, : (n[b] - 1)] = 0 # ignore all previous states to prevent repetition. + alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step # renormalize attention weights alpha = alpha / alpha.sum(dim=1, keepdim=True) return alpha @@ -295,11 +287,9 @@ def forward(self, query, inputs, processed_inputs, mask): mask: [B, T_en] """ if self.location_attention: - attention, _ = self.get_location_attention( - query, processed_inputs) + attention, _ = self.get_location_attention(query, processed_inputs) else: - attention, _ = self.get_attention( - query, processed_inputs) + attention, _ = self.get_attention(query, processed_inputs) # apply masking if mask is not None: attention.data.masked_fill_(~mask, self._mask_value) @@ -311,9 +301,7 @@ def forward(self, query, inputs, processed_inputs, mask): if self.norm == "softmax": alignment = torch.softmax(attention, dim=-1) elif self.norm == "sigmoid": - alignment = torch.sigmoid(attention) / torch.sigmoid( - attention).sum( - dim=1, keepdim=True) + alignment = torch.sigmoid(attention) / torch.sigmoid(attention).sum(dim=1, keepdim=True) else: raise ValueError("Unknown value for attention norm type") @@ -367,19 +355,20 @@ class MonotonicDynamicConvolutionAttention(nn.Module): alpha (float, optional): [description]. Defaults to 0.1 from the paper. beta (float, optional): [description]. Defaults to 0.9 from the paper. """ + def __init__( - self, - query_dim, - embedding_dim, # pylint: disable=unused-argument - attention_dim, - static_filter_dim, - static_kernel_size, - dynamic_filter_dim, - dynamic_kernel_size, - prior_filter_len=11, - alpha=0.1, - beta=0.9, - ): + self, + query_dim, + embedding_dim, # pylint: disable=unused-argument + attention_dim, + static_filter_dim, + static_kernel_size, + dynamic_filter_dim, + dynamic_kernel_size, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ): super().__init__() self._mask_value = 1e-8 self.dynamic_filter_dim = dynamic_filter_dim @@ -388,9 +377,7 @@ def __init__( self.attention_weights = None # setup key and query layers self.query_layer = nn.Linear(query_dim, attention_dim) - self.key_layer = nn.Linear( - attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False - ) + self.key_layer = nn.Linear(attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False) self.static_filter_conv = nn.Conv1d( 1, static_filter_dim, @@ -402,8 +389,7 @@ def __init__( self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim) self.v = nn.Linear(attention_dim, 1, bias=False) - prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, - alpha, beta) + prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, alpha, beta) self.register_buffer("prior", torch.FloatTensor(prior).flip(0)) # pylint: disable=unused-argument @@ -416,8 +402,8 @@ def forward(self, query, inputs, processed_inputs, mask): """ # compute prior filters prior_filter = F.conv1d( - F.pad(self.attention_weights.unsqueeze(1), - (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1)) + F.pad(self.attention_weights.unsqueeze(1), (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1) + ) prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1) G = self.key_layer(torch.tanh(self.query_layer(query))) # compute dynamic filters @@ -430,10 +416,12 @@ def forward(self, query, inputs, processed_inputs, mask): dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2) # compute static filters static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2) - alignment = self.v( - torch.tanh( - self.static_filter_layer(static_filter) + - self.dynamic_filter_layer(dynamic_filter))).squeeze(-1) + prior_filter + alignment = ( + self.v( + torch.tanh(self.static_filter_layer(static_filter) + self.dynamic_filter_layer(dynamic_filter)) + ).squeeze(-1) + + prior_filter + ) # compute attention weights attention_weights = F.softmax(alignment, dim=-1) # apply masking @@ -451,33 +439,52 @@ def init_states(self, inputs): B = inputs.size(0) T = inputs.size(1) self.attention_weights = torch.zeros([B, T], device=inputs.device) - self.attention_weights[:, 0] = 1. - - -def init_attn(attn_type, query_dim, embedding_dim, attention_dim, - location_attention, attention_location_n_filters, - attention_location_kernel_size, windowing, norm, forward_attn, - trans_agent, forward_attn_mask, attn_K): + self.attention_weights[:, 0] = 1.0 + + +def init_attn( + attn_type, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + attn_K, +): if attn_type == "original": - return OriginalAttention(query_dim, embedding_dim, attention_dim, - location_attention, - attention_location_n_filters, - attention_location_kernel_size, windowing, - norm, forward_attn, trans_agent, - forward_attn_mask) + return OriginalAttention( + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ) if attn_type == "graves": return GravesAttention(query_dim, attn_K) if attn_type == "dynamic_convolution": - return MonotonicDynamicConvolutionAttention(query_dim, - embedding_dim, - attention_dim, - static_filter_dim=8, - static_kernel_size=21, - dynamic_filter_dim=8, - dynamic_kernel_size=21, - prior_filter_len=11, - alpha=0.1, - beta=0.9) - - raise RuntimeError( - " [!] Given Attention Type '{attn_type}' is not exist.") + return MonotonicDynamicConvolutionAttention( + query_dim, + embedding_dim, + attention_dim, + static_filter_dim=8, + static_kernel_size=21, + dynamic_filter_dim=8, + dynamic_kernel_size=21, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ) + + raise RuntimeError(" [!] Given Attention Type '{attn_type}' is not exist.") diff --git a/TTS/tts/layers/tacotron/common_layers.py b/TTS/tts/layers/tacotron/common_layers.py index a23bb3f9bf..e2660cda2c 100644 --- a/TTS/tts/layers/tacotron/common_layers.py +++ b/TTS/tts/layers/tacotron/common_layers.py @@ -12,20 +12,14 @@ class Linear(nn.Module): bias (bool, optional): enable/disable bias in the layer. Defaults to True. init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'. """ - def __init__(self, - in_features, - out_features, - bias=True, - init_gain='linear'): + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): super(Linear, self).__init__() - self.linear_layer = torch.nn.Linear( - in_features, out_features, bias=bias) + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) self._init_w(init_gain) def _init_w(self, init_gain): - torch.nn.init.xavier_uniform_( - self.linear_layer.weight, - gain=torch.nn.init.calculate_gain(init_gain)) + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) def forward(self, x): return self.linear_layer(x) @@ -42,21 +36,15 @@ class LinearBN(nn.Module): bias (bool, optional): enable/disable bias in the linear layer. Defaults to True. init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'. """ - def __init__(self, - in_features, - out_features, - bias=True, - init_gain='linear'): + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): super(LinearBN, self).__init__() - self.linear_layer = torch.nn.Linear( - in_features, out_features, bias=bias) + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5) self._init_w(init_gain) def _init_w(self, init_gain): - torch.nn.init.xavier_uniform_( - self.linear_layer.weight, - gain=torch.nn.init.calculate_gain(init_gain)) + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) def forward(self, x): """ @@ -96,27 +84,21 @@ class Prenet(nn.Module): Defaults to [256, 256]. bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True. """ + # pylint: disable=dangerous-default-value - def __init__(self, - in_features, - prenet_type="original", - prenet_dropout=True, - out_features=[256, 256], - bias=True): + def __init__(self, in_features, prenet_type="original", prenet_dropout=True, out_features=[256, 256], bias=True): super(Prenet, self).__init__() self.prenet_type = prenet_type self.prenet_dropout = prenet_dropout in_features = [in_features] + out_features[:-1] if prenet_type == "bn": - self.linear_layers = nn.ModuleList([ - LinearBN(in_size, out_size, bias=bias) - for (in_size, out_size) in zip(in_features, out_features) - ]) + self.linear_layers = nn.ModuleList( + [LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) elif prenet_type == "original": - self.linear_layers = nn.ModuleList([ - Linear(in_size, out_size, bias=bias) - for (in_size, out_size) in zip(in_features, out_features) - ]) + self.linear_layers = nn.ModuleList( + [Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) def forward(self, x): for linear in self.linear_layers: diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py index 63e760703a..e2784e5dc6 100644 --- a/TTS/tts/layers/tacotron/gst_layers.py +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -11,8 +11,7 @@ class GST(nn.Module): def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim=None): super().__init__() self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim) - self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, - gst_embedding_dim, speaker_embedding_dim) + self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim) def forward(self, inputs, speaker_embedding=None): enc_out = self.encoder(inputs) @@ -39,24 +38,17 @@ def __init__(self, num_mel, embedding_dim): num_layers = len(filters) - 1 convs = [ nn.Conv2d( - in_channels=filters[i], - out_channels=filters[i + 1], - kernel_size=(3, 3), - stride=(2, 2), - padding=(1, 1)) for i in range(num_layers) + in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) + ) + for i in range(num_layers) ] self.convs = nn.ModuleList(convs) - self.bns = nn.ModuleList([ - nn.BatchNorm2d(num_features=filter_size) - for filter_size in filters[1:] - ]) + self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) - post_conv_height = self.calculate_post_conv_height( - num_mel, 3, 2, 1, num_layers) + post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers) self.recurrence = nn.GRU( - input_size=filters[-1] * post_conv_height, - hidden_size=embedding_dim // 2, - batch_first=True) + input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True + ) def forward(self, inputs): batch_size = inputs.size(0) @@ -81,8 +73,7 @@ def forward(self, inputs): return out.squeeze(0) @staticmethod - def calculate_post_conv_height(height, kernel_size, stride, pad, - n_convs): + def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): """Height of spec after n convolutions with fixed kernel/stride/pad.""" for _ in range(n_convs): height = (height - kernel_size + 2 * pad) // stride + 1 @@ -92,8 +83,7 @@ def calculate_post_conv_height(height, kernel_size, stride, pad, class StyleTokenLayer(nn.Module): """NN Module attending to style tokens based on prosody encodings.""" - def __init__(self, num_heads, num_style_tokens, - embedding_dim, speaker_embedding_dim=None): + def __init__(self, num_heads, num_style_tokens, embedding_dim, speaker_embedding_dim=None): super().__init__() self.query_dim = embedding_dim // 2 @@ -102,35 +92,31 @@ def __init__(self, num_heads, num_style_tokens, self.query_dim += speaker_embedding_dim self.key_dim = embedding_dim // num_heads - self.style_tokens = nn.Parameter( - torch.FloatTensor(num_style_tokens, self.key_dim)) + self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim)) nn.init.normal_(self.style_tokens, mean=0, std=0.5) self.attention = MultiHeadAttention( - query_dim=self.query_dim, - key_dim=self.key_dim, - num_units=embedding_dim, - num_heads=num_heads) + query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads + ) def forward(self, inputs): batch_size = inputs.size(0) prosody_encoding = inputs.unsqueeze(1) # prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128] - tokens = torch.tanh(self.style_tokens) \ - .unsqueeze(0) \ - .expand(batch_size, -1, -1) + tokens = torch.tanh(self.style_tokens).unsqueeze(0).expand(batch_size, -1, -1) # tokens: 3D tensor [batch_size, num tokens, token embedding size] style_embed = self.attention(prosody_encoding, tokens) return style_embed + class MultiHeadAttention(nn.Module): - ''' + """ input: query --- [N, T_q, query_dim] key --- [N, T_k, key_dim] output: out --- [N, T_q, num_units] - ''' + """ def __init__(self, query_dim, key_dim, num_units, num_heads): @@ -139,12 +125,9 @@ def __init__(self, query_dim, key_dim, num_units, num_heads): self.num_heads = num_heads self.key_dim = key_dim - self.W_query = nn.Linear( - in_features=query_dim, out_features=num_units, bias=False) - self.W_key = nn.Linear( - in_features=key_dim, out_features=num_units, bias=False) - self.W_value = nn.Linear( - in_features=key_dim, out_features=num_units, bias=False) + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) def forward(self, query, key): queries = self.W_query(query) # [N, T_q, num_units] @@ -152,25 +135,17 @@ def forward(self, query, key): values = self.W_value(key) split_size = self.num_units // self.num_heads - queries = torch.stack( - torch.split(queries, split_size, dim=2), - dim=0) # [h, N, T_q, num_units/h] - keys = torch.stack( - torch.split(keys, split_size, dim=2), - dim=0) # [h, N, T_k, num_units/h] - values = torch.stack( - torch.split(values, split_size, dim=2), - dim=0) # [h, N, T_k, num_units/h] + queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] # score = softmax(QK^T / (d_k ** 0.5)) scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] - scores = scores / (self.key_dim**0.5) + scores = scores / (self.key_dim ** 0.5) scores = F.softmax(scores, dim=3) # out = score * V out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] - out = torch.cat( - torch.split(out, 1, dim=0), - dim=3).squeeze(0) # [N, T_q, num_units] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] return out diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index c79edcc300..95930a0576 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -1,8 +1,9 @@ # coding: utf-8 import torch from torch import nn -from .common_layers import Prenet + from .attentions import init_attn +from .common_layers import Prenet class BatchNormConv1d(nn.Module): @@ -23,24 +24,14 @@ class BatchNormConv1d(nn.Module): - output: (B, D) """ - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - activation=None): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None): super(BatchNormConv1d, self).__init__() self.padding = padding self.padder = nn.ConstantPad1d(padding, 0) self.conv1d = nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=0, - bias=False) + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=False + ) # Following tensorflow's default parameters self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) self.activation = activation @@ -48,15 +39,14 @@ def __init__(self, def init_layers(self): if isinstance(self.activation, torch.nn.ReLU): - w_gain = 'relu' + w_gain = "relu" elif isinstance(self.activation, torch.nn.Tanh): - w_gain = 'tanh' + w_gain = "tanh" elif self.activation is None: - w_gain = 'linear' + w_gain = "linear" else: - raise RuntimeError('Unknown activation function') - torch.nn.init.xavier_uniform_( - self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain)) + raise RuntimeError("Unknown activation function") + torch.nn.init.xavier_uniform_(self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain)) def forward(self, x): x = self.padder(x) @@ -91,10 +81,8 @@ def __init__(self, in_features, out_feature): # self.init_layers() def init_layers(self): - torch.nn.init.xavier_uniform_( - self.H.weight, gain=torch.nn.init.calculate_gain('relu')) - torch.nn.init.xavier_uniform_( - self.T.weight, gain=torch.nn.init.calculate_gain('sigmoid')) + torch.nn.init.xavier_uniform_(self.H.weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.T.weight, gain=torch.nn.init.calculate_gain("sigmoid")) def forward(self, inputs): H = self.relu(self.H(inputs)) @@ -104,29 +92,32 @@ def forward(self, inputs): class CBHG(nn.Module): """CBHG module: a recurrent neural network composed of: - - 1-d convolution banks - - Highway networks + residual connections - - Bidirectional gated recurrent units + - 1-d convolution banks + - Highway networks + residual connections + - Bidirectional gated recurrent units - Args: - in_features (int): sample size - K (int): max filter size in conv bank - projections (list): conv channel sizes for conv projections - num_highways (int): number of highways layers + Args: + in_features (int): sample size + K (int): max filter size in conv bank + projections (list): conv channel sizes for conv projections + num_highways (int): number of highways layers - Shapes: - - input: (B, C, T_in) - - output: (B, T_in, C*2) + Shapes: + - input: (B, C, T_in) + - output: (B, T_in, C*2) """ - #pylint: disable=dangerous-default-value - def __init__(self, - in_features, - K=16, - conv_bank_features=128, - conv_projections=[128, 128], - highway_features=128, - gru_features=128, - num_highways=4): + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_features, + K=16, + conv_bank_features=128, + conv_projections=[128, 128], + highway_features=128, + gru_features=128, + num_highways=4, + ): super(CBHG, self).__init__() self.in_features = in_features self.conv_bank_features = conv_bank_features @@ -136,14 +127,19 @@ def __init__(self, self.relu = nn.ReLU() # list of conv1d bank with filter size k=1...K # TODO: try dilational layers instead - self.conv1d_banks = nn.ModuleList([ - BatchNormConv1d(in_features, - conv_bank_features, - kernel_size=k, - stride=1, - padding=[(k - 1) // 2, k // 2], - activation=self.relu) for k in range(1, K + 1) - ]) + self.conv1d_banks = nn.ModuleList( + [ + BatchNormConv1d( + in_features, + conv_bank_features, + kernel_size=k, + stride=1, + padding=[(k - 1) // 2, k // 2], + activation=self.relu, + ) + for k in range(1, K + 1) + ] + ) # max pooling of conv bank, with padding # TODO: try average pooling OR larger kernel size out_features = [K * conv_bank_features] + conv_projections[:-1] @@ -151,31 +147,16 @@ def __init__(self, activations += [None] # setup conv1d projection layers layer_set = [] - for (in_size, out_size, ac) in zip(out_features, conv_projections, - activations): - layer = BatchNormConv1d(in_size, - out_size, - kernel_size=3, - stride=1, - padding=[1, 1], - activation=ac) + for (in_size, out_size, ac) in zip(out_features, conv_projections, activations): + layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=[1, 1], activation=ac) layer_set.append(layer) self.conv1d_projections = nn.ModuleList(layer_set) # setup Highway layers if self.highway_features != conv_projections[-1]: - self.pre_highway = nn.Linear(conv_projections[-1], - highway_features, - bias=False) - self.highways = nn.ModuleList([ - Highway(highway_features, highway_features) - for _ in range(num_highways) - ]) + self.pre_highway = nn.Linear(conv_projections[-1], highway_features, bias=False) + self.highways = nn.ModuleList([Highway(highway_features, highway_features) for _ in range(num_highways)]) # bi-directional GPU layer - self.gru = nn.GRU(gru_features, - gru_features, - 1, - batch_first=True, - bidirectional=True) + self.gru = nn.GRU(gru_features, gru_features, 1, batch_first=True, bidirectional=True) def forward(self, inputs): # (B, in_features, T_in) @@ -218,7 +199,8 @@ def __init__(self): conv_projections=[128, 128], highway_features=128, gru_features=128, - num_highways=4) + num_highways=4, + ) def forward(self, x): return self.cbhg(x) @@ -256,7 +238,8 @@ def __init__(self, mel_dim): conv_projections=[256, mel_dim], highway_features=128, gru_features=128, - num_highways=4) + num_highways=4, + ) def forward(self, x): return self.cbhg(x) @@ -289,10 +272,24 @@ class Decoder(nn.Module): # Pylint gets confused by PyTorch conventions here # pylint: disable=attribute-defined-outside-init - def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing, - attn_norm, prenet_type, prenet_dropout, forward_attn, - trans_agent, forward_attn_mask, location_attn, attn_K, - separate_stopnet): + def __init__( + self, + in_channels, + frame_channels, + r, + memory_size, + attn_type, + attn_windowing, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ): super(Decoder, self).__init__() self.r_init = r self.r = r @@ -305,33 +302,30 @@ def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_ self.query_dim = 256 # memory -> |Prenet| -> processed_memory prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels - self.prenet = Prenet( - prenet_dim, - prenet_type, - prenet_dropout, - out_features=[256, 128]) + self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State # attention_rnn generates queries for the attention mechanism self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim) - self.attention = init_attn(attn_type=attn_type, - query_dim=self.query_dim, - embedding_dim=in_channels, - attention_dim=128, - location_attention=location_attn, - attention_location_n_filters=32, - attention_location_kernel_size=31, - windowing=attn_windowing, - norm=attn_norm, - forward_attn=forward_attn, - trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask, - attn_K=attn_K) + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_channels, 256) # decoder_RNN_input -> |RNN| -> RNN_state - self.decoder_rnns = nn.ModuleList( - [nn.GRUCell(256, 256) for _ in range(2)]) + self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init) # learn init values instead of zero init. @@ -364,8 +358,7 @@ def _init_states(self, inputs): # decoder states self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256) self.decoder_rnn_hiddens = [ - torch.zeros(1, device=inputs.device).repeat(B, 256) - for idx in range(len(self.decoder_rnns)) + torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns)) ] self.context_vec = inputs.data.new(B, self.in_channels).zero_() # cache attention inputs @@ -376,8 +369,7 @@ def _parse_outputs(self, outputs, attentions, stop_tokens): attentions = torch.stack(attentions).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - outputs = outputs.view( - outputs.size(0), -1, self.frame_channels) + outputs = outputs.view(outputs.size(0), -1, self.frame_channels) outputs = outputs.transpose(1, 2) return outputs, attentions, stop_tokens @@ -386,18 +378,15 @@ def decode(self, inputs, mask=None): processed_memory = self.prenet(self.memory_input) # Attention RNN self.attention_rnn_hidden = self.attention_rnn( - torch.cat((processed_memory, self.context_vec), -1), - self.attention_rnn_hidden) - self.context_vec = self.attention( - self.attention_rnn_hidden, inputs, self.processed_inputs, mask) + torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden + ) + self.context_vec = self.attention(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) # Concat RNN output and attention context vector - decoder_input = self.project_to_decoder_in( - torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) + decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) # Pass through the decoder RNNs for idx in range(len(self.decoder_rnns)): - self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx]( - decoder_input, self.decoder_rnn_hiddens[idx]) + self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx]) # Residual connection decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input decoder_output = decoder_input @@ -418,17 +407,17 @@ def _update_memory_input(self, new_memory): if self.use_memory_queue: if self.memory_size > self.r: # memory queue size is larger than number of frames per decoder iter - self.memory_input = torch.cat([ - new_memory, self.memory_input[:, :( - self.memory_size - self.r) * self.frame_channels].clone() - ], dim=-1) + self.memory_input = torch.cat( + [new_memory, self.memory_input[:, : (self.memory_size - self.r) * self.frame_channels].clone()], + dim=-1, + ) else: # memory queue size smaller than number of frames per decoder iter - self.memory_input = new_memory[:, :self.memory_size * self.frame_channels] + self.memory_input = new_memory[:, : self.memory_size * self.frame_channels] else: # use only the last frame prediction # assert new_memory.shape[-1] == self.r * self.frame_channels - self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):] + self.memory_input = new_memory[:, self.frame_channels * (self.r - 1) :] def forward(self, inputs, memory, mask): """ @@ -487,8 +476,7 @@ def inference(self, inputs): attentions += [attention] stop_tokens += [stop_token] t += 1 - if t > inputs.shape[1] / 4 and (stop_token > 0.6 - or attention[:, -1].item() > 0.6): + if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): break if t > self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") @@ -506,8 +494,7 @@ def __init__(self, in_features): super(StopNet, self).__init__() self.dropout = nn.Dropout(0.1) self.linear = nn.Linear(in_features, 1) - torch.nn.init.xavier_uniform_( - self.linear.weight, gain=torch.nn.init.calculate_gain('linear')) + torch.nn.init.xavier_uniform_(self.linear.weight, gain=torch.nn.init.calculate_gain("linear")) def forward(self, inputs): outputs = self.dropout(inputs) diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py index 8e6dbc1510..7893cf4aaa 100644 --- a/TTS/tts/layers/tacotron/tacotron2.py +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -1,12 +1,14 @@ import torch from torch import nn from torch.nn import functional as F -from .common_layers import Prenet, Linear + from .attentions import init_attn +from .common_layers import Linear, Prenet + # NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg class ConvBNBlock(nn.Module): r"""Convolutions with Batch Normalization and non-linear activation. @@ -20,19 +22,17 @@ class ConvBNBlock(nn.Module): - input: (B, C_in, T) - output: (B, C_out, T) """ + def __init__(self, in_channels, out_channels, kernel_size, activation=None): super(ConvBNBlock, self).__init__() assert (kernel_size - 1) % 2 == 0 padding = (kernel_size - 1) // 2 - self.convolution1d = nn.Conv1d(in_channels, - out_channels, - kernel_size, - padding=padding) + self.convolution1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5) self.dropout = nn.Dropout(p=0.5) - if activation == 'relu': + if activation == "relu": self.activation = nn.ReLU() - elif activation == 'tanh': + elif activation == "tanh": self.activation = nn.Tanh() else: self.activation = nn.Identity() @@ -55,16 +55,14 @@ class Postnet(nn.Module): - input: (B, C_in, T) - output: (B, C_in, T) """ + def __init__(self, in_out_channels, num_convs=5): super(Postnet, self).__init__() self.convolutions = nn.ModuleList() - self.convolutions.append( - ConvBNBlock(in_out_channels, 512, kernel_size=5, activation='tanh')) + self.convolutions.append(ConvBNBlock(in_out_channels, 512, kernel_size=5, activation="tanh")) for _ in range(1, num_convs - 1): - self.convolutions.append( - ConvBNBlock(512, 512, kernel_size=5, activation='tanh')) - self.convolutions.append( - ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None)) + self.convolutions.append(ConvBNBlock(512, 512, kernel_size=5, activation="tanh")) + self.convolutions.append(ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None)) def forward(self, x): o = x @@ -83,18 +81,15 @@ class Encoder(nn.Module): - input: (B, C_in, T) - output: (B, C_in, T) """ + def __init__(self, in_out_channels=512): super(Encoder, self).__init__() self.convolutions = nn.ModuleList() for _ in range(3): - self.convolutions.append( - ConvBNBlock(in_out_channels, in_out_channels, 5, 'relu')) - self.lstm = nn.LSTM(in_out_channels, - int(in_out_channels / 2), - num_layers=1, - batch_first=True, - bias=True, - bidirectional=True) + self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) + self.lstm = nn.LSTM( + in_out_channels, int(in_out_channels / 2), num_layers=1, batch_first=True, bias=True, bidirectional=True + ) self.rnn_state = None def forward(self, x, input_lengths): @@ -102,9 +97,7 @@ def forward(self, x, input_lengths): for layer in self.convolutions: o = layer(o) o = o.transpose(1, 2) - o = nn.utils.rnn.pack_padded_sequence(o, - input_lengths.cpu(), - batch_first=True) + o = nn.utils.rnn.pack_padded_sequence(o, input_lengths.cpu(), batch_first=True) self.lstm.flatten_parameters() o, _ = self.lstm(o) o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) @@ -143,11 +136,26 @@ class Decoder(nn.Module): attn_K (int): number of attention heads for GravesAttention. separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. """ + # Pylint gets confused by PyTorch conventions here - #pylint: disable=attribute-defined-outside-init - def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm, - prenet_type, prenet_dropout, forward_attn, trans_agent, - forward_attn_mask, location_attn, attn_K, separate_stopnet): + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + in_channels, + frame_channels, + r, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ): super(Decoder, self).__init__() self.frame_channels = frame_channels self.r_init = r @@ -167,43 +175,36 @@ def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_nor # memory -> |Prenet| -> processed_memory prenet_dim = self.frame_channels - self.prenet = Prenet(prenet_dim, - prenet_type, - prenet_dropout, - out_features=[self.prenet_dim, self.prenet_dim], - bias=False) - - self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, - self.query_dim, - bias=True) - - self.attention = init_attn(attn_type=attn_type, - query_dim=self.query_dim, - embedding_dim=in_channels, - attention_dim=128, - location_attention=location_attn, - attention_location_n_filters=32, - attention_location_kernel_size=31, - windowing=attn_win, - norm=attn_norm, - forward_attn=forward_attn, - trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask, - attn_K=attn_K) - - self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, - self.decoder_rnn_dim, - bias=True) - - self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, - self.frame_channels * self.r_init) + self.prenet = Prenet( + prenet_dim, prenet_type, prenet_dropout, out_features=[self.prenet_dim, self.prenet_dim], bias=False + ) + + self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, self.query_dim, bias=True) + + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_win, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) + + self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, self.decoder_rnn_dim, bias=True) + + self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, self.frame_channels * self.r_init) self.stopnet = nn.Sequential( nn.Dropout(0.1), - Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, - 1, - bias=True, - init_gain='sigmoid')) + Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, 1, bias=True, init_gain="sigmoid"), + ) self.memory_truncated = None def set_r(self, new_r): @@ -211,24 +212,18 @@ def set_r(self, new_r): def get_go_frame(self, inputs): B = inputs.size(0) - memory = torch.zeros(1, device=inputs.device).repeat( - B, self.frame_channels * self.r) + memory = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.r) return memory def _init_states(self, inputs, mask, keep_states=False): B = inputs.size(0) # T = inputs.size(1) if not keep_states: - self.query = torch.zeros(1, device=inputs.device).repeat( - B, self.query_dim) - self.attention_rnn_cell_state = torch.zeros( - 1, device=inputs.device).repeat(B, self.query_dim) - self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat( - B, self.decoder_rnn_dim) - self.decoder_cell = torch.zeros(1, device=inputs.device).repeat( - B, self.decoder_rnn_dim) - self.context = torch.zeros(1, device=inputs.device).repeat( - B, self.encoder_embedding_dim) + self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.context = torch.zeros(1, device=inputs.device).repeat(B, self.encoder_embedding_dim) self.inputs = inputs self.processed_inputs = self.attention.preprocess_inputs(inputs) self.mask = mask @@ -254,38 +249,36 @@ def _parse_outputs(self, outputs, stop_tokens, alignments): def _update_memory(self, memory): if len(memory.shape) == 2: - return memory[:, self.frame_channels * (self.r - 1):] - return memory[:, :, self.frame_channels * (self.r - 1):] + return memory[:, self.frame_channels * (self.r - 1) :] + return memory[:, :, self.frame_channels * (self.r - 1) :] def decode(self, memory): - ''' - shapes: - - memory: B x r * self.frame_channels - ''' + """ + shapes: + - memory: B x r * self.frame_channels + """ # self.context: B x D_en # query_input: B x D_en + (r * self.frame_channels) query_input = torch.cat((memory, self.context), -1) # self.query and self.attention_rnn_cell_state : B x D_attn_rnn self.query, self.attention_rnn_cell_state = self.attention_rnn( - query_input, (self.query, self.attention_rnn_cell_state)) - self.query = F.dropout(self.query, self.p_attention_dropout, - self.training) + query_input, (self.query, self.attention_rnn_cell_state) + ) + self.query = F.dropout(self.query, self.p_attention_dropout, self.training) self.attention_rnn_cell_state = F.dropout( - self.attention_rnn_cell_state, self.p_attention_dropout, - self.training) + self.attention_rnn_cell_state, self.p_attention_dropout, self.training + ) # B x D_en - self.context = self.attention(self.query, self.inputs, - self.processed_inputs, self.mask) + self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask) # B x (D_en + D_attn_rnn) decoder_rnn_input = torch.cat((self.query, self.context), -1) # self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn self.decoder_hidden, self.decoder_cell = self.decoder_rnn( - decoder_rnn_input, (self.decoder_hidden, self.decoder_cell)) - self.decoder_hidden = F.dropout(self.decoder_hidden, - self.p_decoder_dropout, self.training) + decoder_rnn_input, (self.decoder_hidden, self.decoder_cell) + ) + self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) # B x (D_decoder_rnn + D_en) - decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), - dim=1) + decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1) # B x (self.r * self.frame_channels) decoder_output = self.linear_projection(decoder_hidden_context) # B x (D_decoder_rnn + (self.r * self.frame_channels)) @@ -295,7 +288,7 @@ def decode(self, memory): else: stop_token = self.stopnet(stopnet_input) # select outputs for the reduction rate self.r - decoder_output = decoder_output[:, :self.r * self.frame_channels] + decoder_output = decoder_output[:, : self.r * self.frame_channels] return decoder_output, self.attention.attention_weights, stop_token def forward(self, inputs, memories, mask): @@ -329,8 +322,7 @@ def forward(self, inputs, memories, mask): stop_tokens += [stop_token.squeeze(1)] alignments += [attention_weights] - outputs, stop_tokens, alignments = self._parse_outputs( - outputs, stop_tokens, alignments) + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) return outputs, alignments, stop_tokens def inference(self, inputs): @@ -369,8 +361,7 @@ def inference(self, inputs): memory = self._update_memory(decoder_output) t += 1 - outputs, stop_tokens, alignments = self._parse_outputs( - outputs, stop_tokens, alignments) + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) return outputs, alignments, stop_tokens @@ -404,8 +395,7 @@ def inference_truncated(self, inputs): self.memory_truncated = decoder_output t += 1 - outputs, stop_tokens, alignments = self._parse_outputs( - outputs, stop_tokens, alignments) + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) return outputs, alignments, stop_tokens diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 16cb013af1..e097ac503f 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -1,12 +1,13 @@ import torch import torch.nn as nn -from TTS.tts.layers.generic.pos_encoding import PositionalEncoding + +from TTS.tts.layers.align_tts.mdn import MDNBlock +from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.utils.generic_utils import sequence_mask -from TTS.tts.layers.align_tts.mdn import MDNBlock -from TTS.tts.layers.feed_forward.encoder import Encoder -from TTS.tts.layers.feed_forward.decoder import Decoder class AlignTTS(nn.Module): @@ -62,39 +63,27 @@ class AlignTTS(nn.Module): # pylint: disable=dangerous-default-value def __init__( - self, - num_chars, - out_channels, - hidden_channels=256, - hidden_channels_dp=256, - encoder_type='fftransformer', - encoder_params={ - 'hidden_channels_ffn': 1024, - 'num_heads': 2, - 'num_layers': 6, - 'dropout_p': 0.1 - }, - decoder_type='fftransformer', - decoder_params={ - 'hidden_channels_ffn': 1024, - 'num_heads': 2, - 'num_layers': 6, - 'dropout_p': 0.1 - }, - length_scale=1, - num_speakers=0, - external_c=False, - c_in_channels=0): + self, + num_chars, + out_channels, + hidden_channels=256, + hidden_channels_dp=256, + encoder_type="fftransformer", + encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}, + decoder_type="fftransformer", + decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}, + length_scale=1, + num_speakers=0, + external_c=False, + c_in_channels=0, + ): super().__init__() - self.length_scale = float(length_scale) if isinstance( - length_scale, int) else length_scale + self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale self.emb = nn.Embedding(num_chars, hidden_channels) self.pos_encoder = PositionalEncoding(hidden_channels) - self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, - encoder_params, c_in_channels) - self.decoder = Decoder(out_channels, hidden_channels, decoder_type, - decoder_params) + self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels) + self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params) self.duration_predictor = DurationPredictor(hidden_channels_dp) self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1) @@ -111,13 +100,13 @@ def __init__( @staticmethod def compute_log_probs(mu, log_sigma, y): # pylint: disable=protected-access, c-extension-no-member - y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] - mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] - log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] + mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] expanded_y, expanded_mu = torch.broadcast_tensors(y, mu) - exponential = -0.5 * torch.mean(torch._C._nn.mse_loss( - expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), - dim=-1) # B, L, T + exponential = -0.5 * torch.mean( + torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1 + ) # B, L, T logp = exponential - 0.5 * log_sigma.mean(dim=-1) return logp @@ -151,9 +140,7 @@ def expand_encoder_outputs(self, en, dr, x_mask, y_mask): [1, 0, 0, 0, 0, 0, 0]] """ attn = self.convert_dr_to_align(dr, x_mask, y_mask) - o_en_ex = torch.matmul( - attn.squeeze(1).transpose(1, 2), en.transpose(1, - 2)).transpose(1, 2) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) return o_en_ex, attn def format_durations(self, o_dr_log, x_mask): @@ -170,12 +157,12 @@ def _concat_speaker_embedding(o_en, g): def _sum_speaker_embedding(self, x, g): # project g to decoder dim. - if hasattr(self, 'proj_g'): + if hasattr(self, "proj_g"): g = self.proj_g(g) return x + g def _forward_encoder(self, x, x_lengths, g=None): - if hasattr(self, 'emb_g'): + if hasattr(self, "emb_g"): g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] if g is not None: @@ -187,8 +174,7 @@ def _forward_encoder(self, x, x_lengths, g=None): x_emb = torch.transpose(x_emb, 1, -1) # compute sequence masks - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), - 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) # encoder pass o_en = self.encoder(x_emb, x_mask) @@ -201,12 +187,11 @@ def _forward_encoder(self, x, x_lengths, g=None): return o_en, o_en_dp, x_mask, g def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), - 1).to(o_en_dp.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding - if hasattr(self, 'pos_encoder'): + if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) # speaker embedding if g is not None: @@ -218,10 +203,8 @@ def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): def _forward_mdn(self, o_en, y, y_lengths, x_mask): # MAS potentials and alignment mu, log_sigma = self.mdn_block(o_en) - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), - 1).to(o_en.dtype) - dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, - y_mask) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask) return dr_mas, mu, log_sigma, logp def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: disable=unused-argument @@ -237,56 +220,31 @@ def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: di if phase == 0: # train encoder and MDN o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) - dr_mas, mu, log_sigma, logp = self._forward_mdn( - o_en, y, y_lengths, x_mask) - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), - 1).to(o_en_dp.dtype) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask) elif phase == 1: # train decoder o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) - o_de, attn = self._forward_decoder(o_en.detach(), - o_en_dp.detach(), - dr_mas.detach(), - x_mask, - y_lengths, - g=g) + o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g) elif phase == 2: # train the whole except duration predictor o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) - dr_mas, mu, log_sigma, logp = self._forward_mdn( - o_en, y, y_lengths, x_mask) - o_de, attn = self._forward_decoder(o_en, - o_en_dp, - dr_mas, - x_mask, - y_lengths, - g=g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) elif phase == 3: # train duration predictor o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(x, x_mask) - dr_mas, mu, log_sigma, logp = self._forward_mdn( - o_en, y, y_lengths, x_mask) - o_de, attn = self._forward_decoder(o_en, - o_en_dp, - dr_mas, - x_mask, - y_lengths, - g=g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) o_dr_log = o_dr_log.squeeze(1) else: o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) - dr_mas, mu, log_sigma, logp = self._forward_mdn( - o_en, y, y_lengths, x_mask) - o_de, attn = self._forward_decoder(o_en, - o_en_dp, - dr_mas, - x_mask, - y_lengths, - g=g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) o_dr_log = o_dr_log.squeeze(1) dr_mas_log = torch.log(dr_mas + 1).squeeze(1) return o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp @@ -307,17 +265,14 @@ def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument # duration predictor pass o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) y_lengths = o_dr.sum(1) - o_de, attn = self._forward_decoder(o_en, - o_en_dp, - o_dr, - x_mask, - y_lengths, - g=g) + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) return o_de, attn - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 2e01f87c83..0717e2a830 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -1,12 +1,13 @@ import math + import torch from torch import nn from torch.nn import functional as F -from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.layers.glow_tts.decoder import Decoder +from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.utils.generic_utils import sequence_mask -from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path class GlowTTS(nn.Module): @@ -34,28 +35,31 @@ class GlowTTS(nn.Module): encoder_params (dict): encoder module parameters. external_speaker_embedding_dim (int): channels of external speaker embedding vectors. """ - def __init__(self, - num_chars, - hidden_channels_enc, - hidden_channels_dec, - use_encoder_prenet, - hidden_channels_dp, - out_channels, - num_flow_blocks_dec=12, - kernel_size_dec=5, - dilation_rate=5, - num_block_layers=4, - dropout_p_dp=0.1, - dropout_p_dec=0.05, - num_speakers=0, - c_in_channels=0, - num_splits=4, - num_squeeze=1, - sigmoid_scale=False, - mean_only=False, - encoder_type="transformer", - encoder_params=None, - external_speaker_embedding_dim=None): + + def __init__( + self, + num_chars, + hidden_channels_enc, + hidden_channels_dec, + use_encoder_prenet, + hidden_channels_dp, + out_channels, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=5, + num_block_layers=4, + dropout_p_dp=0.1, + dropout_p_dec=0.05, + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_squeeze=1, + sigmoid_scale=False, + mean_only=False, + encoder_type="transformer", + encoder_params=None, + external_speaker_embedding_dim=None, + ): super().__init__() self.num_chars = num_chars @@ -78,7 +82,7 @@ def __init__(self, # model constants. self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference. - self.length_scale = 1. # scaler for the duration predictor. The larger it is, the slower the speech. + self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech. self.external_speaker_embedding_dim = external_speaker_embedding_dim # if is a multispeaker and c_in_channels is 0, set to 256 @@ -88,28 +92,32 @@ def __init__(self, elif self.external_speaker_embedding_dim: self.c_in_channels = self.external_speaker_embedding_dim - self.encoder = Encoder(num_chars, - out_channels=out_channels, - hidden_channels=hidden_channels_enc, - hidden_channels_dp=hidden_channels_dp, - encoder_type=encoder_type, - encoder_params=encoder_params, - mean_only=mean_only, - use_prenet=use_encoder_prenet, - dropout_p_dp=dropout_p_dp, - c_in_channels=self.c_in_channels) - - self.decoder = Decoder(out_channels, - hidden_channels_dec, - kernel_size_dec, - dilation_rate, - num_flow_blocks_dec, - num_block_layers, - dropout_p=dropout_p_dec, - num_splits=num_splits, - num_squeeze=num_squeeze, - sigmoid_scale=sigmoid_scale, - c_in_channels=self.c_in_channels) + self.encoder = Encoder( + num_chars, + out_channels=out_channels, + hidden_channels=hidden_channels_enc, + hidden_channels_dp=hidden_channels_dp, + encoder_type=encoder_type, + encoder_params=encoder_params, + mean_only=mean_only, + use_prenet=use_encoder_prenet, + dropout_p_dp=dropout_p_dp, + c_in_channels=self.c_in_channels, + ) + + self.decoder = Decoder( + out_channels, + hidden_channels_dec, + kernel_size_dec, + dilation_rate, + num_flow_blocks_dec, + num_block_layers, + dropout_p=dropout_p_dec, + num_splits=num_splits, + num_squeeze=num_squeeze, + sigmoid_scale=sigmoid_scale, + c_in_channels=self.c_in_channels, + ) if num_speakers > 1 and not external_speaker_embedding_dim: # speaker embedding layer @@ -119,12 +127,12 @@ def __init__(self, @staticmethod def compute_outputs(attn, o_mean, o_log_scale, x_mask): # compute final values with the computed alignment - y_mean = torch.matmul( - attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( - 1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] - y_log_scale = torch.matmul( - attn.squeeze(1).transpose(1, 2), o_log_scale.transpose( - 1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] # compute total duration with adjustment o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask return y_mean, y_log_scale, o_attn_dur @@ -144,37 +152,27 @@ def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): if self.external_speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) else: - g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h, 1] + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass - o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, - x_lengths, - g=g) + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. - y, y_lengths, y_max_length, attn = self.preprocess( - y, y_lengths, y_max_length, None) + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) # create masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), - 1).to(x_mask.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, - [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * - (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), - z) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, - [1]).unsqueeze(-1) # [b, t, 1] + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] - attn = maximum_path(logp, - attn_mask.squeeze(1)).unsqueeze(1).detach() - y_mean, y_log_scale, o_attn_dur = self.compute_outputs( - attn, o_mean, o_log_scale, x_mask) + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur @@ -187,26 +185,20 @@ def inference(self, x, x_lengths, g=None): g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] # embedding pass - o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, - x_lengths, - g=g) + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # compute output durations w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = None # compute masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), - 1).to(x_mask.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # compute attention mask - attn = generate_path(w_ceil.squeeze(1), - attn_mask.squeeze(1)).unsqueeze(1) - y_mean, y_log_scale, o_attn_dur = self.compute_outputs( - attn, o_mean, o_log_scale, x_mask) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) - z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * - self.noise_scale) * y_mask + z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.noise_scale) * y_mask # decoder pass y, logdet = self.decoder(z, y_mask, g=g, reverse=True) attn = attn.squeeze(1).permute(0, 2, 1) @@ -224,9 +216,11 @@ def preprocess(self, y, y_lengths, y_max_length, attn=None): def store_inverse(self): self.decoder.store_inverse() - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() self.store_inverse() diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 00cba5c7f1..9880b82ba3 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -1,11 +1,12 @@ import torch from torch import nn + from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding -from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.layers.glow_tts.monotonic_align import generate_path +from TTS.tts.utils.generic_utils import sequence_mask class SpeedySpeech(nn.Module): @@ -34,42 +35,37 @@ class SpeedySpeech(nn.Module): external_c (bool, optional): enable external speaker embeddings. Defaults to False. c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0. """ + # pylint: disable=dangerous-default-value def __init__( - self, - num_chars, - out_channels, - hidden_channels, - positional_encoding=True, - length_scale=1, - encoder_type='residual_conv_bn', - encoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 13 - }, - decoder_type='residual_conv_bn', - decoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4, 8] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 17 - }, - num_speakers=0, - external_c=False, - c_in_channels=0): + self, + num_chars, + out_channels, + hidden_channels, + positional_encoding=True, + length_scale=1, + encoder_type="residual_conv_bn", + encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + num_speakers=0, + external_c=False, + c_in_channels=0, + ): super().__init__() self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale self.emb = nn.Embedding(num_chars, hidden_channels) - self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, - encoder_params, c_in_channels) + self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels) if positional_encoding: self.pos_encoder = PositionalEncoding(hidden_channels) - self.decoder = Decoder(out_channels, hidden_channels, - decoder_type, decoder_params) + self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params) self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels) if num_speakers > 1 and not external_c: @@ -97,9 +93,7 @@ def expand_encoder_outputs(en, dr, x_mask, y_mask): """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) - o_en_ex = torch.matmul( - attn.squeeze(1).transpose(1, 2), en.transpose(1, - 2)).transpose(1, 2) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) return o_en_ex, attn def format_durations(self, o_dr_log, x_mask): @@ -116,12 +110,12 @@ def _concat_speaker_embedding(o_en, g): def _sum_speaker_embedding(self, x, g): # project g to decoder dim. - if hasattr(self, 'proj_g'): + if hasattr(self, "proj_g"): g = self.proj_g(g) return x + g def _forward_encoder(self, x, x_lengths, g=None): - if hasattr(self, 'emb_g'): + if hasattr(self, "emb_g"): g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] if g is not None: @@ -133,8 +127,7 @@ def _forward_encoder(self, x, x_lengths, g=None): x_emb = torch.transpose(x_emb, 1, -1) # compute sequence masks - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), - 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) # encoder pass o_en = self.encoder(x_emb, x_mask) @@ -147,12 +140,11 @@ def _forward_encoder(self, x, x_lengths, g=None): return o_en, o_en_dp, x_mask, g def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), - 1).to(o_en_dp.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding - if hasattr(self, 'pos_encoder'): + if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) # speaker embedding if g is not None: @@ -187,7 +179,7 @@ def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument if x.shape[1] < 13: inference_padding += 13 - x.shape[1] # pad input to prevent dropping the last word - x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode='constant', value=0) + x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) # duration predictor pass o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) @@ -196,9 +188,11 @@ def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) return o_de, attn - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 541c4159a2..1fb011108c 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -47,45 +47,67 @@ class Tacotron(TacotronAbstract): memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size``` output frames to the prenet. """ - def __init__(self, - num_chars, - num_speakers, - r=5, - postnet_output_dim=1025, - decoder_output_dim=80, - attn_type='original', - attn_win=False, - attn_norm="sigmoid", - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=256, - decoder_in_features=256, - speaker_embedding_dim=None, - gst=False, - gst_embedding_dim=256, - gst_num_heads=4, - gst_style_tokens=10, - memory_size=5, - gst_use_speaker_embedding=False): - super(Tacotron, - self).__init__(num_chars, num_speakers, r, postnet_output_dim, - decoder_output_dim, attn_type, attn_win, - attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet, - bidirectional_decoder, double_decoder_consistency, - ddc_r, encoder_in_features, decoder_in_features, - speaker_embedding_dim, gst, gst_embedding_dim, - gst_num_heads, gst_style_tokens, gst_use_speaker_embedding) + + def __init__( + self, + num_chars, + num_speakers, + r=5, + postnet_output_dim=1025, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="sigmoid", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=5, + separate_stopnet=True, + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + encoder_in_features=256, + decoder_in_features=256, + speaker_embedding_dim=None, + gst=False, + gst_embedding_dim=256, + gst_num_heads=4, + gst_style_tokens=10, + memory_size=5, + gst_use_speaker_embedding=False, + ): + super(Tacotron, self).__init__( + num_chars, + num_speakers, + r, + postnet_output_dim, + decoder_output_dim, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + bidirectional_decoder, + double_decoder_consistency, + ddc_r, + encoder_in_features, + decoder_in_features, + speaker_embedding_dim, + gst, + gst_embedding_dim, + gst_num_heads, + gst_style_tokens, + gst_use_speaker_embedding, + ) # speaker embedding layers if self.num_speakers > 1: @@ -96,7 +118,7 @@ def __init__(self, # speaker and gst embeddings is concat in decoder input if self.num_speakers > 1: - self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim + self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim # embedding layer self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) @@ -104,32 +126,59 @@ def __init__(self, # base model layers self.encoder = Encoder(self.encoder_in_features) - self.decoder = Decoder(self.decoder_in_features, decoder_output_dim, r, - memory_size, attn_type, attn_win, attn_norm, - prenet_type, prenet_dropout, forward_attn, - trans_agent, forward_attn_mask, location_attn, - attn_K, separate_stopnet) + self.decoder = Decoder( + self.decoder_in_features, + decoder_output_dim, + r, + memory_size, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ) self.postnet = PostCBHG(decoder_output_dim) - self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, - postnet_output_dim) + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) # global style token layers if self.gst: - self.gst_layer = GST(num_mel=80, - num_heads=gst_num_heads, - num_style_tokens=gst_style_tokens, - gst_embedding_dim=self.gst_embedding_dim, - speaker_embedding_dim=speaker_embedding_dim if self.embeddings_per_sample and self.gst_use_speaker_embedding else None) + self.gst_layer = GST( + num_mel=80, + num_heads=gst_num_heads, + num_style_tokens=gst_style_tokens, + gst_embedding_dim=self.gst_embedding_dim, + speaker_embedding_dim=speaker_embedding_dim + if self.embeddings_per_sample and self.gst_use_speaker_embedding + else None, + ) # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( - self.decoder_in_features, decoder_output_dim, ddc_r, memory_size, - attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, location_attn, - attn_K, separate_stopnet) + self.decoder_in_features, + decoder_output_dim, + ddc_r, + memory_size, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ) def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): """ @@ -151,9 +200,9 @@ def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker # global style token if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - mel_specs, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None + ) # speaker embedding if self.num_speakers > 1: if not self.embeddings_per_sample: @@ -166,8 +215,7 @@ def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker # decoder_outputs: B x decoder_in_features x T_out # alignments: B x T_in x encoder_in_features # stop_tokens: B x T_in - decoder_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, input_mask) + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) # sequence masking if output_mask is not None: decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) @@ -182,10 +230,26 @@ def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() if self.bidirectional_decoder: decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) - return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + return ( + decoder_outputs, + postnet_outputs, + alignments, + stop_tokens, + decoder_outputs_backward, + alignments_backward, + ) if self.double_decoder_consistency: - decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask) - return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + return ( + decoder_outputs, + postnet_outputs, + alignments, + stop_tokens, + decoder_outputs_backward, + alignments_backward, + ) return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() @@ -194,9 +258,9 @@ def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embedd encoder_outputs = self.encoder(inputs) if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - style_mel, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None + ) if self.num_speakers > 1: if not self.embeddings_per_sample: # B x 1 x speaker_embed_dim @@ -205,8 +269,7 @@ def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embedd # B x 1 x speaker_embed_dim speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) - decoder_outputs, alignments, stop_tokens = self.decoder.inference( - encoder_outputs) + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.last_linear(postnet_outputs) decoder_outputs = decoder_outputs.transpose(1, 2) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 0e751c32ec..af8c47e03e 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -5,6 +5,7 @@ from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.models.tacotron_abstract import TacotronAbstract + # TODO: match function arguments with tacotron class Tacotron2(TacotronAbstract): """Tacotron2 as in https://arxiv.org/abs/1712.05884 @@ -44,44 +45,66 @@ class Tacotron2(TacotronAbstract): gst_style_tokens (int, optional): number of GST tokens. Defaults to 10. gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False. """ - def __init__(self, - num_chars, - num_speakers, - r, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type='original', - attn_win=False, - attn_norm="softmax", - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=512, - decoder_in_features=512, - speaker_embedding_dim=None, - gst=False, - gst_embedding_dim=512, - gst_num_heads=4, - gst_style_tokens=10, - gst_use_speaker_embedding=False): - super(Tacotron2, - self).__init__(num_chars, num_speakers, r, postnet_output_dim, - decoder_output_dim, attn_type, attn_win, - attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet, - bidirectional_decoder, double_decoder_consistency, - ddc_r, encoder_in_features, decoder_in_features, - speaker_embedding_dim, gst, gst_embedding_dim, - gst_num_heads, gst_style_tokens, gst_use_speaker_embedding) + + def __init__( + self, + num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="softmax", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=5, + separate_stopnet=True, + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + encoder_in_features=512, + decoder_in_features=512, + speaker_embedding_dim=None, + gst=False, + gst_embedding_dim=512, + gst_num_heads=4, + gst_style_tokens=10, + gst_use_speaker_embedding=False, + ): + super(Tacotron2, self).__init__( + num_chars, + num_speakers, + r, + postnet_output_dim, + decoder_output_dim, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + bidirectional_decoder, + double_decoder_consistency, + ddc_r, + encoder_in_features, + decoder_in_features, + speaker_embedding_dim, + gst, + gst_embedding_dim, + gst_num_heads, + gst_style_tokens, + gst_use_speaker_embedding, + ) # speaker embedding layer if self.num_speakers > 1: @@ -92,36 +115,63 @@ def __init__(self, # speaker and gst embeddings is concat in decoder input if self.num_speakers > 1: - self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim + self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim # embedding layer self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) # base model layers self.encoder = Encoder(self.encoder_in_features) - self.decoder = Decoder(self.decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, - attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet) + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + r, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ) self.postnet = Postnet(self.postnet_output_dim) # global style token layers if self.gst: - self.gst_layer = GST(num_mel=80, - num_heads=self.gst_num_heads, - num_style_tokens=self.gst_style_tokens, - gst_embedding_dim=self.gst_embedding_dim, - speaker_embedding_dim=speaker_embedding_dim if self.embeddings_per_sample and self.gst_use_speaker_embedding else None) + self.gst_layer = GST( + num_mel=80, + num_heads=self.gst_num_heads, + num_style_tokens=self.gst_style_tokens, + gst_embedding_dim=self.gst_embedding_dim, + speaker_embedding_dim=speaker_embedding_dim + if self.embeddings_per_sample and self.gst_use_speaker_embedding + else None, + ) # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( - self.decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, - attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, - trans_agent, forward_attn_mask, location_attn, attn_K, - separate_stopnet) + self.decoder_in_features, + self.decoder_output_dim, + ddc_r, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ) @staticmethod def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): @@ -148,9 +198,9 @@ def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ encoder_outputs = self.encoder(embedded_inputs, text_lengths) if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - mel_specs, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None + ) if self.num_speakers > 1: if not self.embeddings_per_sample: # B x 1 x speaker_embed_dim @@ -163,8 +213,7 @@ def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r - decoder_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, input_mask) + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) # sequence masking if mel_lengths is not None: decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) @@ -175,14 +224,29 @@ def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ if output_mask is not None: postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs) # B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in - decoder_outputs, postnet_outputs, alignments = self.shape_outputs( - decoder_outputs, postnet_outputs, alignments) + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) if self.bidirectional_decoder: decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) - return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + return ( + decoder_outputs, + postnet_outputs, + alignments, + stop_tokens, + decoder_outputs_backward, + alignments_backward, + ) if self.double_decoder_consistency: - decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask) - return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + return ( + decoder_outputs, + postnet_outputs, + alignments, + stop_tokens, + decoder_outputs_backward, + alignments_backward, + ) return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() @@ -192,20 +256,18 @@ def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=N if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - style_mel, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None + ) if self.num_speakers > 1: if not self.embeddings_per_sample: speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) - decoder_outputs, alignments, stop_tokens = self.decoder.inference( - encoder_outputs) + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = decoder_outputs + postnet_outputs - decoder_outputs, postnet_outputs, alignments = self.shape_outputs( - decoder_outputs, postnet_outputs, alignments) + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) return decoder_outputs, postnet_outputs, alignments, stop_tokens def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): @@ -217,19 +279,17 @@ def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_em if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - style_mel, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None + ) if self.num_speakers > 1: if not self.embeddings_per_sample: speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) - mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated( - encoder_outputs) + mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet - mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( - mel_outputs, mel_outputs_postnet, alignments) + mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 22e86ee4df..820dd8b858 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -8,34 +8,36 @@ class TacotronAbstract(ABC, nn.Module): - def __init__(self, - num_chars, - num_speakers, - r, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type='original', - attn_win=False, - attn_norm="softmax", - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=512, - decoder_in_features=512, - speaker_embedding_dim=None, - gst=False, - gst_embedding_dim=512, - gst_num_heads=4, - gst_style_tokens=10, - gst_use_speaker_embedding=False): + def __init__( + self, + num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="softmax", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=5, + separate_stopnet=True, + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + encoder_in_features=512, + decoder_in_features=512, + speaker_embedding_dim=None, + gst=False, + gst_embedding_dim=512, + gst_num_heads=4, + gst_style_tokens=10, + gst_use_speaker_embedding=False, + ): """ Abstract Tacotron class """ super().__init__() self.num_chars = num_chars @@ -82,7 +84,7 @@ def __init__(self, # global style token if self.gst: - self.decoder_in_features += gst_embedding_dim # add gst embedding dim + self.decoder_in_features += gst_embedding_dim # add gst embedding dim self.gst_layer = None # model states @@ -121,10 +123,12 @@ def forward(self): def inference(self): pass - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) - self.decoder.set_r(state['r']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + self.decoder.set_r(state["r"]) if eval: self.eval() assert not self.training @@ -149,25 +153,24 @@ def compute_masks(self, text_lengths, mel_lengths): def _backward_pass(self, mel_specs, encoder_outputs, mask): """ Run backwards decoder """ decoder_outputs_b, alignments_b, _ = self.decoder_backward( - encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask) + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask + ) decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() return decoder_outputs_b, alignments_b - def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, - input_mask): + def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): """ Double Decoder Consistency """ T = mel_specs.shape[1] if T % self.coarse_decoder.r > 0: padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) - mel_specs = torch.nn.functional.pad(mel_specs, - (0, 0, 0, padding_size, 0, 0)) + mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( - encoder_outputs.detach(), mel_specs, input_mask) + encoder_outputs.detach(), mel_specs, input_mask + ) # scale_factor = self.decoder.r_init / self.decoder.r alignments_backward = torch.nn.functional.interpolate( - alignments_backward.transpose(1, 2), - size=alignments.shape[1], - mode='nearest').transpose(1, 2) + alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest" + ).transpose(1, 2) decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) decoder_outputs_backward = decoder_outputs_backward[:, :T, :] return decoder_outputs_backward, alignments_backward @@ -179,20 +182,17 @@ def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, def compute_speaker_embedding(self, speaker_ids): """ Compute speaker embedding vectors """ if hasattr(self, "speaker_embedding") and speaker_ids is None: - raise RuntimeError( - " [!] Model has speaker embedding layer but speaker_id is not provided" - ) + raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") if hasattr(self, "speaker_embedding") and speaker_ids is not None: self.speaker_embeddings = self.speaker_embedding(speaker_ids).unsqueeze(1) if hasattr(self, "speaker_project_mel") and speaker_ids is not None: - self.speaker_embeddings_projected = self.speaker_project_mel( - self.speaker_embeddings).squeeze(1) + self.speaker_embeddings_projected = self.speaker_project_mel(self.speaker_embeddings).squeeze(1) def compute_gst(self, inputs, style_input, speaker_embedding=None): """ Compute global style token """ device = inputs.device if isinstance(style_input, dict): - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) + query = torch.zeros(1, 1, self.gst_embedding_dim // 2).to(device) if speaker_embedding is not None: query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) @@ -205,20 +205,18 @@ def compute_gst(self, inputs, style_input, speaker_embedding=None): elif style_input is None: gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) else: - gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable + gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable inputs = self._concat_speaker_embedding(inputs, gst_outputs) return inputs @staticmethod def _add_speaker_embedding(outputs, speaker_embeddings): - speaker_embeddings_ = speaker_embeddings.expand( - outputs.size(0), outputs.size(1), -1) + speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1) outputs = outputs + speaker_embeddings_ return outputs @staticmethod def _concat_speaker_embedding(outputs, speaker_embeddings): - speaker_embeddings_ = speaker_embeddings.expand( - outputs.size(0), outputs.size(1), -1) + speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1) outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) return outputs diff --git a/TTS/tts/tf/layers/tacotron/common_layers.py b/TTS/tts/tf/layers/tacotron/common_layers.py index ad18b9fc6b..b208d7fed1 100644 --- a/TTS/tts/tf/layers/tacotron/common_layers.py +++ b/TTS/tts/tf/layers/tacotron/common_layers.py @@ -1,16 +1,18 @@ import tensorflow as tf from tensorflow import keras from tensorflow.python.ops import math_ops + # from tensorflow_addons.seq2seq import BahdanauAttention # NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg + class Linear(keras.layers.Layer): def __init__(self, units, use_bias, **kwargs): super(Linear, self).__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer") self.activation = keras.layers.ReLU() def call(self, x): @@ -24,8 +26,10 @@ def call(self, x): class LinearBN(keras.layers.Layer): def __init__(self, units, use_bias, **kwargs): super(LinearBN, self).__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') - self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization') + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer") + self.batch_normalization = keras.layers.BatchNormalization( + axis=-1, momentum=0.90, epsilon=1e-5, name="batch_normalization" + ) self.activation = keras.layers.ReLU() def call(self, x, training=None): @@ -39,22 +43,21 @@ def call(self, x, training=None): class Prenet(keras.layers.Layer): - def __init__(self, - prenet_type, - prenet_dropout, - units, - bias, - **kwargs): + def __init__(self, prenet_type, prenet_dropout, units, bias, **kwargs): super(Prenet, self).__init__(**kwargs) self.prenet_type = prenet_type self.prenet_dropout = prenet_dropout self.linear_layers = [] if prenet_type == "bn": - self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + self.linear_layers += [ + LinearBN(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units) + ] elif prenet_type == "original": - self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + self.linear_layers += [ + Linear(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units) + ] else: - raise RuntimeError(' [!] Unknown prenet type.') + raise RuntimeError(" [!] Unknown prenet type.") if prenet_dropout: self.dropout = keras.layers.Dropout(rate=0.5) @@ -80,10 +83,21 @@ def _sigmoid_norm(score): class Attention(keras.layers.Layer): """TODO: implement forward_attention TODO: location sensitive attention - TODO: implement attention windowing """ - def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters, - loc_attn_kernel_size, use_windowing, norm, use_forward_attn, - use_trans_agent, use_forward_attn_mask, **kwargs): + TODO: implement attention windowing""" + + def __init__( + self, + attn_dim, + use_loc_attn, + loc_attn_n_filters, + loc_attn_kernel_size, + use_windowing, + norm, + use_forward_attn, + use_trans_agent, + use_forward_attn_mask, + **kwargs, + ): super(Attention, self).__init__(**kwargs) self.use_loc_attn = use_loc_attn self.loc_attn_n_filters = loc_attn_n_filters @@ -93,20 +107,23 @@ def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters, self.use_forward_attn = use_forward_attn self.use_trans_agent = use_trans_agent self.use_forward_attn_mask = use_forward_attn_mask - self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer') - self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer') - self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer') + self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name="query_layer/linear_layer") + self.inputs_layer = tf.keras.layers.Dense( + attn_dim, use_bias=False, name=f"{self.name}/inputs_layer/linear_layer" + ) + self.v = tf.keras.layers.Dense(1, use_bias=True, name="v/linear_layer") if use_loc_attn: self.location_conv1d = keras.layers.Conv1D( filters=loc_attn_n_filters, kernel_size=loc_attn_kernel_size, - padding='same', + padding="same", use_bias=False, - name='location_layer/location_conv1d') - self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense') - if norm == 'softmax': + name="location_layer/location_conv1d", + ) + self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name="location_layer/location_dense") + if norm == "softmax": self.norm_func = tf.nn.softmax - elif norm == 'sigmoid': + elif norm == "sigmoid": self.norm_func = _sigmoid_norm else: raise ValueError("Unknown value for attention norm type") @@ -118,30 +135,25 @@ def init_states(self, batch_size, value_length): attention_old = tf.zeros([batch_size, value_length]) states = [attention_cum, attention_old] if self.use_forward_attn: - alpha = tf.concat([ - tf.ones([batch_size, 1]), - tf.zeros([batch_size, value_length])[:, :-1] + 1e-7 - ], 1) + alpha = tf.concat([tf.ones([batch_size, 1]), tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], 1) states.append(alpha) return tuple(states) def process_values(self, values): """ cache values for decoder iterations """ - #pylint: disable=attribute-defined-outside-init + # pylint: disable=attribute-defined-outside-init self.processed_values = self.inputs_layer(values) self.values = values def get_loc_attn(self, query, states): - """ compute location attention, query layer and + """compute location attention, query layer and unnorm. attention weights""" attention_cum, attention_old = states[:2] attn_cat = tf.stack([attention_old, attention_cum], axis=2) processed_query = self.query_layer(tf.expand_dims(query, 1)) processed_attn = self.location_dense(self.location_conv1d(attn_cat)) - score = self.v( - tf.nn.tanh(self.processed_values + processed_query + - processed_attn)) + score = self.v(tf.nn.tanh(self.processed_values + processed_query + processed_attn)) score = tf.squeeze(score, axis=2) return score, processed_query @@ -152,14 +164,14 @@ def get_attn(self, query): score = tf.squeeze(score, axis=2) return score, processed_query - def apply_score_masking(self, score, mask): #pylint: disable=no-self-use + def apply_score_masking(self, score, mask): # pylint: disable=no-self-use """ ignore sequence paddings """ padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) # Bias so padding positions do not contribute to attention distribution. - score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) + score -= 1.0e9 * math_ops.cast(padding_mask, dtype=tf.float32) return score - def apply_forward_attention(self, alignment, alpha): #pylint: disable=no-self-use + def apply_forward_attention(self, alignment, alpha): # pylint: disable=no-self-use # forward attention fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0) # compute transition potentials @@ -206,7 +218,9 @@ def call(self, query, states): states = self.update_states(states, scores_norm, attn_weights, new_alpha) # context_vector shape after sum == (batch_size, hidden_size) - context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False) + context_vector = tf.matmul( + tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False + ) context_vector = tf.squeeze(context_vector, axis=1) return context_vector, attn_weights, states diff --git a/TTS/tts/tf/layers/tacotron/tacotron2.py b/TTS/tts/tf/layers/tacotron/tacotron2.py index 094d9e4a8d..0e3b575635 100644 --- a/TTS/tts/tf/layers/tacotron/tacotron2.py +++ b/TTS/tts/tf/layers/tacotron/tacotron2.py @@ -1,19 +1,22 @@ import tensorflow as tf from tensorflow import keras + +from TTS.tts.tf.layers.tacotron.common_layers import Attention, Prenet from TTS.tts.tf.utils.tf_utils import shape_list -from TTS.tts.tf.layers.tacotron.common_layers import Prenet, Attention # NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg class ConvBNBlock(keras.layers.Layer): def __init__(self, filters, kernel_size, activation, **kwargs): super(ConvBNBlock, self).__init__(**kwargs) - self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d') - self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization') - self.dropout = keras.layers.Dropout(rate=0.5, name='dropout') - self.activation = keras.layers.Activation(activation, name='activation') + self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding="same", name="convolution1d") + self.batch_normalization = keras.layers.BatchNormalization( + axis=2, momentum=0.90, epsilon=1e-5, name="batch_normalization" + ) + self.dropout = keras.layers.Dropout(rate=0.5, name="dropout") + self.activation = keras.layers.Activation(activation, name="activation") def call(self, x, training=None): o = self.convolution1d(x) @@ -27,10 +30,10 @@ class Postnet(keras.layers.Layer): def __init__(self, output_filters, num_convs, **kwargs): super(Postnet, self).__init__(**kwargs) self.convolutions = [] - self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0')) + self.convolutions.append(ConvBNBlock(512, 5, "tanh", name="convolutions_0")) for idx in range(1, num_convs - 1): - self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}')) - self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}')) + self.convolutions.append(ConvBNBlock(512, 5, "tanh", name=f"convolutions_{idx}")) + self.convolutions.append(ConvBNBlock(output_filters, 5, "linear", name=f"convolutions_{idx+1}")) def call(self, x, training=None): o = x @@ -44,8 +47,10 @@ def __init__(self, output_input_dim, **kwargs): super(Encoder, self).__init__(**kwargs) self.convolutions = [] for idx in range(3): - self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}')) - self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm') + self.convolutions.append(ConvBNBlock(output_input_dim, 5, "relu", name=f"convolutions_{idx}")) + self.lstm = keras.layers.Bidirectional( + keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name="lstm" + ) def call(self, x, training=None): o = x @@ -56,10 +61,26 @@ def call(self, x, training=None): class Decoder(keras.layers.Layer): - #pylint: disable=unused-argument - def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type, - prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask, - use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs): + # pylint: disable=unused-argument + def __init__( + self, + frame_dim, + r, + attn_type, + use_attn_win, + attn_norm, + prenet_type, + prenet_dropout, + use_forward_attn, + use_trans_agent, + use_forward_attn_mask, + use_location_attn, + attn_K, + separate_stopnet, + speaker_emb_dim, + enable_tflite, + **kwargs, + ): super(Decoder, self).__init__(**kwargs) self.frame_dim = frame_dim self.r_init = tf.constant(r, dtype=tf.int32) @@ -80,30 +101,31 @@ def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type self.p_attention_dropout = 0.1 self.p_decoder_dropout = 0.1 - self.prenet = Prenet(prenet_type, - prenet_dropout, - [self.prenet_dim, self.prenet_dim], - bias=False, - name='prenet') - self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name='attention_rnn', ) + self.prenet = Prenet(prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, name="prenet") + self.attention_rnn = keras.layers.LSTMCell( + self.query_dim, + use_bias=True, + name="attention_rnn", + ) self.attention_rnn_dropout = keras.layers.Dropout(0.5) # TODO: implement other attn options - self.attention = Attention(attn_dim=self.attn_dim, - use_loc_attn=True, - loc_attn_n_filters=32, - loc_attn_kernel_size=31, - use_windowing=False, - norm=attn_norm, - use_forward_attn=use_forward_attn, - use_trans_agent=use_trans_agent, - use_forward_attn_mask=use_forward_attn_mask, - name='attention') - self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name='decoder_rnn') + self.attention = Attention( + attn_dim=self.attn_dim, + use_loc_attn=True, + loc_attn_n_filters=32, + loc_attn_kernel_size=31, + use_windowing=False, + norm=attn_norm, + use_forward_attn=use_forward_attn, + use_trans_agent=use_trans_agent, + use_forward_attn_mask=use_forward_attn_mask, + name="attention", + ) + self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name="decoder_rnn") self.decoder_rnn_dropout = keras.layers.Dropout(0.5) - self.linear_projection = keras.layers.Dense(self.frame_dim * r, name='linear_projection/linear_layer') - self.stopnet = keras.layers.Dense(1, name='stopnet/linear_layer') - + self.linear_projection = keras.layers.Dense(self.frame_dim * r, name="linear_projection/linear_layer") + self.stopnet = keras.layers.Dense(1, name="stopnet/linear_layer") def set_max_decoder_steps(self, new_max_steps): self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32) @@ -120,25 +142,31 @@ def build_decoder_initial_states(self, batch_size, memory_dim, memory_length): attention_states = self.attention.init_states(batch_size, memory_length) return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states - def step(self, prenet_next, states, - memory_seq_length=None, training=None): + def step(self, prenet_next, states, memory_seq_length=None, training=None): _, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states attention_rnn_input = tf.concat([prenet_next, context_next], -1) - attention_rnn_output, attention_rnn_state = \ - self.attention_rnn(attention_rnn_input, - attention_rnn_state, training=training) + attention_rnn_output, attention_rnn_state = self.attention_rnn( + attention_rnn_input, attention_rnn_state, training=training + ) attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training) context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training) decoder_rnn_input = tf.concat([attention_rnn_output, context], -1) - decoder_rnn_output, decoder_rnn_state = \ - self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training) + decoder_rnn_output, decoder_rnn_state = self.decoder_rnn( + decoder_rnn_input, decoder_rnn_state, training=training + ) decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training) linear_projection_input = tf.concat([decoder_rnn_output, context], -1) output_frame = self.linear_projection(linear_projection_input, training=training) stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1) stopnet_output = self.stopnet(stopnet_input, training=training) - output_frame = output_frame[:, :self.r * self.frame_dim] - states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states) + output_frame = output_frame[:, : self.r * self.frame_dim] + states = ( + output_frame[:, self.frame_dim * (self.r - 1) :], + context, + attention_rnn_state, + decoder_rnn_state, + attention_states, + ) return output_frame, stopnet_output, states, attention def decode(self, memory, states, frames, memory_seq_length=None): @@ -157,21 +185,20 @@ def decode(self, memory, states, frames, memory_seq_length=None): def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions): prenet_next = prenet_output[:, step] - output, stop_token, states, attention = self.step(prenet_next, - states, - memory_seq_length) + output, stop_token, states, attention = self.step(prenet_next, states, memory_seq_length) outputs = outputs.write(step, output) attentions = attentions.write(step, attention) stop_tokens = stop_tokens.write(step, stop_token) return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions - _, memory, _, states, outputs, stop_tokens, attentions = \ - tf.while_loop(lambda *arg: True, - _body, - loop_vars=(step_count, memory, prenet_output, - states, outputs, stop_tokens, attentions), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=num_iter) + + _, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop( + lambda *arg: True, + _body, + loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=num_iter, + ) outputs = outputs.stack() attentions = attentions.stack() @@ -200,10 +227,7 @@ def decode_inference(self, memory, states): def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): frame_next = states[0] prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, attention = self.step(prenet_next, - states, - None, - training=False) + output, stop_token, states, attention = self.step(prenet_next, states, None, training=False) stop_token = tf.math.sigmoid(stop_token) outputs = outputs.write(step, output) attentions = attentions.write(step, attention) @@ -213,14 +237,14 @@ def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - _, memory, states, outputs, stop_tokens, attentions, stop_flag = \ - tf.while_loop(cond, - _body, - loop_vars=(step_count, memory, states, outputs, - stop_tokens, attentions, stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps) + _, memory, states, outputs, stop_tokens, attentions, stop_flag = tf.while_loop( + cond, + _body, + loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps, + ) outputs = outputs.stack() attentions = attentions.stack() @@ -238,12 +262,13 @@ def decode_inference_tflite(self, memory, states): batch_size is 1""" # init states # dynamic_shape is not supported in TFLite - outputs = tf.TensorArray(dtype=tf.float32, - size=self.max_decoder_steps, - element_shape=tf.TensorShape( - [self.output_dim]), - clear_after_read=False, - dynamic_size=False) + outputs = tf.TensorArray( + dtype=tf.float32, + size=self.max_decoder_steps, + element_shape=tf.TensorShape([self.output_dim]), + clear_after_read=False, + dynamic_size=False, + ) # stop_flags = tf.TensorArray(dtype=tf.bool, # size=self.max_decoder_steps, # element_shape=tf.TensorShape( @@ -263,10 +288,7 @@ def decode_inference_tflite(self, memory, states): def _body(step, memory, states, outputs, stop_flag): frame_next = states[0] prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, _ = self.step(prenet_next, - states, - None, - training=False) + output, stop_token, states, _ = self.step(prenet_next, states, None, training=False) stop_token = tf.math.sigmoid(stop_token) stop_flag = tf.greater(stop_token, self.stop_thresh) stop_flag = tf.reduce_all(stop_flag) @@ -276,24 +298,22 @@ def _body(step, memory, states, outputs, stop_flag): return step + 1, memory, states, outputs, stop_flag cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - step_count, memory, states, outputs, stop_flag = \ - tf.while_loop(cond, - _body, - loop_vars=(step_count, memory, states, outputs, - stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps) - + step_count, memory, states, outputs, stop_flag = tf.while_loop( + cond, + _body, + loop_vars=(step_count, memory, states, outputs, stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps, + ) outputs = outputs.stack() - outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter + outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter outputs = tf.expand_dims(outputs, axis=[0]) outputs = tf.transpose(outputs, [1, 0, 2]) outputs = tf.reshape(outputs, [1, -1, self.frame_dim]) return outputs, stop_tokens, attentions - def call(self, memory, states, frames=None, memory_seq_length=None, training=False): if training: return self.decode(memory, states, frames, memory_seq_length) diff --git a/TTS/tts/tf/models/tacotron2.py b/TTS/tts/tf/models/tacotron2.py index 882af5175b..2c0e5d6f70 100644 --- a/TTS/tts/tf/models/tacotron2.py +++ b/TTS/tts/tf/models/tacotron2.py @@ -1,31 +1,33 @@ import tensorflow as tf from tensorflow import keras -from TTS.tts.tf.layers.tacotron.tacotron2 import Encoder, Decoder, Postnet +from TTS.tts.tf.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.tf.utils.tf_utils import shape_list -#pylint: disable=too-many-ancestors, abstract-method +# pylint: disable=too-many-ancestors, abstract-method class Tacotron2(keras.models.Model): - def __init__(self, - num_chars, - num_speakers, - r, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type='original', - attn_win=False, - attn_norm="softmax", - attn_K=4, - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - separate_stopnet=True, - bidirectional_decoder=False, - enable_tflite=False): + def __init__( + self, + num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="softmax", + attn_K=4, + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + separate_stopnet=True, + bidirectional_decoder=False, + enable_tflite=False, + ): super(Tacotron2, self).__init__() self.r = r self.decoder_output_dim = decoder_output_dim @@ -35,26 +37,28 @@ def __init__(self, self.speaker_embed_dim = 256 self.enable_tflite = enable_tflite - self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding') - self.encoder = Encoder(512, name='encoder') + self.embedding = keras.layers.Embedding(num_chars, 512, name="embedding") + self.encoder = Encoder(512, name="encoder") # TODO: most of the decoder args have no use at the momment - self.decoder = Decoder(decoder_output_dim, - r, - attn_type=attn_type, - use_attn_win=attn_win, - attn_norm=attn_norm, - prenet_type=prenet_type, - prenet_dropout=prenet_dropout, - use_forward_attn=forward_attn, - use_trans_agent=trans_agent, - use_forward_attn_mask=forward_attn_mask, - use_location_attn=location_attn, - attn_K=attn_K, - separate_stopnet=separate_stopnet, - speaker_emb_dim=self.speaker_embed_dim, - name='decoder', - enable_tflite=enable_tflite) - self.postnet = Postnet(postnet_output_dim, 5, name='postnet') + self.decoder = Decoder( + decoder_output_dim, + r, + attn_type=attn_type, + use_attn_win=attn_win, + attn_norm=attn_norm, + prenet_type=prenet_type, + prenet_dropout=prenet_dropout, + use_forward_attn=forward_attn, + use_trans_agent=trans_agent, + use_forward_attn_mask=forward_attn_mask, + use_location_attn=location_attn, + attn_K=attn_K, + separate_stopnet=separate_stopnet, + speaker_emb_dim=self.speaker_embed_dim, + name="decoder", + enable_tflite=enable_tflite, + ) + self.postnet = Postnet(postnet_output_dim, 5, name="postnet") @tf.function(experimental_relax_shapes=True) def call(self, characters, text_lengths=None, frames=None, training=None): @@ -62,14 +66,16 @@ def call(self, characters, text_lengths=None, frames=None, training=None): return self.training(characters, text_lengths, frames) if not training: return self.inference(characters) - raise RuntimeError(' [!] Set model training mode True or False') + raise RuntimeError(" [!] Set model training mode True or False") def training(self, characters, text_lengths, frames): B, T = shape_list(characters) embedding_vectors = self.embedding(characters, training=True) encoder_output = self.encoder(embedding_vectors, training=True) decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) - decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True) + decoder_frames, stop_tokens, attentions = self.decoder( + encoder_output, decoder_states, frames, text_lengths, training=True + ) postnet_frames = self.postnet(decoder_frames, training=True) output_frames = decoder_frames + postnet_frames return decoder_frames, output_frames, attentions, stop_tokens @@ -89,7 +95,8 @@ def inference(self, characters): experimental_relax_shapes=True, input_signature=[ tf.TensorSpec([1, None], dtype=tf.int32), - ],) + ], + ) def inference_tflite(self, characters): B, T = shape_list(characters) embedding_vectors = self.embedding(characters, training=False) @@ -101,7 +108,9 @@ def inference_tflite(self, characters): print(output_frames.shape) return decoder_frames, output_frames, attentions, stop_tokens - def build_inference(self, ): + def build_inference( + self, + ): # TODO: issue https://github.com/PyCQA/pylint/issues/3613 - input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) #pylint: disable=unexpected-keyword-arg + input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) # pylint: disable=unexpected-keyword-arg self(input_ids) diff --git a/TTS/tts/tf/utils/convert_torch_to_tf_utils.py b/TTS/tts/tf/utils/convert_torch_to_tf_utils.py index 03b4180366..5cc072d0d7 100644 --- a/TTS/tts/tf/utils/convert_torch_to_tf_utils.py +++ b/TTS/tts/tf/utils/convert_torch_to_tf_utils.py @@ -2,8 +2,9 @@ import tensorflow as tf # NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg + def tf_create_dummy_inputs(): """ Create dummy inputs for TF Tacotron2 model """ @@ -13,11 +14,11 @@ def tf_create_dummy_inputs(): pad = 1 n_chars = 24 input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32) - input_lengths = np.random.randint(0, high=max_input_length+1 + pad, size=[batch_size]) + input_lengths = np.random.randint(0, high=max_input_length + 1 + pad, size=[batch_size]) input_lengths[-1] = max_input_length input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32) mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80]) - mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size]) + mel_lengths = np.random.randint(0, high=max_mel_length + 1 + pad, size=[batch_size]) mel_lengths[-1] = max_mel_length mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32) return input_ids, input_lengths, mel_outputs, mel_lengths @@ -31,14 +32,14 @@ def compare_torch_tf(torch_tensor, tf_tensor): def convert_tf_name(tf_name): """ Convert certain patterns in TF layer names to Torch patterns """ tf_name_tmp = tf_name - tf_name_tmp = tf_name_tmp.replace(':0', '') - tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0') - tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1') - tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh') - tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight') - tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight') - tf_name_tmp = tf_name_tmp.replace('/beta', '/bias') - tf_name_tmp = tf_name_tmp.replace('/', '.') + tf_name_tmp = tf_name_tmp.replace(":0", "") + tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0") + tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1") + tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh") + tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight") + tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight") + tf_name_tmp = tf_name_tmp.replace("/beta", "/bias") + tf_name_tmp = tf_name_tmp.replace("/", ".") return tf_name_tmp @@ -47,33 +48,35 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): print(" > Passing weights from Torch to TF ...") for tf_var in tf_vars: torch_var_name = var_map_dict[tf_var.name] - print(f' | > {tf_var.name} <-- {torch_var_name}') + print(f" | > {tf_var.name} <-- {torch_var_name}") # if tuple, it is a bias variable if not isinstance(torch_var_name, tuple): - torch_layer_name = '.'.join(torch_var_name.split('.')[-2:]) + torch_layer_name = ".".join(torch_var_name.split(".")[-2:]) torch_weight = state_dict[torch_var_name] - if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name: + if "convolution1d/kernel" in tf_var.name or "conv1d/kernel" in tf_var.name: # out_dim, in_dim, filter -> filter, in_dim, out_dim numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy() - elif 'lstm_cell' in tf_var.name and 'kernel' in tf_var.name: + elif "lstm_cell" in tf_var.name and "kernel" in tf_var.name: numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() # if variable is for bidirectional lstm and it is a bias vector there # needs to be pre-defined two matching torch bias vectors - elif '_lstm/lstm_cell_' in tf_var.name and 'bias' in tf_var.name: + elif "_lstm/lstm_cell_" in tf_var.name and "bias" in tf_var.name: bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name] assert len(bias_vectors) == 2 numpy_weight = bias_vectors[0] + bias_vectors[1] - elif 'rnn' in tf_var.name and 'kernel' in tf_var.name: + elif "rnn" in tf_var.name and "kernel" in tf_var.name: numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() - elif 'rnn' in tf_var.name and 'bias' in tf_var.name: + elif "rnn" in tf_var.name and "bias" in tf_var.name: bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key] assert len(bias_vectors) == 2 numpy_weight = bias_vectors[0] + bias_vectors[1] - elif 'linear_layer' in torch_layer_name and 'weight' in torch_var_name: + elif "linear_layer" in torch_layer_name and "weight" in torch_var_name: numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() else: numpy_weight = torch_weight.detach().cpu().numpy() - assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" + assert np.all( + tf_var.shape == numpy_weight.shape + ), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" tf.keras.backend.set_value(tf_var, numpy_weight) return tf_vars diff --git a/TTS/tts/tf/utils/generic_utils.py b/TTS/tts/tf/utils/generic_utils.py index 7eba946b16..5b8b4ce221 100644 --- a/TTS/tts/tf/utils/generic_utils.py +++ b/TTS/tts/tf/utils/generic_utils.py @@ -1,26 +1,27 @@ import datetime import importlib import pickle + import numpy as np import tensorflow as tf def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): state = { - 'model': model.weights, - 'optimizer': optimizer, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), - 'r': r + "model": model.weights, + "optimizer": optimizer, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + "r": r, } state.update(kwargs) - pickle.dump(state, open(output_path, 'wb')) + pickle.dump(state, open(output_path, "wb")) def load_checkpoint(model, checkpoint_path): - checkpoint = pickle.load(open(checkpoint_path, 'rb')) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} + checkpoint = pickle.load(open(checkpoint_path, "rb")) + chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: layer_name = tf_var.name @@ -32,8 +33,8 @@ def load_checkpoint(model, checkpoint_path): chkp_var_value = chkp_var_dict[layer_name] tf.keras.backend.set_value(tf_var, chkp_var_value) - if 'r' in checkpoint.keys(): - model.decoder.set_r(checkpoint['r']) + if "r" in checkpoint.keys(): + model.decoder.set_r(checkpoint["r"]) return model @@ -45,8 +46,7 @@ def sequence_mask(sequence_length, max_len=None): seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) if sequence_length.is_cuda: seq_range_expand = seq_range_expand.cuda() - seq_length_expand = ( - sequence_length.unsqueeze(1).expand_as(seq_range_expand)) + seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) # B x T_max return seq_range_expand < seq_length_expand @@ -62,42 +62,42 @@ def count_parameters(model, c): try: return model.count_params() except RuntimeError: - input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32')) - input_lengths = np.random.randint(100, 129, (8, )) + input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32")) + input_lengths = np.random.randint(100, 129, (8,)) input_lengths[-1] = 128 - input_lengths = tf.convert_to_tensor(input_lengths.astype('int32')) - mel_spec = np.random.rand(8, 2 * c.r, - c.audio['num_mels']).astype('float32') + input_lengths = tf.convert_to_tensor(input_lengths.astype("int32")) + mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32") mel_spec = tf.convert_to_tensor(mel_spec) - speaker_ids = np.random.randint( - 0, 5, (8, )) if c.use_speaker_embedding else None + speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None _ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids) return model.count_params() def setup_model(num_chars, num_speakers, c, enable_tflite=False): print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module('TTS.tts.tf.models.' + c.model.lower()) + MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower()) MyModel = getattr(MyModel, c.model) if c.model.lower() in "tacotron": - raise NotImplementedError(' [!] Tacotron model is not ready.') + raise NotImplementedError(" [!] Tacotron model is not ready.") # tacotron2 - model = MyModel(num_chars=num_chars, - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=c.audio['num_mels'], - decoder_output_dim=c.audio['num_mels'], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - enable_tflite=enable_tflite) + model = MyModel( + num_chars=num_chars, + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=c.audio["num_mels"], + decoder_output_dim=c.audio["num_mels"], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, + enable_tflite=enable_tflite, + ) return model diff --git a/TTS/tts/tf/utils/io.py b/TTS/tts/tf/utils/io.py index 143422d279..b2345b00bf 100644 --- a/TTS/tts/tf/utils/io.py +++ b/TTS/tts/tf/utils/io.py @@ -1,24 +1,25 @@ -import pickle import datetime +import pickle + import tensorflow as tf def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): state = { - 'model': model.weights, - 'optimizer': optimizer, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), - 'r': r + "model": model.weights, + "optimizer": optimizer, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + "r": r, } state.update(kwargs) - pickle.dump(state, open(output_path, 'wb')) + pickle.dump(state, open(output_path, "wb")) def load_checkpoint(model, checkpoint_path): - checkpoint = pickle.load(open(checkpoint_path, 'rb')) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} + checkpoint = pickle.load(open(checkpoint_path, "rb")) + chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: layer_name = tf_var.name @@ -30,8 +31,8 @@ def load_checkpoint(model, checkpoint_path): chkp_var_value = chkp_var_dict[layer_name] tf.keras.backend.set_value(tf_var, chkp_var_value) - if 'r' in checkpoint.keys(): - model.decoder.set_r(checkpoint['r']) + if "r" in checkpoint.keys(): + model.decoder.set_r(checkpoint["r"]) return model diff --git a/TTS/tts/tf/utils/tflite.py b/TTS/tts/tf/utils/tflite.py index b8daf25429..9701d5910b 100644 --- a/TTS/tts/tf/utils/tflite.py +++ b/TTS/tts/tf/utils/tflite.py @@ -1,25 +1,20 @@ import tensorflow as tf -def convert_tacotron2_to_tflite(model, - output_path=None, - experimental_converter=True): +def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=True): """Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is provided, else return TFLite model.""" concrete_function = model.inference_tflite.get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_function]) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function]) converter.experimental_new_converter = experimental_converter converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS - ] + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] tflite_model = converter.convert() - print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.') + print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") if output_path is not None: # same model binary if outputpath is provided - with open(output_path, 'wb') as f: + with open(output_path, "wb") as f: f.write(tflite_model) return None return tflite_model diff --git a/TTS/tts/utils/chinese_mandarin/numbers.py b/TTS/tts/utils/chinese_mandarin/numbers.py index 94c8fd03b3..4787ea6100 100644 --- a/TTS/tts/utils/chinese_mandarin/numbers.py +++ b/TTS/tts/utils/chinese_mandarin/numbers.py @@ -1,4 +1,3 @@ - #!/usr/bin/env python3 # -*- coding: utf-8 -*- @@ -6,8 +5,8 @@ # This uses Python 3, but it's easy to port to Python 2 by changing # strings to u'xx'. -import re import itertools +import re def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: @@ -31,38 +30,37 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: # check num first nd = str(num) if abs(float(nd)) >= 1e48: - raise ValueError('number out of range') - if 'e' in nd: - raise ValueError('scientific notation is not supported') - c_symbol = '正负点' if simp else '正負點' + raise ValueError("number out of range") + if "e" in nd: + raise ValueError("scientific notation is not supported") + c_symbol = "正负点" if simp else "正負點" if o: # formal twoalt = False if big: - c_basic = '零壹贰叁肆伍陆柒捌玖' if simp else '零壹貳參肆伍陸柒捌玖' - c_unit1 = '拾佰仟' - c_twoalt = '贰' if simp else '貳' + c_basic = "零壹贰叁肆伍陆柒捌玖" if simp else "零壹貳參肆伍陸柒捌玖" + c_unit1 = "拾佰仟" + c_twoalt = "贰" if simp else "貳" else: - c_basic = '〇一二三四五六七八九' if o else '零一二三四五六七八九' - c_unit1 = '十百千' + c_basic = "〇一二三四五六七八九" if o else "零一二三四五六七八九" + c_unit1 = "十百千" if twoalt: - c_twoalt = '两' if simp else '兩' + c_twoalt = "两" if simp else "兩" else: - c_twoalt = '二' - c_unit2 = '万亿兆京垓秭穰沟涧正载' if simp else '萬億兆京垓秭穰溝澗正載' - revuniq = lambda l: ''.join(k for k, g in itertools.groupby(reversed(l))) + c_twoalt = "二" + c_unit2 = "万亿兆京垓秭穰沟涧正载" if simp else "萬億兆京垓秭穰溝澗正載" + revuniq = lambda l: "".join(k for k, g in itertools.groupby(reversed(l))) nd = str(num) result = [] - if nd[0] == '+': + if nd[0] == "+": result.append(c_symbol[0]) - elif nd[0] == '-': + elif nd[0] == "-": result.append(c_symbol[1]) - if '.' in nd: - integer, remainder = nd.lstrip('+-').split('.') + if "." in nd: + integer, remainder = nd.lstrip("+-").split(".") else: - integer, remainder = nd.lstrip('+-'), None + integer, remainder = nd.lstrip("+-"), None if int(integer): - splitted = [integer[max(i - 4, 0):i] - for i in range(len(integer), 0, -4)] + splitted = [integer[max(i - 4, 0) : i] for i in range(len(integer), 0, -4)] intresult = [] for nu, unit in enumerate(splitted): # special cases @@ -75,17 +73,17 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: ulist = [] unit = unit.zfill(4) for nc, ch in enumerate(reversed(unit)): - if ch == '0': + if ch == "0": if ulist: # ???0 ulist.append(c_basic[0]) elif nc == 0: ulist.append(c_basic[int(ch)]) - elif nc == 1 and ch == '1' and unit[1] == '0': + elif nc == 1 and ch == "1" and unit[1] == "0": # special case for tens # edit the 'elif' if you don't like # 十四, 三千零十四, 三千三百一十四 ulist.append(c_unit1[0]) - elif nc > 1 and ch == '2': + elif nc > 1 and ch == "2": ulist.append(c_twoalt + c_unit1[nc - 1]) else: ulist.append(c_basic[int(ch)] + c_unit1[nc - 1]) @@ -99,10 +97,8 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: result.append(c_basic[0]) if remainder: result.append(c_symbol[2]) - result.append(''.join(c_basic[int(ch)] for ch in remainder)) - return ''.join(result) - - + result.append("".join(c_basic[int(ch)] for ch in remainder)) + return "".join(result) def _number_replace(match) -> str: @@ -127,5 +123,5 @@ def replace_numbers_to_characters_in_text(text: str) -> str: Returns: str: output text """ - text = re.sub(r'[0-9]+', _number_replace, text) + text = re.sub(r"[0-9]+", _number_replace, text) return text diff --git a/TTS/tts/utils/chinese_mandarin/phonemizer.py b/TTS/tts/utils/chinese_mandarin/phonemizer.py index 7742c49114..29cac1606e 100644 --- a/TTS/tts/utils/chinese_mandarin/phonemizer.py +++ b/TTS/tts/utils/chinese_mandarin/phonemizer.py @@ -1,17 +1,13 @@ from typing import List +import jieba import pypinyin from .pinyinToPhonemes import PINYIN_DICT -import jieba - - def _chinese_character_to_pinyin(text: str) -> List[str]: - pinyins = pypinyin.pinyin( - text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True - ) + pinyins = pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True) pinyins_flat_list = [item for sublist in pinyins for item in sublist] return pinyins_flat_list diff --git a/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py b/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py index a4722ff967..4e25c3a4c9 100644 --- a/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py +++ b/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py @@ -1,4 +1,3 @@ - PINYIN_DICT = { "a": ["a"], "ai": ["ai"], diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py index a75410b484..a55d3a86ec 100644 --- a/TTS/tts/utils/data.py +++ b/TTS/tts/utils/data.py @@ -4,8 +4,7 @@ def _pad_data(x, length): _pad = 0 assert x.ndim == 1 - return np.pad( - x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad) def prepare_data(inputs): @@ -14,12 +13,9 @@ def prepare_data(inputs): def _pad_tensor(x, length): - _pad = 0. + _pad = 0.0 assert x.ndim == 2 - x = np.pad( - x, [[0, 0], [0, length - x.shape[1]]], - mode='constant', - constant_values=_pad) + x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], mode="constant", constant_values=_pad) return x @@ -31,10 +27,9 @@ def prepare_tensor(inputs, out_steps): def _pad_stop_target(x, length): - _pad = 0. + _pad = 0.0 assert x.ndim == 1 - return np.pad( - x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad) def prepare_stop_target(inputs, out_steps): @@ -46,22 +41,18 @@ def prepare_stop_target(inputs, out_steps): def pad_per_step(inputs, pad_len): - return np.pad( - inputs, [[0, 0], [0, 0], [0, pad_len]], - mode='constant', - constant_values=0.0) + return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0) # pylint: disable=attribute-defined-outside-init -class StandardScaler(): - +class StandardScaler: def set_stats(self, mean, scale): self.mean_ = mean self.scale_ = scale def reset_stats(self): - delattr(self, 'mean_') - delattr(self, 'scale_') + delattr(self, "mean_") + delattr(self, "scale_") def transform(self, X): X = np.asarray(X) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 44d961ec24..8c23dd84ee 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -1,9 +1,10 @@ -import re -import torch import importlib -import numpy as np +import re from collections import Counter +import numpy as np +import torch + from TTS.utils.generic_utils import check_argument @@ -28,277 +29,334 @@ def split_dataset(items): return items_eval, items return items[:eval_split_size], items[eval_split_size:] + # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 def sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() - seq_range = torch.arange(max_len, - dtype=sequence_length.dtype, - device=sequence_length.device) + seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) # B x T_max return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) def to_camel(text): text = text.capitalize() - text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) - text = text.replace('Tts', 'TTS') + text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + text = text.replace("Tts", "TTS") return text def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower()) + MyModel = importlib.import_module("TTS.tts.models." + c.model.lower()) MyModel = getattr(MyModel, to_camel(c.model)) if c.model.lower() in "tacotron": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=int(c.audio['fft_size'] / 2 + 1), - decoder_output_dim=c.audio['num_mels'], - gst=c.use_gst, - gst_embedding_dim=c.gst['gst_embedding_dim'], - gst_num_heads=c.gst['gst_num_heads'], - gst_style_tokens=c.gst['gst_style_tokens'], - gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'], - memory_size=c.memory_size, - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r, - speaker_embedding_dim=speaker_embedding_dim) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=int(c.audio["fft_size"] / 2 + 1), + decoder_output_dim=c.audio["num_mels"], + gst=c.use_gst, + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"], + memory_size=c.memory_size, + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, + double_decoder_consistency=c.double_decoder_consistency, + ddc_r=c.ddc_r, + speaker_embedding_dim=speaker_embedding_dim, + ) elif c.model.lower() == "tacotron2": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=c.audio['num_mels'], - decoder_output_dim=c.audio['num_mels'], - gst=c.use_gst, - gst_embedding_dim=c.gst['gst_embedding_dim'], - gst_num_heads=c.gst['gst_num_heads'], - gst_style_tokens=c.gst['gst_style_tokens'], - gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r, - speaker_embedding_dim=speaker_embedding_dim) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=c.audio["num_mels"], + decoder_output_dim=c.audio["num_mels"], + gst=c.use_gst, + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, + double_decoder_consistency=c.double_decoder_consistency, + ddc_r=c.ddc_r, + speaker_embedding_dim=speaker_embedding_dim, + ) elif c.model.lower() == "glow_tts": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - hidden_channels_enc=c['hidden_channels_encoder'], - hidden_channels_dec=c['hidden_channels_decoder'], - hidden_channels_dp=c['hidden_channels_duration_predictor'], - out_channels=c.audio['num_mels'], - encoder_type=c.encoder_type, - encoder_params=c.encoder_params, - use_encoder_prenet=c["use_encoder_prenet"], - num_flow_blocks_dec=12, - kernel_size_dec=5, - dilation_rate=1, - num_block_layers=4, - dropout_p_dec=0.05, - num_speakers=num_speakers, - c_in_channels=0, - num_splits=4, - num_squeeze=2, - sigmoid_scale=False, - mean_only=True, - external_speaker_embedding_dim=speaker_embedding_dim) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + hidden_channels_enc=c["hidden_channels_encoder"], + hidden_channels_dec=c["hidden_channels_decoder"], + hidden_channels_dp=c["hidden_channels_duration_predictor"], + out_channels=c.audio["num_mels"], + encoder_type=c.encoder_type, + encoder_params=c.encoder_params, + use_encoder_prenet=c["use_encoder_prenet"], + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=1, + num_block_layers=4, + dropout_p_dec=0.05, + num_speakers=num_speakers, + c_in_channels=0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + mean_only=True, + external_speaker_embedding_dim=speaker_embedding_dim, + ) elif c.model.lower() == "speedy_speech": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - out_channels=c.audio['num_mels'], - hidden_channels=c['hidden_channels'], - positional_encoding=c['positional_encoding'], - encoder_type=c['encoder_type'], - encoder_params=c['encoder_params'], - decoder_type=c['decoder_type'], - decoder_params=c['decoder_params'], - c_in_channels=0) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + out_channels=c.audio["num_mels"], + hidden_channels=c["hidden_channels"], + positional_encoding=c["positional_encoding"], + encoder_type=c["encoder_type"], + encoder_params=c["encoder_params"], + decoder_type=c["decoder_type"], + decoder_params=c["decoder_params"], + c_in_channels=0, + ) elif c.model.lower() == "align_tts": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - out_channels=c.audio['num_mels'], - hidden_channels=c['hidden_channels'], - hidden_channels_dp=c['hidden_channels_dp'], - encoder_type=c['encoder_type'], - encoder_params=c['encoder_params'], - decoder_type=c['decoder_type'], - decoder_params=c['decoder_params'], - c_in_channels=0) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + out_channels=c.audio["num_mels"], + hidden_channels=c["hidden_channels"], + hidden_channels_dp=c["hidden_channels_dp"], + encoder_type=c["encoder_type"], + encoder_params=c["encoder_params"], + decoder_type=c["decoder_type"], + decoder_params=c["decoder_params"], + c_in_channels=0, + ) return model + def is_tacotron(c): - return 'tacotron' in c['model'].lower() + return "tacotron" in c["model"].lower() + def check_config_tts(c): - check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str) - check_argument('run_name', c, restricted=True, val_type=str) - check_argument('run_description', c, val_type=str) + check_argument( + "model", + c, + enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"], + restricted=True, + val_type=str, + ) + check_argument("run_name", c, restricted=True, val_type=str) + check_argument("run_description", c, val_type=str) # AUDIO - check_argument('audio', c, restricted=True, val_type=dict) + check_argument("audio", c, restricted=True, val_type=dict) # audio processing parameters - check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) - check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) - check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) - check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') - check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') - check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) - check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) - check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) - check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) - check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056) + check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058) + check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c["audio"], + restricted=True, + val_type=float, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument( + "frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length" + ) + check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1) + check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10) + check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000) + check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000) # vocabulary parameters - check_argument('characters', c, restricted=False, val_type=dict) - check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys() and c['use_phonemes'], val_type=str) - check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + check_argument("characters", c, restricted=False, val_type=dict) + check_argument( + "pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str + ) + check_argument( + "eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str + ) + check_argument( + "bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str + ) + check_argument( + "characters", + c["characters"] if "characters" in c.keys() else {}, + restricted="characters" in c.keys(), + val_type=str, + ) + check_argument( + "phonemes", + c["characters"] if "characters" in c.keys() else {}, + restricted="characters" in c.keys() and c["use_phonemes"], + val_type=str, + ) + check_argument( + "punctuations", + c["characters"] if "characters" in c.keys() else {}, + restricted="characters" in c.keys(), + val_type=str, + ) # normalization parameters - check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) - check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) - check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) - check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) - check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) - check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) - check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100) - check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) - check_argument('trim_db', c['audio'], restricted=True, val_type=int) + check_argument("signal_norm", c["audio"], restricted=True, val_type=bool) + check_argument("symmetric_norm", c["audio"], restricted=True, val_type=bool) + check_argument("max_norm", c["audio"], restricted=True, val_type=float, min_val=0.1, max_val=1000) + check_argument("clip_norm", c["audio"], restricted=True, val_type=bool) + check_argument("mel_fmin", c["audio"], restricted=True, val_type=float, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c["audio"], restricted=True, val_type=float, min_val=500.0) + check_argument("spec_gain", c["audio"], restricted=True, val_type=[int, float], min_val=1, max_val=100) + check_argument("do_trim_silence", c["audio"], restricted=True, val_type=bool) + check_argument("trim_db", c["audio"], restricted=True, val_type=int) # training parameters - check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) - check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) - check_argument('r', c, restricted=True, val_type=int, min_val=1) - check_argument('gradual_training', c, restricted=False, val_type=list) - check_argument('mixed_precision', c, restricted=False, val_type=bool) + check_argument("batch_size", c, restricted=True, val_type=int, min_val=1) + check_argument("eval_batch_size", c, restricted=True, val_type=int, min_val=1) + check_argument("r", c, restricted=True, val_type=int, min_val=1) + check_argument("gradual_training", c, restricted=False, val_type=list) + check_argument("mixed_precision", c, restricted=False, val_type=bool) # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) # loss parameters - check_argument('loss_masking', c, restricted=True, val_type=bool) - if c['model'].lower() in ['tacotron', 'tacotron2']: - check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0) - if c['model'].lower in ["speedy_speech", "align_tts"]: - check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument("loss_masking", c, restricted=True, val_type=bool) + if c["model"].lower() in ["tacotron", "tacotron2"]: + check_argument("decoder_loss_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("postnet_loss_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("postnet_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("decoder_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("decoder_ssim_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("postnet_ssim_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("ga_alpha", c, restricted=True, val_type=float, min_val=0) + if c["model"].lower in ["speedy_speech", "align_tts"]: + check_argument("ssim_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("l1_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("huber_alpha", c, restricted=True, val_type=float, min_val=0) # validation parameters - check_argument('run_eval', c, restricted=True, val_type=bool) - check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) - check_argument('test_sentences_file', c, restricted=False, val_type=str) + check_argument("run_eval", c, restricted=True, val_type=bool) + check_argument("test_delay_epochs", c, restricted=True, val_type=int, min_val=0) + check_argument("test_sentences_file", c, restricted=False, val_type=str) # optimizer - check_argument('noam_schedule', c, restricted=False, val_type=bool) - check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0) - check_argument('epochs', c, restricted=True, val_type=int, min_val=1) - check_argument('lr', c, restricted=True, val_type=float, min_val=0) - check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0) - check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) - check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool) + check_argument("noam_schedule", c, restricted=False, val_type=bool) + check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0) + check_argument("epochs", c, restricted=True, val_type=int, min_val=1) + check_argument("lr", c, restricted=True, val_type=float, min_val=0) + check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0) + check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0) + check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool) # tacotron prenet - check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1) - check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn']) - check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool) + check_argument("memory_size", c, restricted=is_tacotron(c), val_type=int, min_val=-1) + check_argument("prenet_type", c, restricted=is_tacotron(c), val_type=str, enum_list=["original", "bn"]) + check_argument("prenet_dropout", c, restricted=is_tacotron(c), val_type=bool) # attention - check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution']) - check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int) - check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax']) - check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool) - check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool) - check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool) - check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) - check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) - check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool) - check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool) - check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool) - check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int) - - if c['model'].lower() in ['tacotron', 'tacotron2']: + check_argument( + "attention_type", + c, + restricted=is_tacotron(c), + val_type=str, + enum_list=["graves", "original", "dynamic_convolution"], + ) + check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int) + check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"]) + check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool) + check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool) + check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool) + check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool) + check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool) + check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool) + check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool) + check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool) + check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int) + + if c["model"].lower() in ["tacotron", "tacotron2"]: # stopnet - check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) - check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) + check_argument("stopnet", c, restricted=is_tacotron(c), val_type=bool) + check_argument("separate_stopnet", c, restricted=is_tacotron(c), val_type=bool) # Model Parameters for non-tacotron models - if c['model'].lower in ["speedy_speech", "align_tts"]: - check_argument('positional_encoding', c, restricted=True, val_type=type) - check_argument('encoder_type', c, restricted=True, val_type=str) - check_argument('encoder_params', c, restricted=True, val_type=dict) - check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict) + if c["model"].lower in ["speedy_speech", "align_tts"]: + check_argument("positional_encoding", c, restricted=True, val_type=type) + check_argument("encoder_type", c, restricted=True, val_type=str) + check_argument("encoder_params", c, restricted=True, val_type=dict) + check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict) # GlowTTS parameters - check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str) + check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str) # tensorboard - check_argument('print_step', c, restricted=True, val_type=int, min_val=1) - check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1) - check_argument('save_step', c, restricted=True, val_type=int, min_val=1) - check_argument('checkpoint', c, restricted=True, val_type=bool) - check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) + check_argument("print_step", c, restricted=True, val_type=int, min_val=1) + check_argument("tb_plot_step", c, restricted=True, val_type=int, min_val=1) + check_argument("save_step", c, restricted=True, val_type=int, min_val=1) + check_argument("checkpoint", c, restricted=True, val_type=bool) + check_argument("tb_model_param_stats", c, restricted=True, val_type=bool) # dataloading # pylint: disable=import-outside-toplevel from TTS.tts.utils.text import cleaners - check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) - check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) - check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) - check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) - check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) - check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) - check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) - check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool) + + check_argument("text_cleaner", c, restricted=True, val_type=str, enum_list=dir(cleaners)) + check_argument("enable_eos_bos_chars", c, restricted=True, val_type=bool) + check_argument("num_loader_workers", c, restricted=True, val_type=int, min_val=0) + check_argument("num_val_loader_workers", c, restricted=True, val_type=int, min_val=0) + check_argument("batch_group_size", c, restricted=True, val_type=int, min_val=0) + check_argument("min_seq_len", c, restricted=True, val_type=int, min_val=0) + check_argument("max_seq_len", c, restricted=True, val_type=int, min_val=10) + check_argument("compute_input_seq_cache", c, restricted=True, val_type=bool) # paths - check_argument('output_path', c, restricted=True, val_type=str) + check_argument("output_path", c, restricted=True, val_type=str) # multi-speaker and gst - check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) - check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool) - check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str) - if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']: - check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) - check_argument('gst', c, restricted=is_tacotron(c), val_type=dict) - check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict]) - check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) - check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool) - check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10) - check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000) + check_argument("use_speaker_embedding", c, restricted=True, val_type=bool) + check_argument("use_external_speaker_embedding_file", c, restricted=c["use_speaker_embedding"], val_type=bool) + check_argument( + "external_speaker_embedding_file", c, restricted=c["use_external_speaker_embedding_file"], val_type=str + ) + if c["model"].lower() in ["tacotron", "tacotron2"] and c["use_gst"]: + check_argument("use_gst", c, restricted=is_tacotron(c), val_type=bool) + check_argument("gst", c, restricted=is_tacotron(c), val_type=dict) + check_argument("gst_style_input", c["gst"], restricted=is_tacotron(c), val_type=[str, dict]) + check_argument("gst_embedding_dim", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) + check_argument("gst_use_speaker_embedding", c["gst"], restricted=is_tacotron(c), val_type=bool) + check_argument("gst_num_heads", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10) + check_argument("gst_style_tokens", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000) # datasets - checking only the first entry - check_argument('datasets', c, restricted=True, val_type=list) - for dataset_entry in c['datasets']: - check_argument('name', dataset_entry, restricted=True, val_type=str) - check_argument('path', dataset_entry, restricted=True, val_type=str) - check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) - check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) + check_argument("datasets", c, restricted=True, val_type=list) + for dataset_entry in c["datasets"]: + check_argument("name", dataset_entry, restricted=True, val_type=str) + check_argument("path", dataset_entry, restricted=True, val_type=str) + check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list]) + check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str) diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index bcf5ff371d..bb8432fad4 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -1,10 +1,10 @@ -import os -import torch import datetime +import os import pickle as pickle_tts -from TTS.utils.io import RenamingUnpickler +import torch +from TTS.utils.io import RenamingUnpickler def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): # pylint: disable=redefined-builtin @@ -20,33 +20,25 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False [type]: [description] """ try: - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) except ModuleNotFoundError: pickle_tts.Unpickler = RenamingUnpickler - state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts) - model.load_state_dict(state['model']) - if amp and 'amp' in state: - amp.load_state_dict(state['amp']) + state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + model.load_state_dict(state["model"]) + if amp and "amp" in state: + amp.load_state_dict(state["amp"]) if use_cuda: model.cuda() # set model stepsize - if hasattr(model.decoder, 'r'): - model.decoder.set_r(state['r']) - print(" > Model r: ", state['r']) + if hasattr(model.decoder, "r"): + model.decoder.set_r(state["r"]) + print(" > Model r: ", state["r"]) if eval: model.eval() return model, state -def save_model(model, - optimizer, - current_step, - epoch, - r, - output_path, - characters, - amp_state_dict=None, - **kwargs): +def save_model(model, optimizer, current_step, epoch, r, output_path, characters, amp_state_dict=None, **kwargs): """Save ```TTS.tts.models``` states with extra fields. Args: @@ -59,27 +51,26 @@ def save_model(model, characters (list): list of characters used in the model. amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None. """ - if hasattr(model, 'module'): + if hasattr(model, "module"): model_state = model.module.state_dict() else: model_state = model.state_dict() state = { - 'model': model_state, - 'optimizer': optimizer.state_dict() if optimizer is not None else None, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), - 'r': r, - 'characters': characters + "model": model_state, + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + "r": r, + "characters": characters, } if amp_state_dict: - state['amp'] = amp_state_dict + state["amp"] = amp_state_dict state.update(kwargs) torch.save(state, output_path) -def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, - characters, **kwargs): +def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs): """Save model checkpoint, intended for saving checkpoints at training. Args: @@ -91,14 +82,15 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, output_path (str): output path to save the model file. characters (list): list of characters used in the model. """ - file_name = 'checkpoint_{}.pth.tar'.format(current_step) + file_name = "checkpoint_{}.pth.tar".format(current_step) checkpoint_path = os.path.join(output_folder, file_name) print(" > CHECKPOINT : {}".format(checkpoint_path)) save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs) -def save_best_model(target_loss, best_loss, model, optimizer, current_step, - epoch, r, output_folder, characters, **kwargs): +def save_best_model( + target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs +): """Save model checkpoint, intended for saving the best model after each epoch. It compares the current model loss with the best loss so far and saves the model if the current loss is better. @@ -118,9 +110,11 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, float: updated current best loss. """ if target_loss < best_loss: - file_name = 'best_model.pth.tar' + file_name = "best_model.pth.tar" checkpoint_path = os.path.join(output_folder, file_name) print(" >> BEST MODEL : {}".format(checkpoint_path)) - save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs) + save_model( + model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs + ) best_loss = target_loss return best_loss diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index feb1a8450d..25786d7050 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -1,5 +1,5 @@ -import os import json +import os def make_speakers_json_path(out_path): @@ -10,7 +10,7 @@ def make_speakers_json_path(out_path): def load_speaker_mapping(out_path): """Loads speaker mapping if already present.""" try: - if os.path.splitext(out_path)[1] == '.json': + if os.path.splitext(out_path)[1] == ".json": json_file = out_path else: json_file = make_speakers_json_path(out_path) @@ -19,6 +19,7 @@ def load_speaker_mapping(out_path): except FileNotFoundError: return {} + def save_speaker_mapping(out_path, speaker_mapping): """Saves speaker mapping if not yet present.""" speakers_json_path = make_speakers_json_path(out_path) @@ -31,40 +32,49 @@ def get_speakers(items): speakers = {e[2] for e in items} return sorted(speakers) + def parse_speakers(c, args, meta_data_train, OUT_PATH): """ Returns number of speakers, speaker embedding shape and speaker mapping""" if c.use_speaker_embedding: speakers = get_speakers(meta_data_train) if args.restore_path: - if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file + if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file prev_out_path = os.path.dirname(args.restore_path) speaker_mapping = load_speaker_mapping(prev_out_path) if not speaker_mapping: - print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file") + print( + "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" + ) speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) if not speaker_mapping: - raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file") - speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) - elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file + raise RuntimeError( + "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file" + ) + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"]) + elif ( + not c.use_external_speaker_embedding_file + ): # if restore checkpoint and don't use External Embedding file prev_out_path = os.path.dirname(args.restore_path) speaker_mapping = load_speaker_mapping(prev_out_path) speaker_embedding_dim = None - assert all([speaker in speaker_mapping - for speaker in speakers]), "As of now you, you cannot " \ - "introduce new speakers to " \ - "a previously trained model." - elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file + assert all([speaker in speaker_mapping for speaker in speakers]), ( + "As of now you, you cannot " "introduce new speakers to " "a previously trained model." + ) + elif ( + c.use_external_speaker_embedding_file and c.external_speaker_embedding_file + ): # if start new train using External Embedding file speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) - speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) - elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"]) + elif ( + c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file + ): # if start new train using External Embedding file and don't pass external embedding file raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" - else: # if start new train and don't use External Embedding file + else: # if start new train and don't use External Embedding file speaker_mapping = {name: i for i, name in enumerate(speakers)} speaker_embedding_dim = None save_speaker_mapping(OUT_PATH, speaker_mapping) num_speakers = len(speaker_mapping) - print(" > Training with {} speakers: {}".format( - len(speakers), ", ".join(speakers))) + print(" > Training with {} speakers: {}".format(len(speakers), ", ".join(speakers))) else: num_speakers = 0 speaker_embedding_dim = None diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index 8f4c4cae5c..11107e47c2 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -8,8 +8,9 @@ def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) - return gauss/gauss.sum() + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) @@ -24,25 +25,22 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) - mu1_mu2 = mu1*mu2 + mu1_mu2 = mu1 * mu2 - sigma1_sq = F.conv2d( - img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq - sigma2_sq = F.conv2d( - img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq - sigma12 = F.conv2d( - img1 * img2, window, padding=window_size // 2, - groups=channel) - mu1_mu2 + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 - C1 = 0.01**2 - C2 = 0.03**2 + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 - ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() return ssim_map.mean(1).mean(1).mean(1) + class SSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True): super().__init__() @@ -66,7 +64,6 @@ def forward(self, img1, img2): self.window = window self.channel = channel - return _ssim(img1, img2, window, self.window_size, channel, self.size_average) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index f825d61c6b..30e5feabc7 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -1,12 +1,16 @@ import os -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import pkg_resources -installed = {pkg.key for pkg in pkg_resources.working_set} #pylint: disable=not-an-iterable -if 'tensorflow' in installed or 'tensorflow-gpu' in installed: + +installed = {pkg.key for pkg in pkg_resources.working_set} # pylint: disable=not-an-iterable +if "tensorflow" in installed or "tensorflow-gpu" in installed: import tensorflow as tf -import torch + import numpy as np -from .text import text_to_sequence, phoneme_to_sequence +import torch + +from .text import phoneme_to_sequence, text_to_sequence def text_to_seqvec(text, CONFIG): @@ -14,19 +18,26 @@ def text_to_seqvec(text, CONFIG): # text ot phonemes to sequence vector if CONFIG.use_phonemes: seq = np.asarray( - phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, - CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, - add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), - dtype=np.int32) + phoneme_to_sequence( + text, + text_cleaner, + CONFIG.phoneme_language, + CONFIG.enable_eos_bos_chars, + tp=CONFIG.characters if "characters" in CONFIG.keys() else None, + add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False, + ), + dtype=np.int32, + ) else: - seq = np.asarray(text_to_sequence( - text, - text_cleaner, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, - add_blank=CONFIG['add_blank'] - if 'add_blank' in CONFIG.keys() else False), - dtype=np.int32) + seq = np.asarray( + text_to_sequence( + text, + text_cleaner, + tp=CONFIG.characters if "characters" in CONFIG.keys() else None, + add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False, + ), + dtype=np.int32, + ) return seq @@ -47,86 +58,95 @@ def numpy_to_tf(np_array, dtype): def compute_style_mel(style_wav, ap, cuda=False): - style_mel = torch.FloatTensor(ap.melspectrogram( - ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) + style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) if cuda: return style_mel.cuda() return style_mel def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): - if 'tacotron' in CONFIG.model.lower(): + if "tacotron" in CONFIG.model.lower(): if CONFIG.use_gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) else: if truncated: decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) else: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) - elif 'glow' in CONFIG.model.lower(): + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) + elif "glow" in CONFIG.model.lower(): inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable - if hasattr(model, 'module'): + if hasattr(model, "module"): # distributed model - postnet_output, _, _, _, alignments, _, _ = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, _, _, _, alignments, _, _ = model.module.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) else: - postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, _, _, _, alignments, _, _ = model.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None stop_tokens = None - elif CONFIG.model.lower() in ['speedy_speech', 'align_tts']: + elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]: inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable - if hasattr(model, 'module'): + if hasattr(model, "module"): # distributed model - postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, alignments = model.module.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) else: - postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, alignments = model.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None stop_tokens = None else: - raise ValueError('[!] Unknown model name.') + raise ValueError("[!] Unknown model name.") return decoder_output, postnet_output, alignments, stop_tokens def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): if CONFIG.use_gst and style_mel is not None: - raise NotImplementedError(' [!] GST inference not implemented for TF') + raise NotImplementedError(" [!] GST inference not implemented for TF") if truncated: - raise NotImplementedError(' [!] Truncated inference not implemented for TF') + raise NotImplementedError(" [!] Truncated inference not implemented for TF") if speaker_id is not None: - raise NotImplementedError(' [!] Multi-Speaker not implemented for TF') + raise NotImplementedError(" [!] Multi-Speaker not implemented for TF") # TODO: handle multispeaker case - decoder_output, postnet_output, alignments, stop_tokens = model( - inputs, training=False) + decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False) return decoder_output, postnet_output, alignments, stop_tokens def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): if CONFIG.use_gst and style_mel is not None: - raise NotImplementedError(' [!] GST inference not implemented for TfLite') + raise NotImplementedError(" [!] GST inference not implemented for TfLite") if truncated: - raise NotImplementedError(' [!] Truncated inference not implemented for TfLite') + raise NotImplementedError(" [!] Truncated inference not implemented for TfLite") if speaker_id is not None: - raise NotImplementedError(' [!] Multi-Speaker not implemented for TfLite') + raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite") # get input and output details input_details = model.get_input_details() output_details = model.get_output_details() # reshape input tensor for the new input shape - model.resize_tensor_input(input_details[0]['index'], inputs.shape) + model.resize_tensor_input(input_details[0]["index"], inputs.shape) model.allocate_tensors() detail = input_details[0] # input_shape = detail['shape'] - model.set_tensor(detail['index'], inputs) + model.set_tensor(detail["index"], inputs) # run the model model.invoke() # collect outputs - decoder_output = model.get_tensor(output_details[0]['index']) - postnet_output = model.get_tensor(output_details[1]['index']) + decoder_output = model.get_tensor(output_details[0]["index"]) + postnet_output = model.get_tensor(output_details[1]["index"]) # tflite model only returns feature frames return decoder_output, postnet_output, None, None @@ -154,7 +174,7 @@ def parse_outputs_tflite(postnet_output, decoder_output): def trim_silence(wav, ap): - return wav[:ap.find_endpoint(wav)] + return wav[: ap.find_endpoint(wav)] def inv_spectrogram(postnet_output, ap, CONFIG): @@ -186,13 +206,13 @@ def embedding_to_torch(speaker_embedding, cuda=False): # TODO: perform GL with pytorch for batching def apply_griffin_lim(inputs, input_lens, CONFIG, ap): - '''Apply griffin-lim to each sample iterating throught the first dimension. + """Apply griffin-lim to each sample iterating throught the first dimension. Args: inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size. input_lens (Tensor or np.Array): 1D array of sample lengths. CONFIG (Dict): TTS config. ap (AudioProcessor): TTS audio processor. - ''' + """ wavs = [] for idx, spec in enumerate(inputs): wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding @@ -202,39 +222,41 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap): return wavs -def synthesis(model, - text, - CONFIG, - use_cuda, - ap, - speaker_id=None, - style_wav=None, - truncated=False, - enable_eos_bos_chars=False, #pylint: disable=unused-argument - use_griffin_lim=False, - do_trim_silence=False, - speaker_embedding=None, - backend='torch'): +def synthesis( + model, + text, + CONFIG, + use_cuda, + ap, + speaker_id=None, + style_wav=None, + truncated=False, + enable_eos_bos_chars=False, # pylint: disable=unused-argument + use_griffin_lim=False, + do_trim_silence=False, + speaker_embedding=None, + backend="torch", +): """Synthesize voice for the given text. - Args: - model (TTS.tts.models): model to synthesize. - text (str): target text - CONFIG (dict): config dictionary to be loaded from config.json. - use_cuda (bool): enable cuda. - ap (TTS.tts.utils.audio.AudioProcessor): audio processor to process - model outputs. - speaker_id (int): id of speaker - style_wav (str | Dict[str, float]): Uses for style embedding of GST. - truncated (bool): keep model states after inference. It can be used - for continuous inference at long texts. - enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence. - do_trim_silence (bool): trim silence after synthesis. - backend (str): tf or torch + Args: + model (TTS.tts.models): model to synthesize. + text (str): target text + CONFIG (dict): config dictionary to be loaded from config.json. + use_cuda (bool): enable cuda. + ap (TTS.tts.utils.audio.AudioProcessor): audio processor to process + model outputs. + speaker_id (int): id of speaker + style_wav (str | Dict[str, float]): Uses for style embedding of GST. + truncated (bool): keep model states after inference. It can be used + for continuous inference at long texts. + enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence. + do_trim_silence (bool): trim silence after synthesis. + backend (str): tf or torch """ # GST processing style_mel = None - if 'use_gst' in CONFIG.keys() and CONFIG.use_gst and style_wav is not None: + if "use_gst" in CONFIG.keys() and CONFIG.use_gst and style_wav is not None: if isinstance(style_wav, dict): style_mel = style_wav else: @@ -242,7 +264,7 @@ def synthesis(model, # preprocess the given text inputs = text_to_seqvec(text, CONFIG) # pass tensors to backend - if backend == 'torch': + if backend == "torch": if speaker_id is not None: speaker_id = id_to_torch(speaker_id, cuda=use_cuda) @@ -253,31 +275,35 @@ def synthesis(model, style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda) inputs = inputs.unsqueeze(0) - elif backend == 'tf': + elif backend == "tf": # TODO: handle speaker id for tf model style_mel = numpy_to_tf(style_mel, tf.float32) inputs = numpy_to_tf(inputs, tf.int32) inputs = tf.expand_dims(inputs, 0) - elif backend == 'tflite': + elif backend == "tflite": style_mel = numpy_to_tf(style_mel, tf.float32) inputs = numpy_to_tf(inputs, tf.int32) inputs = tf.expand_dims(inputs, 0) # synthesize voice - if backend == 'torch': + if backend == "torch": decoder_output, postnet_output, alignments, stop_tokens = run_model_torch( - model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding) + model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding + ) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch( - postnet_output, decoder_output, alignments, stop_tokens) - elif backend == 'tf': + postnet_output, decoder_output, alignments, stop_tokens + ) + elif backend == "tf": decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( - model, inputs, CONFIG, truncated, speaker_id, style_mel) + model, inputs, CONFIG, truncated, speaker_id, style_mel + ) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf( - postnet_output, decoder_output, alignments, stop_tokens) - elif backend == 'tflite': + postnet_output, decoder_output, alignments, stop_tokens + ) + elif backend == "tflite": decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite( - model, inputs, CONFIG, truncated, speaker_id, style_mel) - postnet_output, decoder_output = parse_outputs_tflite( - postnet_output, decoder_output) + model, inputs, CONFIG, truncated, speaker_id, style_mel + ) + postnet_output, decoder_output = parse_outputs_tflite(postnet_output, decoder_output) # convert outputs to numpy # plot results wav = None diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 2a724650f1..64d15b0154 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -5,11 +5,10 @@ import phonemizer from packaging import version from phonemizer.phonemize import phonemize -from TTS.tts.utils.text import cleaners -from TTS.tts.utils.text.symbols import (_bos, _eos, _punctuations, - make_symbols, phonemes, symbols) -from TTS.tts.utils.chinese_mandarin.phonemizer import chinese_text_to_phonemes +from TTS.tts.utils.chinese_mandarin.phonemizer import chinese_text_to_phonemes +from TTS.tts.utils.text import cleaners +from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols # pylint: disable=unnecessary-comprehension # Mappings from symbol to numeric ID and vice versa: @@ -22,14 +21,14 @@ _symbols = symbols _phonemes = phonemes # Regular expression matching text enclosed in curly braces: -_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)') +_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)") # Regular expression matching punctuations, ignoring empty space -PHONEME_PUNCTUATION_PATTERN = r'['+_punctuations.replace(' ', '')+']+' +PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+" def text2phone(text, language): - '''Convert graphemes to phonemes. For most of the languages, it calls + """Convert graphemes to phonemes. For most of the languages, it calls the phonemizer python library that calls espeak/espeak-ng. For chinese mandarin, it calls pypinyin + custom function for phonemizing Parameters: @@ -38,60 +37,73 @@ def text2phone(text, language): Returns: ph (str): phonemes as a string seperated by "|" ph = "ɪ|g|ˈ|z|æ|m|p|ə|l" - ''' + """ # TO REVIEW : How to have a good implementation for this? if language == "zh-CN": ph = chinese_text_to_phonemes(text) return ph - - seperator = phonemizer.separator.Separator(' |', '', '|') - #try: + seperator = phonemizer.separator.Separator(" |", "", "|") + # try: punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text) - if version.parse(phonemizer.__version__) < version.parse('2.1'): - ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language) - ph = ph[:-1].strip() # skip the last empty character + if version.parse(phonemizer.__version__) < version.parse("2.1"): + ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend="espeak", language=language) + ph = ph[:-1].strip() # skip the last empty character # phonemizer does not tackle punctuations. Here we do. # Replace \n with matching punctuations. if punctuations: # if text ends with a punctuation. if text[-1] == punctuations[-1]: for punct in punctuations[:-1]: - ph = ph.replace('| |\n', '|'+punct+'| |', 1) + ph = ph.replace("| |\n", "|" + punct + "| |", 1) ph = ph + punctuations[-1] else: for punct in punctuations: - ph = ph.replace('| |\n', '|'+punct+'| |', 1) - elif version.parse(phonemizer.__version__) >= version.parse('2.1'): - ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True, language_switch='remove-flags') + ph = ph.replace("| |\n", "|" + punct + "| |", 1) + elif version.parse(phonemizer.__version__) >= version.parse("2.1"): + ph = phonemize( + text, + separator=seperator, + strip=False, + njobs=1, + backend="espeak", + language=language, + preserve_punctuation=True, + language_switch="remove-flags", + ) # this is a simple fix for phonemizer. # https://github.com/bootphon/phonemizer/issues/32 if punctuations: for punctuation in punctuations: - ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace(f"| |{punctuation}", f"|{punctuation}| |") + ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace( + f"| |{punctuation}", f"|{punctuation}| |" + ) ph = ph[:-3] else: raise RuntimeError(" [!] Use 'phonemizer' version 2.1 or older.") return ph + def intersperse(sequence, token): result = [token] * (len(sequence) * 2 + 1) result[1::2] = sequence return result + def pad_with_eos_bos(phoneme_sequence, tp=None): # pylint: disable=global-statement global _phonemes_to_id, _bos, _eos if tp: - _bos = tp['bos'] - _eos = tp['eos'] + _bos = tp["bos"] + _eos = tp["eos"] _, _phonemes = make_symbols(**tp) _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] + def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False): # pylint: disable=global-statement global _phonemes_to_id, _phonemes @@ -105,23 +117,23 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp= if to_phonemes is None: print("!! After phoneme conversion the result is None. -- {} ".format(clean_text)) # iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation. - for phoneme in filter(None, to_phonemes.split('|')): + for phoneme in filter(None, to_phonemes.split("|")): sequence += _phoneme_to_sequence(phoneme) # Append EOS char if enable_eos_bos: sequence = pad_with_eos_bos(sequence, tp=tp) if add_blank: - sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) + sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) return sequence def sequence_to_phoneme(sequence, tp=None, add_blank=False): # pylint: disable=global-statement - '''Converts a sequence of IDs back to a string''' + """Converts a sequence of IDs back to a string""" global _id_to_phonemes, _phonemes if add_blank: sequence = list(filter(lambda x: x != len(_phonemes), sequence)) - result = '' + result = "" if tp: _, _phonemes = make_symbols(**tp) _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} @@ -130,22 +142,22 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False): if symbol_id in _id_to_phonemes: s = _id_to_phonemes[symbol_id] result += s - return result.replace('}{', ' ') + return result.replace("}{", " ") def text_to_sequence(text, cleaner_names, tp=None, add_blank=False): - '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - The text can optionally have ARPAbet sequences enclosed in curly braces embedded - in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." - Args: - text: string to convert to a sequence - cleaner_names: names of the cleaner functions to run the text through + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through - Returns: - List of integers corresponding to the symbols in the text - ''' + Returns: + List of integers corresponding to the symbols in the text + """ # pylint: disable=global-statement global _symbol_to_id, _symbols if tp: @@ -159,18 +171,17 @@ def text_to_sequence(text, cleaner_names, tp=None, add_blank=False): if not m: sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) break - sequence += _symbols_to_sequence( - _clean_text(m.group(1), cleaner_names)) + sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) sequence += _arpabet_to_sequence(m.group(2)) text = m.group(3) if add_blank: - sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols) + sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols) return sequence def sequence_to_text(sequence, tp=None, add_blank=False): - '''Converts a sequence of IDs back to a string''' + """Converts a sequence of IDs back to a string""" # pylint: disable=global-statement global _id_to_symbol, _symbols if add_blank: @@ -180,22 +191,22 @@ def sequence_to_text(sequence, tp=None, add_blank=False): _symbols, _ = make_symbols(**tp) _id_to_symbol = {i: s for i, s in enumerate(_symbols)} - result = '' + result = "" for symbol_id in sequence: if symbol_id in _id_to_symbol: s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == '@': - s = '{%s}' % s[1:] + if len(s) > 1 and s[0] == "@": + s = "{%s}" % s[1:] result += s - return result.replace('}{', ' ') + return result.replace("}{", " ") def _clean_text(text, cleaner_names): for name in cleaner_names: cleaner = getattr(cleaners, name) if not cleaner: - raise Exception('Unknown cleaner: %s' % name) + raise Exception("Unknown cleaner: %s" % name) text = cleaner(text) return text @@ -209,12 +220,12 @@ def _phoneme_to_sequence(phons): def _arpabet_to_sequence(text): - return _symbols_to_sequence(['@' + s for s in text.split()]) + return _symbols_to_sequence(["@" + s for s in text.split()]) def _should_keep_symbol(s): - return s in _symbol_to_id and s not in ['~', '^', '_'] + return s in _symbol_to_id and s not in ["~", "^", "_"] def _should_keep_phoneme(p): - return p in _phonemes_to_id and p not in ['~', '^', '_'] + return p in _phonemes_to_id and p not in ["~", "^", "_"] diff --git a/TTS/tts/utils/text/abbreviations.py b/TTS/tts/utils/text/abbreviations.py index 579d7dcdf2..7e44b90c63 100644 --- a/TTS/tts/utils/text/abbreviations.py +++ b/TTS/tts/utils/text/abbreviations.py @@ -1,66 +1,73 @@ import re # List of (regular expression, replacement) pairs for abbreviations in english: -abbreviations_en = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) - for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), - ]] +abbreviations_en = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] # List of (regular expression, replacement) pairs for abbreviations in french: -abbreviations_fr = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) - for x in [ - ('M', 'monsieur'), - ('Mlle', 'mademoiselle'), - ('Mlles', 'mesdemoiselles'), - ('Mme', 'Madame'), - ('Mmes', 'Mesdames'), - ('N.B', 'nota bene'), - ('M', 'monsieur'), - ('p.c.q', 'parce que'), - ('Pr', 'professeur'), - ('qqch', 'quelque chose'), - ('rdv', 'rendez-vous'), - ('max', 'maximum'), - ('min', 'minimum'), - ('no', 'numéro'), - ('adr', 'adresse'), - ('dr', 'docteur'), - ('st', 'saint'), - ('co', 'companie'), - ('jr', 'junior'), - ('sgt', 'sergent'), - ('capt', 'capitain'), - ('col', 'colonel'), - ('av', 'avenue'), - ('av. J.-C', 'avant Jésus-Christ'), - ('apr. J.-C', 'après Jésus-Christ'), - ('art', 'article'), - ('boul', 'boulevard'), - ('c.-à-d', 'c’est-à-dire'), - ('etc', 'et cetera'), - ('ex', 'exemple'), - ('excl', 'exclusivement'), - ('boul', 'boulevard'), - ]] + [(re.compile('\\b%s' % x[0]), x[1]) for x in [ - ('Mlle', 'mademoiselle'), - ('Mlles', 'mesdemoiselles'), - ('Mme', 'Madame'), - ('Mmes', 'Mesdames'), - ]] +abbreviations_fr = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("M", "monsieur"), + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ("N.B", "nota bene"), + ("M", "monsieur"), + ("p.c.q", "parce que"), + ("Pr", "professeur"), + ("qqch", "quelque chose"), + ("rdv", "rendez-vous"), + ("max", "maximum"), + ("min", "minimum"), + ("no", "numéro"), + ("adr", "adresse"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "companie"), + ("jr", "junior"), + ("sgt", "sergent"), + ("capt", "capitain"), + ("col", "colonel"), + ("av", "avenue"), + ("av. J.-C", "avant Jésus-Christ"), + ("apr. J.-C", "après Jésus-Christ"), + ("art", "article"), + ("boul", "boulevard"), + ("c.-à-d", "c’est-à-dire"), + ("etc", "et cetera"), + ("ex", "exemple"), + ("excl", "exclusivement"), + ("boul", "boulevard"), + ] +] + [ + (re.compile("\\b%s" % x[0]), x[1]) + for x in [ + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ] +] diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 4e1c6d4380..d61738a6f1 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -1,4 +1,4 @@ -''' +""" Cleaners are transformations that run over the input text at both training and eval time. Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" @@ -8,24 +8,26 @@ the Unidecode library (https://pypi.python.org/pypi/Unidecode) 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update the symbols in symbols.py to match your data). -''' +""" import re + from unidecode import unidecode -from .number_norm import normalize_numbers -from .abbreviations import abbreviations_en, abbreviations_fr -from .time import expand_time_english + from TTS.tts.utils.chinese_mandarin.numbers import replace_numbers_to_characters_in_text +from .abbreviations import abbreviations_en, abbreviations_fr +from .number_norm import normalize_numbers +from .time import expand_time_english # Regular expression matching whitespace: -_whitespace_re = re.compile(r'\s+') +_whitespace_re = re.compile(r"\s+") -def expand_abbreviations(text, lang='en'): - if lang == 'en': +def expand_abbreviations(text, lang="en"): + if lang == "en": _abbreviations = abbreviations_en - elif lang == 'fr': + elif lang == "fr": _abbreviations = abbreviations_fr for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) @@ -41,7 +43,7 @@ def lowercase(text): def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text).strip() + return re.sub(_whitespace_re, " ", text).strip() def convert_to_ascii(text): @@ -49,30 +51,32 @@ def convert_to_ascii(text): def remove_aux_symbols(text): - text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text) + text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text) return text -def replace_symbols(text, lang='en'): - text = text.replace(';', ',') - text = text.replace('-', ' ') - text = text.replace(':', ',') - if lang == 'en': - text = text.replace('&', ' and ') - elif lang == 'fr': - text = text.replace('&', ' et ') - elif lang == 'pt': - text = text.replace('&', ' e ') + +def replace_symbols(text, lang="en"): + text = text.replace(";", ",") + text = text.replace("-", " ") + text = text.replace(":", ",") + if lang == "en": + text = text.replace("&", " and ") + elif lang == "fr": + text = text.replace("&", " et ") + elif lang == "pt": + text = text.replace("&", " e ") return text + def basic_cleaners(text): - '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" text = lowercase(text) text = collapse_whitespace(text) return text def transliteration_cleaners(text): - '''Pipeline for non-English text that transliterates to ASCII.''' + """Pipeline for non-English text that transliterates to ASCII.""" text = convert_to_ascii(text) text = lowercase(text) text = collapse_whitespace(text) @@ -80,7 +84,7 @@ def transliteration_cleaners(text): def basic_german_cleaners(text): - '''Pipeline for German text''' + """Pipeline for German text""" text = lowercase(text) text = collapse_whitespace(text) return text @@ -88,7 +92,7 @@ def basic_german_cleaners(text): # TODO: elaborate it def basic_turkish_cleaners(text): - '''Pipeline for Turkish text''' + """Pipeline for Turkish text""" text = text.replace("I", "ı") text = lowercase(text) text = collapse_whitespace(text) @@ -96,7 +100,7 @@ def basic_turkish_cleaners(text): def english_cleaners(text): - '''Pipeline for English text, including number and abbreviation expansion.''' + """Pipeline for English text, including number and abbreviation expansion.""" text = convert_to_ascii(text) text = lowercase(text) text = expand_time_english(text) @@ -109,33 +113,33 @@ def english_cleaners(text): def french_cleaners(text): - '''Pipeline for French text. There is no need to expand numbers, phonemizer already does that''' - text = expand_abbreviations(text, lang='fr') + """Pipeline for French text. There is no need to expand numbers, phonemizer already does that""" + text = expand_abbreviations(text, lang="fr") text = lowercase(text) - text = replace_symbols(text, lang='fr') + text = replace_symbols(text, lang="fr") text = remove_aux_symbols(text) text = collapse_whitespace(text) return text def portuguese_cleaners(text): - '''Basic pipeline for Portuguese text. There is no need to expand abbreviation and - numbers, phonemizer already does that''' + """Basic pipeline for Portuguese text. There is no need to expand abbreviation and + numbers, phonemizer already does that""" text = lowercase(text) - text = replace_symbols(text, lang='pt') + text = replace_symbols(text, lang="pt") text = remove_aux_symbols(text) text = collapse_whitespace(text) return text def chinese_mandarin_cleaners(text: str) -> str: - '''Basic pipeline for chinese''' + """Basic pipeline for chinese""" text = replace_numbers_to_characters_in_text(text) return text def phoneme_cleaners(text): - '''Pipeline for phonemes mode, including number and abbreviation expansion.''' + """Pipeline for phonemes mode, including number and abbreviation expansion.""" text = expand_numbers(text) text = convert_to_ascii(text) text = expand_abbreviations(text) diff --git a/TTS/tts/utils/text/cmudict.py b/TTS/tts/utils/text/cmudict.py index c0f23406f5..f206fb043b 100644 --- a/TTS/tts/utils/text/cmudict.py +++ b/TTS/tts/utils/text/cmudict.py @@ -3,43 +3,116 @@ import re VALID_SYMBOLS = [ - 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', - 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', - 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', - 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', - 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', - 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', - 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', - 'Y', 'Z', 'ZH' + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH", ] class CMUDict: - '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' + """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" def __init__(self, file_or_path, keep_ambiguous=True): if isinstance(file_or_path, str): - with open(file_or_path, encoding='latin-1') as f: + with open(file_or_path, encoding="latin-1") as f: entries = _parse_cmudict(f) else: entries = _parse_cmudict(file_or_path) if not keep_ambiguous: - entries = { - word: pron - for word, pron in entries.items() if len(pron) == 1 - } + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} self._entries = entries def __len__(self): return len(self._entries) def lookup(self, word): - '''Returns list of ARPAbet pronunciations of the given word.''' + """Returns list of ARPAbet pronunciations of the given word.""" return self._entries.get(word.upper()) @staticmethod def get_arpabet(word, cmudict, punctuation_symbols): - first_symbol, last_symbol = '', '' + first_symbol, last_symbol = "", "" if word and word[0] in punctuation_symbols: first_symbol = word[0] word = word[1:] @@ -48,19 +121,19 @@ def get_arpabet(word, cmudict, punctuation_symbols): word = word[:-1] arpabet = cmudict.lookup(word) if arpabet is not None: - return first_symbol + '{%s}' % arpabet[0] + last_symbol + return first_symbol + "{%s}" % arpabet[0] + last_symbol return first_symbol + word + last_symbol -_alt_re = re.compile(r'\([0-9]+\)') +_alt_re = re.compile(r"\([0-9]+\)") def _parse_cmudict(file): cmudict = {} for line in file: - if line and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): - parts = line.split(' ') - word = re.sub(_alt_re, '', parts[0]) + if line and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) pronunciation = _get_pronunciation(parts[1]) if pronunciation: if word in cmudict: @@ -71,8 +144,8 @@ def _parse_cmudict(file): def _get_pronunciation(s): - parts = s.strip().split(' ') + parts = s.strip().split(" ") for part in parts: if part not in VALID_SYMBOLS: return None - return ' '.join(parts) + return " ".join(parts) diff --git a/TTS/tts/utils/text/number_norm.py b/TTS/tts/utils/text/number_norm.py index 2b83c271cd..e8377ede87 100644 --- a/TTS/tts/utils/text/number_norm.py +++ b/TTS/tts/utils/text/number_norm.py @@ -1,27 +1,28 @@ """ from https://github.com/keithito/tacotron """ -import inflect import re from typing import Dict +import inflect + _inflect = inflect.engine() -_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') -_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') -_currency_re = re.compile(r'(£|\$|¥)([0-9\,\.]*[0-9]+)') -_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') -_number_re = re.compile(r'-?[0-9]+') +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"-?[0-9]+") def _remove_commas(m): - return m.group(1).replace(',', '') + return m.group(1).replace(",", "") def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') + return m.group(1).replace(".", " point ") def __expand_currency(value: str, inflection: Dict[float, str]) -> str: - parts = value.replace(",", "").split('.') + parts = value.replace(",", "").split(".") if len(parts) > 2: return f"{value} {inflection[2]}" # Unexpected format text = [] @@ -31,7 +32,7 @@ def __expand_currency(value: str, inflection: Dict[float, str]) -> str: text.append(f"{integer} {integer_unit}") fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0 if fraction > 0: - fraction_unit = inflection.get(fraction/100, inflection[0.02]) + fraction_unit = inflection.get(fraction / 100, inflection[0.02]) text.append(f"{fraction} {fraction_unit}") if len(text) == 0: return f"zero {inflection[2]}" @@ -62,7 +63,7 @@ def _expand_currency(m: "re.Match") -> str: # TODO rin 0.02: "sen", 2: "yen", - } + }, } unit = m.group(1) currency = currencies[unit] @@ -78,16 +79,13 @@ def _expand_number(m): num = int(m.group(0)) if 1000 < num < 3000: if num == 2000: - return 'two thousand' + return "two thousand" if 2000 < num < 2010: - return 'two thousand ' + _inflect.number_to_words(num % 100) + return "two thousand " + _inflect.number_to_words(num % 100) if num % 100 == 0: - return _inflect.number_to_words(num // 100) + ' hundred' - return _inflect.number_to_words(num, - andword='', - zero='oh', - group=2).replace(', ', ' ') - return _inflect.number_to_words(num, andword='') + return _inflect.number_to_words(num // 100) + " hundred" + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + return _inflect.number_to_words(num, andword="") def normalize_numbers(text): diff --git a/TTS/tts/utils/text/symbols.py b/TTS/tts/utils/text/symbols.py index 834359176e..a531849d00 100644 --- a/TTS/tts/utils/text/symbols.py +++ b/TTS/tts/utils/text/symbols.py @@ -1,38 +1,41 @@ # -*- coding: utf-8 -*- -''' +""" Defines the set of symbols used in text input to the model. The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. -''' +""" -def make_symbols(characters, phonemes=None, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name - ''' Function to create symbols and phonemes ''' +def make_symbols( + characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^" +): # pylint: disable=redefined-outer-name + """ Function to create symbols and phonemes """ _symbols = [pad, eos, bos] + list(characters) _phonemes = None if phonemes is not None: _phonemes_sorted = sorted(list(set(phonemes))) # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): - _arpabet = ['@' + s for s in _phonemes_sorted] + _arpabet = ["@" + s for s in _phonemes_sorted] # Export all symbols: _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) _symbols += _arpabet return _symbols, _phonemes -_pad = '_' -_eos = '~' -_bos = '^' -_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' -_punctuations = '!\'(),-.:;? ' + +_pad = "_" +_eos = "~" +_bos = "^" +_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? " +_punctuations = "!'(),-.:;? " # Phonemes definition (All IPA characters) -_vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ' -_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ' -_pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ' -_suprasegmentals = 'ˈˌːˑ' -_other_symbols = 'ʍwɥʜʢʡɕʑɺɧʲ' -_diacrilics = 'ɚ˞ɫ' +_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ" +_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ" +_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ" +_suprasegmentals = "ˈˌːˑ" +_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ" +_diacrilics = "ɚ˞ɫ" _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos) @@ -43,16 +46,18 @@ def make_symbols(characters, phonemes=None, punctuations='!\'(),-.:;? ', pad='_' def parse_symbols(): - return {'pad': _pad, - 'eos': _eos, - 'bos': _bos, - 'characters': _characters, - 'punctuations': _punctuations, - 'phonemes': _phonemes} + return { + "pad": _pad, + "eos": _eos, + "bos": _bos, + "characters": _characters, + "punctuations": _punctuations, + "phonemes": _phonemes, + } -if __name__ == '__main__': +if __name__ == "__main__": print(" > TTS symbols {}".format(len(symbols))) print(symbols) print(" > TTS phonemes {}".format(len(phonemes))) - print(''.join(sorted(phonemes))) + print("".join(sorted(phonemes))) diff --git a/TTS/tts/utils/text/time.py b/TTS/tts/utils/text/time.py index 55ecbd8c06..c8ac09e79d 100644 --- a/TTS/tts/utils/text/time.py +++ b/TTS/tts/utils/text/time.py @@ -1,15 +1,18 @@ import re + import inflect _inflect = inflect.engine() -_time_re = re.compile(r"""\b +_time_re = re.compile( + r"""\b ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours : ([0-5][0-9]) # minutes \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm \b""", - re.IGNORECASE | re.X) + re.IGNORECASE | re.X, +) def _expand_num(n: int) -> str: diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index e5bb589146..b27dca2cd2 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -3,33 +3,26 @@ import numpy as np import torch -matplotlib.use('Agg') +matplotlib.use("Agg") import matplotlib.pyplot as plt + from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme -def plot_alignment(alignment, - info=None, - fig_size=(16, 10), - title=None, - output_fig=False): +def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False): if isinstance(alignment, torch.Tensor): alignment_ = alignment.detach().cpu().numpy().squeeze() else: alignment_ = alignment - alignment_ = alignment_.astype( - np.float32) if alignment_.dtype == np.float16 else alignment_ + alignment_ = alignment_.astype(np.float32) if alignment_.dtype == np.float16 else alignment_ fig, ax = plt.subplots(figsize=fig_size) - im = ax.imshow(alignment_.T, - aspect='auto', - origin='lower', - interpolation='none') + im = ax.imshow(alignment_.T, aspect="auto", origin="lower", interpolation="none") fig.colorbar(im, ax=ax) - xlabel = 'Decoder timestep' + xlabel = "Decoder timestep" if info is not None: - xlabel += '\n\n' + info + xlabel += "\n\n" + info plt.xlabel(xlabel) - plt.ylabel('Encoder timestep') + plt.ylabel("Encoder timestep") # plt.yticks(range(len(text)), list(text)) plt.tight_layout() if title is not None: @@ -39,16 +32,12 @@ def plot_alignment(alignment, return fig -def plot_spectrogram(spectrogram, - ap=None, - fig_size=(16, 10), - output_fig=False): +def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): if isinstance(spectrogram, torch.Tensor): spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T else: spectrogram_ = spectrogram.T - spectrogram_ = spectrogram_.astype( - np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ if ap is not None: spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access fig = plt.figure(figsize=fig_size) @@ -60,16 +49,18 @@ def plot_spectrogram(spectrogram, return fig -def visualize(alignment, - postnet_output, - text, - hop_length, - CONFIG, - stop_tokens=None, - decoder_output=None, - output_path=None, - figsize=(8, 24), - output_fig=False): +def visualize( + alignment, + postnet_output, + text, + hop_length, + CONFIG, + stop_tokens=None, + decoder_output=None, + output_path=None, + figsize=(8, 24), + output_fig=False, +): if decoder_output is not None: num_plot = 4 @@ -86,13 +77,13 @@ def visualize(alignment, # compute phoneme representation and back if CONFIG.use_phonemes: seq = phoneme_to_sequence( - text, [CONFIG.text_cleaner], + text, + [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) - text = sequence_to_phoneme( - seq, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) + tp=CONFIG.characters if "characters" in CONFIG.keys() else None, + ) + text = sequence_to_phoneme(seq, tp=CONFIG.characters if "characters" in CONFIG.keys() else None) print(text) plt.yticks(range(len(text)), list(text)) plt.colorbar() @@ -104,13 +95,15 @@ def visualize(alignment, # plot postnet spectrogram plt.subplot(num_plot, 1, 3) - librosa.display.specshow(postnet_output.T, - sr=CONFIG.audio['sample_rate'], - hop_length=hop_length, - x_axis="time", - y_axis="linear", - fmin=CONFIG.audio['mel_fmin'], - fmax=CONFIG.audio['mel_fmax']) + librosa.display.specshow( + postnet_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) plt.xlabel("Time", fontsize=label_fontsize) plt.ylabel("Hz", fontsize=label_fontsize) @@ -119,13 +112,15 @@ def visualize(alignment, if decoder_output is not None: plt.subplot(num_plot, 1, 4) - librosa.display.specshow(decoder_output.T, - sr=CONFIG.audio['sample_rate'], - hop_length=hop_length, - x_axis="time", - y_axis="linear", - fmin=CONFIG.audio['mel_fmin'], - fmax=CONFIG.audio['mel_fmax']) + librosa.display.specshow( + decoder_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) plt.xlabel("Time", fontsize=label_fontsize) plt.ylabel("Hz", fontsize=label_fontsize) plt.tight_layout() diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py index 17a76aa681..af0a15986b 100644 --- a/TTS/utils/arguments.py +++ b/TTS/utils/arguments.py @@ -4,11 +4,12 @@ import argparse import glob +import json import os import re -import json import torch + from TTS.tts.utils.text.symbols import parse_symbols from TTS.utils.console_logger import ConsoleLogger from TTS.utils.generic_utils import create_experiment_folder, get_git_branch @@ -29,41 +30,31 @@ def parse_arguments(argv): parser.add_argument( "--continue_path", type=str, - help=("Training output folder to continue training. Used to continue " - "a training. If it is used, 'config_path' is ignored."), + help=( + "Training output folder to continue training. Used to continue " + "a training. If it is used, 'config_path' is ignored." + ), default="", - required="--config_path" not in argv) + required="--config_path" not in argv, + ) parser.add_argument( - "--restore_path", - type=str, - help="Model file to be restored. Use to finetune a model.", - default="") + "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default="" + ) parser.add_argument( "--best_path", type=str, - help=("Best model file to be used for extracting best loss." - "If not specified, the latest best model in continue path is used"), - default="") - parser.add_argument( - "--config_path", - type=str, - help="Path to config file for training.", - required="--continue_path" not in argv) - parser.add_argument( - "--debug", - type=bool, - default=False, - help="Do not verify commit integrity to run training.") - parser.add_argument( - "--rank", - type=int, - default=0, - help="DISTRIBUTED: process rank for distributed training.") - parser.add_argument( - "--group_id", - type=str, + help=( + "Best model file to be used for extracting best loss." + "If not specified, the latest best model in continue path is used" + ), default="", - help="DISTRIBUTED: process group id.") + ) + parser.add_argument( + "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv + ) + parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.") + parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.") + parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.") return parser.parse_args() @@ -86,7 +77,7 @@ def get_last_checkpoint(path): file_names = glob.glob(os.path.join(path, "*.pth.tar")) last_models = {} last_model_nums = {} - for key in ['checkpoint', 'best_model']: + for key in ["checkpoint", "best_model"]: last_model_num = None last_model = None # pass all the checkpoint files and find @@ -105,7 +96,7 @@ def get_last_checkpoint(path): key_file_names = [fn for fn in file_names if key in fn] if last_model is None and len(key_file_names) > 0: last_model = max(key_file_names, key=os.path.getctime) - last_model_num = torch.load(last_model)['step'] + last_model_num = torch.load(last_model)["step"] if last_model is not None: last_models[key] = last_model @@ -114,16 +105,16 @@ def get_last_checkpoint(path): # check what models were found if not last_models: raise ValueError(f"No models found in continue path {path}!") - if 'checkpoint' not in last_models: # no checkpoint just best model - last_models['checkpoint'] = last_models['best_model'] - elif 'best_model' not in last_models: # no best model + if "checkpoint" not in last_models: # no checkpoint just best model + last_models["checkpoint"] = last_models["best_model"] + elif "best_model" not in last_models: # no best model # this shouldn't happen, but let's handle it just in case - last_models['best_model'] = None + last_models["best_model"] = None # finally check if last best model is more recent than checkpoint - elif last_model_nums['best_model'] > last_model_nums['checkpoint']: - last_models['checkpoint'] = last_models['best_model'] + elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: + last_models["checkpoint"] = last_models["best_model"] - return last_models['checkpoint'], last_models['best_model'] + return last_models["checkpoint"], last_models["best_model"] def process_args(args, model_class): @@ -157,13 +148,12 @@ def process_args(args, model_class): c = load_config(args.config_path) _ = os.path.dirname(os.path.realpath(__file__)) - if 'mixed_precision' in c and c.mixed_precision: + if "mixed_precision" in c and c.mixed_precision: print(" > Mixed precision mode is ON") out_path = args.continue_path if not out_path: - out_path = create_experiment_folder(c.output_path, c.run_name, - args.debug) + out_path = create_experiment_folder(c.output_path, c.run_name, args.debug) audio_path = os.path.join(out_path, "test_audios") @@ -179,11 +169,10 @@ def process_args(args, model_class): # if model characters are not set in the config file # save the default set to the config file for future # compatibility. - if model_class == 'tts' and 'characters' not in c: + if model_class == "tts" and "characters" not in c: used_characters = parse_symbols() - new_fields['characters'] = used_characters - copy_model_files(c, args.config_path, - out_path, new_fields) + new_fields["characters"] = used_characters + copy_model_files(c, args.config_path, out_path, new_fields) os.chmod(audio_path, 0o775) os.chmod(out_path, 0o775) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index fa4b4a53d4..a656efc0c8 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -1,13 +1,15 @@ import librosa -import soundfile as sf import numpy as np import scipy.io.wavfile import scipy.signal -# import pyworld as pw +import soundfile as sf from TTS.tts.utils.data import StandardScaler -#pylint: disable=too-many-public-methods +# import pyworld as pw + + +# pylint: disable=too-many-public-methods class AudioProcessor(object): """Audio Processor for TTS used by all the data pipelines. @@ -43,35 +45,38 @@ class AudioProcessor(object): stats_path (str, optional): Path to the computed stats file. Defaults to None. verbose (bool, optional): enable/disable logging. Defaults to True. """ - def __init__(self, - sample_rate=None, - resample=False, - num_mels=None, - log_func='np.log10', - min_level_db=None, - frame_shift_ms=None, - frame_length_ms=None, - hop_length=None, - win_length=None, - ref_level_db=None, - fft_size=1024, - power=None, - preemphasis=0.0, - signal_norm=None, - symmetric_norm=None, - max_norm=None, - mel_fmin=None, - mel_fmax=None, - spec_gain=20, - stft_pad_mode='reflect', - clip_norm=True, - griffin_lim_iters=None, - do_trim_silence=False, - trim_db=60, - do_sound_norm=False, - stats_path=None, - verbose=True, - **_): + + def __init__( + self, + sample_rate=None, + resample=False, + num_mels=None, + log_func="np.log10", + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + hop_length=None, + win_length=None, + ref_level_db=None, + fft_size=1024, + power=None, + preemphasis=0.0, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + spec_gain=20, + stft_pad_mode="reflect", + clip_norm=True, + griffin_lim_iters=None, + do_trim_silence=False, + trim_db=60, + do_sound_norm=False, + stats_path=None, + verbose=True, + **_, + ): # setup class attributed self.sample_rate = sample_rate @@ -98,14 +103,14 @@ def __init__(self, self.do_sound_norm = do_sound_norm self.stats_path = stats_path # setup exp_func for db to amp conversion - print(f'self.log_func = {log_func}') - exec(f'self.log_func = {log_func}') #pylint: disable=exec-used - if self.log_func.__name__ == 'log': + print(f"self.log_func = {log_func}") + exec(f"self.log_func = {log_func}") # pylint: disable=exec-used + if self.log_func.__name__ == "log": self.exp_func = np.exp - elif self.log_func.__name__ == 'log10': + elif self.log_func.__name__ == "log10": self.exp_func = lambda x: 10 ** x else: - raise ValueError(' [!] unknown `log_func` value.') + raise ValueError(" [!] unknown `log_func` value.") # setup stft parameters if hop_length is None: # compute stft parameters from given time values @@ -134,17 +139,18 @@ def __init__(self, self.symmetric_norm = None ### setting up the parameters ### - def _build_mel_basis(self, ): + def _build_mel_basis( + self, + ): if self.mel_fmax is not None: assert self.mel_fmax <= self.sample_rate // 2 return librosa.filters.mel( - self.sample_rate, - self.fft_size, - n_mels=self.num_mels, - fmin=self.mel_fmin, - fmax=self.mel_fmax) + self.sample_rate, self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) - def _stft_parameters(self, ): + def _stft_parameters( + self, + ): """Compute necessary stft parameters with given time values""" factor = self.frame_length_ms / self.frame_shift_ms assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" @@ -155,24 +161,26 @@ def _stft_parameters(self, ): ### normalization ### def normalize(self, S): """Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]""" - #pylint: disable=no-else-return + # pylint: disable=no-else-return S = S.copy() if self.signal_norm: # mean-var scaling - if hasattr(self, 'mel_scaler'): + if hasattr(self, "mel_scaler"): if S.shape[0] == self.num_mels: return self.mel_scaler.transform(S.T).T elif S.shape[0] == self.fft_size / 2: return self.linear_scaler.transform(S.T).T else: - raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.') + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") # range normalization S -= self.ref_level_db # discard certain range of DB assuming it is air noise - S_norm = ((S - self.min_level_db) / (-self.min_level_db)) + S_norm = (S - self.min_level_db) / (-self.min_level_db) if self.symmetric_norm: S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm if self.clip_norm: - S_norm = np.clip(S_norm, -self.max_norm, self.max_norm) # pylint: disable=invalid-unary-operand-type + S_norm = np.clip( + S_norm, -self.max_norm, self.max_norm + ) # pylint: disable=invalid-unary-operand-type return S_norm else: S_norm = self.max_norm * S_norm @@ -184,47 +192,49 @@ def normalize(self, S): def denormalize(self, S): """denormalize values""" - #pylint: disable=no-else-return + # pylint: disable=no-else-return S_denorm = S.copy() if self.signal_norm: # mean-var scaling - if hasattr(self, 'mel_scaler'): + if hasattr(self, "mel_scaler"): if S_denorm.shape[0] == self.num_mels: return self.mel_scaler.inverse_transform(S_denorm.T).T elif S_denorm.shape[0] == self.fft_size / 2: return self.linear_scaler.inverse_transform(S_denorm.T).T else: - raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.') + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") if self.symmetric_norm: if self.clip_norm: - S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm) #pylint: disable=invalid-unary-operand-type + S_denorm = np.clip( + S_denorm, -self.max_norm, self.max_norm + ) # pylint: disable=invalid-unary-operand-type S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db return S_denorm + self.ref_level_db else: if self.clip_norm: S_denorm = np.clip(S_denorm, 0, self.max_norm) - S_denorm = (S_denorm * -self.min_level_db / - self.max_norm) + self.min_level_db + S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db return S_denorm + self.ref_level_db else: return S_denorm ### Mean-STD scaling ### def load_stats(self, stats_path): - stats = np.load(stats_path, allow_pickle=True).item() #pylint: disable=unexpected-keyword-arg - mel_mean = stats['mel_mean'] - mel_std = stats['mel_std'] - linear_mean = stats['linear_mean'] - linear_std = stats['linear_std'] - stats_config = stats['audio_config'] + stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg + mel_mean = stats["mel_mean"] + mel_std = stats["mel_std"] + linear_mean = stats["linear_mean"] + linear_std = stats["linear_std"] + stats_config = stats["audio_config"] # check all audio parameters used for computing stats - skip_parameters = ['griffin_lim_iters', 'stats_path', 'do_trim_silence', 'ref_level_db', 'power'] + skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] for key in stats_config.keys(): if key in skip_parameters: continue - if key not in ['sample_rate', 'trim_db']: - assert stats_config[key] == self.__dict__[key],\ - f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + if key not in ["sample_rate", "trim_db"]: + assert ( + stats_config[key] == self.__dict__[key] + ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" return mel_mean, mel_std, linear_mean, linear_std, stats_config # pylint: disable=attribute-defined-outside-init @@ -243,7 +253,6 @@ def _amp_to_db(self, x): def _db_to_amp(self, x): return self.exp_func(x / self.spec_gain) - ### Preemphasis ### def apply_preemphasis(self, x): if self.preemphasis == 0: @@ -284,17 +293,17 @@ def inv_spectrogram(self, spectrogram): S = self._db_to_amp(S) # Reconstruct phase if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) def inv_melspectrogram(self, mel_spectrogram): - '''Converts melspectrogram to waveform using librosa''' + """Converts melspectrogram to waveform using librosa""" D = self.denormalize(mel_spectrogram) S = self._db_to_amp(D) S = self._mel_to_linear(S) # Convert back to linear if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) def out_linear_to_mel(self, linear_spec): S = self.denormalize(linear_spec) @@ -306,22 +315,24 @@ def out_linear_to_mel(self, linear_spec): ### STFT and ISTFT ### def _stft(self, y): - return librosa.stft(y=y, - n_fft=self.fft_size, - hop_length=self.hop_length, - win_length=self.win_length, - pad_mode=self.stft_pad_mode, - window='hann', - center=True, - ) + return librosa.stft( + y=y, + n_fft=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + window="hann", + center=True, + ) def _istft(self, y): - return librosa.istft(y, - hop_length=self.hop_length, - win_length=self.win_length, - window='hann', - center=True, - ) + return librosa.istft( + y, + hop_length=self.hop_length, + win_length=self.win_length, + window="hann", + center=True, + ) def _griffin_lim(self, S): angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) @@ -333,8 +344,7 @@ def _griffin_lim(self, S): return y def compute_stft_paddings(self, x, pad_sides=1): - '''compute right padding (final frame) or both sides padding (first and final frames) - ''' + """compute right padding (final frame) or both sides padding (first and final frames)""" assert pad_sides in (1, 2) pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] if pad_sides == 1: @@ -359,7 +369,7 @@ def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8): hop_length = int(window_length / 4) threshold = self._db_to_amp(threshold_db) for x in range(hop_length, len(wav) - window_length, hop_length): - if np.max(wav[x:x + window_length]) < threshold: + if np.max(wav[x : x + window_length]) < threshold: return x + hop_length return len(wav) @@ -367,8 +377,9 @@ def trim_silence(self, wav): """ Trim silent parts with a threshold and 0.01 sec margin """ margin = int(self.sample_rate * 0.01) wav = wav[margin:-margin] - return librosa.effects.trim( - wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[0] + return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ + 0 + ] @staticmethod def sound_norm(x): @@ -380,14 +391,14 @@ def load_wav(self, filename, sr=None): x, sr = librosa.load(filename, sr=self.sample_rate) elif sr is None: x, sr = sf.read(filename) - assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr) + assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) else: x, sr = librosa.load(filename, sr=sr) if self.do_trim_silence: try: x = self.trim_silence(x) except ValueError: - print(f' [!] File cannot be trimmed for silence - {filename}') + print(f" [!] File cannot be trimmed for silence - {filename}") if self.do_sound_norm: x = self.sound_norm(x) return x @@ -400,10 +411,12 @@ def save_wav(self, wav, path, sr=None): def mulaw_encode(wav, qc): mu = 2 ** qc - 1 # wav_abs = np.minimum(np.abs(wav), 1.0) - signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1. + mu) + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) # Quantize signal to the specified number of levels. signal = (signal + 1) / 2 * mu + 0.5 - return np.floor(signal,) + return np.floor( + signal, + ) @staticmethod def mulaw_decode(wav, qc): @@ -412,15 +425,14 @@ def mulaw_decode(wav, qc): x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) return x - @staticmethod def encode_16bits(x): - return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16) + return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) @staticmethod def quantize(x, bits): - return (x + 1.) * (2**bits - 1) / 2 + return (x + 1.0) * (2 ** bits - 1) / 2 @staticmethod def dequantize(x, bits): - return 2 * x / (2**bits - 1) - 1 + return 2 * x / (2 ** bits - 1) - 1 diff --git a/TTS/utils/console_logger.py b/TTS/utils/console_logger.py index 3affd6afc3..7d6e1968f9 100644 --- a/TTS/utils/console_logger.py +++ b/TTS/utils/console_logger.py @@ -1,20 +1,22 @@ import datetime -from TTS.utils.io import AttrDict +from TTS.utils.io import AttrDict -tcolors = AttrDict({ - 'OKBLUE': '\033[94m', - 'HEADER': '\033[95m', - 'OKGREEN': '\033[92m', - 'WARNING': '\033[93m', - 'FAIL': '\033[91m', - 'ENDC': '\033[0m', - 'BOLD': '\033[1m', - 'UNDERLINE': '\033[4m' -}) +tcolors = AttrDict( + { + "OKBLUE": "\033[94m", + "HEADER": "\033[95m", + "OKGREEN": "\033[92m", + "WARNING": "\033[93m", + "FAIL": "\033[91m", + "ENDC": "\033[0m", + "BOLD": "\033[1m", + "UNDERLINE": "\033[4m", + } +) -class ConsoleLogger(): +class ConsoleLogger: def __init__(self): # TODO: color code for value changes # use these to compare values between iterations @@ -28,23 +30,24 @@ def get_time(self): return now.strftime("%Y-%m-%d %H:%M:%S") def print_epoch_start(self, epoch, max_epoch): - print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, - epoch, max_epoch, tcolors.ENDC), - flush=True) + print( + "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), + flush=True, + ) def print_train_start(self): print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") - def print_train_step(self, batch_steps, step, global_step, log_dict, - loss_dict, avg_loss_dict): + def print_train_step(self, batch_steps, step, global_step, log_dict, loss_dict, avg_loss_dict): indent = " | > " print() log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format( - tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC) + tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC + ) for key, value in loss_dict.items(): # print the avg value if given - if f'avg_{key}' in avg_loss_dict.keys(): - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}']) + if f"avg_{key}" in avg_loss_dict.keys(): + log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) else: log_text += "{}{}: {:.5f} \n".format(indent, key, value) for idx, (key, value) in enumerate(log_dict.items()): @@ -52,13 +55,12 @@ def print_train_step(self, batch_steps, step, global_step, log_dict, log_text += f"{indent}{key}: {value[0]:.{value[1]}f}" else: log_text += f"{indent}{key}: {value}" - if idx < len(log_dict)-1: + if idx < len(log_dict) - 1: log_text += "\n" print(log_text, flush=True) # pylint: disable=unused-argument - def print_train_epoch_end(self, global_step, epoch, epoch_time, - print_dict): + def print_train_epoch_end(self, global_step, epoch, epoch_time, print_dict): indent = " | > " log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n" for key, value in print_dict.items(): @@ -74,29 +76,28 @@ def print_eval_step(self, step, loss_dict, avg_loss_dict): log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n" for key, value in loss_dict.items(): # print the avg value if given - if f'avg_{key}' in avg_loss_dict.keys(): - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}']) + if f"avg_{key}" in avg_loss_dict.keys(): + log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) else: log_text += "{}{}: {:.5f} \n".format(indent, key, value) print(log_text, flush=True) def print_epoch_end(self, epoch, avg_loss_dict): indent = " | > " - log_text = " {}--> EVAL PERFORMANCE{}\n".format( - tcolors.BOLD, tcolors.ENDC) + log_text = " {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC) for key, value in avg_loss_dict.items(): # print the avg value if given - color = '' - sign = '+' + color = "" + sign = "+" diff = 0 if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict: diff = value - self.old_eval_loss_dict[key] if diff < 0: color = tcolors.OKGREEN - sign = '' + sign = "" elif diff > 0: color = tcolors.FAIL - sign = '+' + sign = "+" log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff) self.old_eval_loss_dict = avg_loss_dict print(log_text, flush=True) diff --git a/TTS/utils/distribute.py b/TTS/utils/distribute.py index 89d4efec00..7703ab4e42 100644 --- a/TTS/utils/distribute.py +++ b/TTS/utils/distribute.py @@ -34,11 +34,11 @@ def __iter__(self): indices = torch.arange(len(self.dataset)).tolist() # add extra samples to make it evenly divisible - indices += indices[:(self.total_size - len(indices))] + indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) @@ -64,12 +64,7 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): torch.cuda.set_device(rank % torch.cuda.device_count()) # Initialize distributed communication - dist.init_process_group( - dist_backend, - init_method=dist_url, - world_size=num_gpus, - rank=rank, - group_name=group_name) + dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name) def apply_gradient_allreduce(module): @@ -97,14 +92,13 @@ def allreduce_params(): coalesced = _flatten_dense_tensors(grads) dist.all_reduce(coalesced, op=dist.reduce_op.SUM) coalesced /= dist.get_world_size() - for buf, synced in zip( - grads, _unflatten_dense_tensors(coalesced, grads)): + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) for param in list(module.parameters()): def allreduce_hook(*_): - Variable._execution_engine.queue_callback(allreduce_params) #pylint: disable=protected-access + Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access if param.requires_grad: param.register_hook(allreduce_hook) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 60721364ad..57e227074d 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -10,8 +10,7 @@ def get_git_branch(): try: out = subprocess.check_output(["git", "branch"]).decode("utf8") - current = next(line for line in out.split("\n") - if line.startswith("*")) + current = next(line for line in out.split("\n") if line.startswith("*")) current.replace("* ", "") except subprocess.CalledProcessError: current = "inside_docker" @@ -29,12 +28,11 @@ def get_commit_hash(): # raise RuntimeError( # " !! Commit before training to get the commit hash.") try: - commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD']).decode().strip() + commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() # Not copying .git folder into docker container except (subprocess.CalledProcessError, FileNotFoundError): commit = "0000000" - print(' > Git Hash: {}'.format(commit)) + print(" > Git Hash: {}".format(commit)) return commit @@ -42,11 +40,10 @@ def create_experiment_folder(root_path, model_name, debug): """ Create a folder with the current date and time """ date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") if debug: - commit_hash = 'debug' + commit_hash = "debug" else: commit_hash = get_commit_hash() - output_folder = os.path.join( - root_path, model_name + '-' + date_str + '-' + commit_hash) + output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) os.makedirs(output_folder, exist_ok=True) print(" > Experiment folder: {}".format(output_folder)) return output_folder @@ -72,16 +69,16 @@ def count_parameters(model): def get_user_data_dir(appname): if sys.platform == "win32": import winreg # pylint: disable=import-outside-toplevel + key = winreg.OpenKey( - winreg.HKEY_CURRENT_USER, - r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" + winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" ) dir_, _ = winreg.QueryValueEx(key, "Local AppData") ans = Path(dir_).resolve(strict=False) - elif sys.platform == 'darwin': - ans = Path('~/Library/Application Support/').expanduser() + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() else: - ans = Path.home().joinpath('.local/share') + ans = Path.home().joinpath(".local/share") return ans.joinpath(appname) @@ -91,32 +88,20 @@ def set_init_dict(model_dict, checkpoint_state, c): if k not in model_dict: print(" | > Layer missing in the model definition: {}".format(k)) # 1. filter out unnecessary keys - pretrained_dict = { - k: v - for k, v in checkpoint_state.items() if k in model_dict - } + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} # 2. filter out different size layers - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if v.numel() == model_dict[k].numel() - } + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} # 3. skip reinit layers if c.reinit_layers is not None: for reinit_layer_name in c.reinit_layers: - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if reinit_layer_name not in k - } + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), - len(model_dict))) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) return model_dict -class KeepAverage(): +class KeepAverage: def __init__(self): self.avg_values = {} self.iters = {} @@ -141,8 +126,7 @@ def update_value(self, name, value, weighted_avg=False): self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value self.iters[name] += 1 else: - self.avg_values[name] = self.avg_values[name] * \ - self.iters[name] + value + self.avg_values[name] = self.avg_values[name] * self.iters[name] + value self.iters[name] += 1 self.avg_values[name] /= self.iters[name] @@ -155,23 +139,27 @@ def update_values(self, value_dict): self.update_value(key, value) -def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None): +def check_argument( + name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None +): if alternative in c.keys() and c[alternative] is not None: return if restricted: - assert name in c.keys(), f' [!] {name} not defined in config.json' + assert name in c.keys(), f" [!] {name} not defined in config.json" if name in c.keys(): if max_val: - assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}' + assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}" if min_val: - assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' + assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}" if enum_list: - assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' + assert c[name].lower() in enum_list, f" [!] {name} is not a valid value" if isinstance(val_type, list): is_valid = False for typ in val_type: if isinstance(c[name], typ): is_valid = True - assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}" elif val_type: - assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + assert ( + isinstance(c[name], val_type) or c[name] is None + ), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}" diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 1703de6f36..93f8b7497a 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -1,20 +1,23 @@ -import os -import re import json -import yaml +import os import pickle as pickle_tts +import re from shutil import copyfile +import yaml + class RenamingUnpickler(pickle_tts.Unpickler): """Overload default pickler to solve module renaming problem""" + def find_class(self, module, name): - return super().find_class(module.replace('mozilla_voice_tts', 'TTS'), name) + return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) class AttrDict(dict): """A custom dict which converts dict keys to class attributes""" + def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self @@ -25,11 +28,12 @@ def read_json_with_comments(json_path): with open(json_path, "r", encoding="utf-8") as f: input_str = f.read() # handle comments - input_str = re.sub(r'\\\n', '', input_str) - input_str = re.sub(r'//.*\n', '\n', input_str) + input_str = re.sub(r"\\\n", "", input_str) + input_str = re.sub(r"//.*\n", "\n", input_str) data = json.loads(input_str) return data + def load_config(config_path: str) -> AttrDict: """Load config files and discard comments @@ -60,7 +64,7 @@ def copy_model_files(c, config_file, out_path, new_fields): in the config file. """ # copy config.json - copy_config_path = os.path.join(out_path, 'config.json') + copy_config_path = os.path.join(out_path, "config.json") config_lines = open(config_file, "r", encoding="utf-8").readlines() # add extra information fields for key, value in new_fields.items(): @@ -73,7 +77,10 @@ def copy_model_files(c, config_file, out_path, new_fields): config_out_file.writelines(config_lines) config_out_file.close() # copy model stats file if available - if c.audio['stats_path'] is not None: - copy_stats_path = os.path.join(out_path, 'scale_stats.npy') + if c.audio["stats_path"] is not None: + copy_stats_path = os.path.join(out_path, "scale_stats.npy") if not os.path.exists(copy_stats_path): - copyfile(c.audio['stats_path'], copy_stats_path, ) + copyfile( + c.audio["stats_path"], + copy_stats_path, + ) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index ef77ca4eb0..f0a8122762 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -7,6 +7,7 @@ import gdown import requests + from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.io import load_config @@ -22,12 +23,13 @@ class ModelManager(object): Args: models_file (str): path to .model.json """ + def __init__(self, models_file=None, output_prefix=None): super().__init__() if output_prefix is None: - self.output_prefix = get_user_data_dir('tts') + self.output_prefix = get_user_data_dir("tts") else: - self.output_prefix = os.path.join(output_prefix, 'tts') + self.output_prefix = os.path.join(output_prefix, "tts") self.url_prefix = "https://drive.google.com/uc?id=" self.models_dict = None if models_file is not None: @@ -72,7 +74,7 @@ def list_models(self): print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]") else: print(f" >: {model_type}/{lang}/{dataset}/{model}") - models_name_list.append(f'{model_type}/{lang}/{dataset}/{model}') + models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") return models_name_list def download_model(self, model_name): @@ -104,25 +106,25 @@ def download_model(self, model_name): else: os.makedirs(output_path, exist_ok=True) print(f" > Downloading model to {output_path}") - output_stats_path = os.path.join(output_path, 'scale_stats.npy') + output_stats_path = os.path.join(output_path, "scale_stats.npy") # download files to the output path - if self._check_dict_key(model_item, 'github_rls_url'): + if self._check_dict_key(model_item, "github_rls_url"): # download from github release # TODO: pass output_path - self._download_zip_file(model_item['github_rls_url'], output_path) + self._download_zip_file(model_item["github_rls_url"], output_path) else: # download from gdrive - self._download_gdrive_file(model_item['model_file'], output_model_path) - self._download_gdrive_file(model_item['config_file'], output_config_path) - if self._check_dict_key(model_item, 'stats_file'): - self._download_gdrive_file(model_item['stats_file'], output_stats_path) + self._download_gdrive_file(model_item["model_file"], output_model_path) + self._download_gdrive_file(model_item["config_file"], output_config_path) + if self._check_dict_key(model_item, "stats_file"): + self._download_gdrive_file(model_item["stats_file"], output_stats_path) # set the scale_path.npy file path in the model config.json - if self._check_dict_key(model_item, 'stats_file') or os.path.exists(output_stats_path): + if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path): # set scale stats path in config.json config_path = output_config_path config = load_config(config_path) - config["audio"]['stats_path'] = output_stats_path + config["audio"]["stats_path"] = output_stats_path with open(config_path, "w") as jf: json.dump(config, jf) return output_model_path, output_config_path, model_item diff --git a/TTS/utils/radam.py b/TTS/utils/radam.py index 58cec9205a..40d8cec9ab 100644 --- a/TTS/utils/radam.py +++ b/TTS/utils/radam.py @@ -1,12 +1,12 @@ # from https://github.com/LiyuanLucasLiu/RAdam import math + import torch from torch.optim.optimizer import Optimizer class RAdam(Optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): if lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -20,9 +20,11 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 self.degenerated_to_sgd = degenerated_to_sgd if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): for param in params: - if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): - param['buffer'] = [[None, None, None] for _ in range(10)] - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) + if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)] + ) super(RAdam, self).__init__(params, defaults) def __setstate__(self, state): # pylint: disable=useless-super-delegation @@ -36,62 +38,70 @@ def step(self, closure=None): for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: - raise RuntimeError('RAdam does not support sparse gradients') + raise RuntimeError("RAdam does not support sparse gradients") p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_data_fp32) - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) else: - state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - state['step'] += 1 - buffered = group['buffer'][int(state['step'] % 10)] - if state['step'] == buffered[0]: + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: N_sma, step_size = buffered[1], buffered[2] else: - buffered[0] = state['step'] - beta2_t = beta2 ** state['step'] + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] N_sma_max = 2 / (1 - beta2) - 1 - N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) buffered[1] = N_sma # more conservative since it's an approximated value if N_sma >= 5: - step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) elif self.degenerated_to_sgd: - step_size = 1.0 / (1 - beta1 ** state['step']) + step_size = 1.0 / (1 - beta1 ** state["step"]) else: step_size = -1 buffered[2] = step_size # more conservative since it's an approximated value if N_sma >= 5: - if group['weight_decay'] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) - denom = exp_avg_sq.sqrt().add_(group['eps']) - p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"]) p.data.copy_(p_data_fp32) elif step_size > 0: - if group['weight_decay'] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) - p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"]) p.data.copy_(p_data_fp32) return loss diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a7a82d1387..64a1b88cb6 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,18 +1,19 @@ import time import numpy as np -import torch import pysbd +import torch -from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_config from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.speakers import load_speaker_mapping -from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input + # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import from TTS.tts.utils.synthesis import synthesis, trim_silence from TTS.tts.utils.text import make_symbols, phonemes, symbols +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_config +from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input, setup_generator class Synthesizer(object): @@ -49,12 +50,11 @@ def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_ self.use_cuda = use_cuda if self.use_cuda: assert torch.cuda.is_available(), "CUDA is not availabe on this machine." - self.load_tts(tts_checkpoint, tts_config, - use_cuda) - self.output_sample_rate = self.tts_config.audio['sample_rate'] + self.load_tts(tts_checkpoint, tts_config, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] if vocoder_checkpoint: self.load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) - self.output_sample_rate = self.vocoder_config.audio['sample_rate'] + self.output_sample_rate = self.vocoder_config.audio["sample_rate"] @staticmethod def get_segmenter(lang): @@ -69,16 +69,18 @@ def load_speakers(self): self.num_speakers = 0 # set external speaker embedding if self.tts_config.use_external_speaker_embedding_file: - speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]['embedding'] + speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"] self.speaker_embedding_dim = len(speaker_embedding) def init_speaker(self, speaker_idx): # load speakers speaker_embedding = None - if hasattr(self, 'tts_speakers') and speaker_idx is not None: - assert speaker_idx < len(self.tts_speakers), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}" + if hasattr(self, "tts_speakers") and speaker_idx is not None: + assert speaker_idx < len( + self.tts_speakers + ), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}" if self.tts_config.use_external_speaker_embedding_file: - speaker_embedding = self.tts_speakers[speaker_idx]['embedding'] + speaker_embedding = self.tts_speakers[speaker_idx]["embedding"] return speaker_embedding def load_tts(self, tts_checkpoint, tts_config, use_cuda): @@ -90,7 +92,7 @@ def load_tts(self, tts_checkpoint, tts_config, use_cuda): self.use_phonemes = self.tts_config.use_phonemes self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) - if 'characters' in self.tts_config.keys(): + if "characters" in self.tts_config.keys(): symbols, phonemes = make_symbols(**self.tts_config.characters) if self.use_phonemes: @@ -105,7 +107,7 @@ def load_tts(self, tts_checkpoint, tts_config, use_cuda): def load_vocoder(self, model_file, model_config, use_cuda): self.vocoder_config = load_config(model_config) - self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config['audio']) + self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"]) self.vocoder_model = setup_generator(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) if use_cuda: @@ -141,7 +143,8 @@ def tts(self, text, speaker_idx=None): False, self.tts_config.enable_eos_bos_chars, use_gl, - speaker_embedding=speaker_embedding) + speaker_embedding=speaker_embedding, + ) if not use_gl: # denormalize tts output based on tts audio config mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T @@ -149,7 +152,7 @@ def tts(self, text, speaker_idx=None): # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch - scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate] + scale_factor = [1, self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate] if scale_factor[1] != 1: print(" > interpolating tts model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) @@ -172,7 +175,7 @@ def tts(self, text, speaker_idx=None): # compute stats process_time = time.time() - start_time - audio_time = len(wavs) / self.tts_config.audio['sample_rate'] + audio_time = len(wavs) / self.tts_config.audio["sample_rate"] print(f" > Processing time: {process_time}") print(f" > Real-time factor: {process_time / audio_time}") return wavs diff --git a/TTS/utils/tensorboard_logger.py b/TTS/utils/tensorboard_logger.py index 4ee12d74ec..3874a42b95 100644 --- a/TTS/utils/tensorboard_logger.py +++ b/TTS/utils/tensorboard_logger.py @@ -1,4 +1,5 @@ import traceback + from tensorboardX import SummaryWriter @@ -13,40 +14,28 @@ def tb_model_weights(self, model, step): layer_num = 1 for name, param in model.named_parameters(): if param.numel() == 1: - self.writer.add_scalar( - "layer{}-{}/value".format(layer_num, name), - param.max(), step) + self.writer.add_scalar("layer{}-{}/value".format(layer_num, name), param.max(), step) else: - self.writer.add_scalar( - "layer{}-{}/max".format(layer_num, name), - param.max(), step) - self.writer.add_scalar( - "layer{}-{}/min".format(layer_num, name), - param.min(), step) - self.writer.add_scalar( - "layer{}-{}/mean".format(layer_num, name), - param.mean(), step) - self.writer.add_scalar( - "layer{}-{}/std".format(layer_num, name), - param.std(), step) - self.writer.add_histogram( - "layer{}-{}/param".format(layer_num, name), param, step) - self.writer.add_histogram( - "layer{}-{}/grad".format(layer_num, name), param.grad, step) + self.writer.add_scalar("layer{}-{}/max".format(layer_num, name), param.max(), step) + self.writer.add_scalar("layer{}-{}/min".format(layer_num, name), param.min(), step) + self.writer.add_scalar("layer{}-{}/mean".format(layer_num, name), param.mean(), step) + self.writer.add_scalar("layer{}-{}/std".format(layer_num, name), param.std(), step) + self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step) + self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) layer_num += 1 def dict_to_tb_scalar(self, scope_name, stats, step): for key, value in stats.items(): - self.writer.add_scalar('{}/{}'.format(scope_name, key), value, step) + self.writer.add_scalar("{}/{}".format(scope_name, key), value, step) def dict_to_tb_figure(self, scope_name, figures, step): for key, value in figures.items(): - self.writer.add_figure('{}/{}'.format(scope_name, key), value, step) + self.writer.add_figure("{}/{}".format(scope_name, key), value, step) def dict_to_tb_audios(self, scope_name, audios, step, sample_rate): for key, value in audios.items(): try: - self.writer.add_audio('{}/{}'.format(scope_name, key), value, step, sample_rate=sample_rate) + self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate) except RuntimeError: traceback.print_exc() diff --git a/TTS/utils/training.py b/TTS/utils/training.py index 8166562ccb..a1581b3296 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch def setup_torch_training_env(cudnn_enable, cudnn_benchmark): @@ -14,12 +14,13 @@ def setup_torch_training_env(cudnn_enable, cudnn_benchmark): def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): - r'''Check model gradient against unexpected jumps and failures''' + r"""Check model gradient against unexpected jumps and failures""" skip_flag = False if ignore_stopnet: if not amp_opt_params: grad_norm = torch.nn.utils.clip_grad_norm_( - [param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip) + [param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip + ) else: grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) else: @@ -41,11 +42,10 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): def lr_decay(init_lr, global_step, warmup_steps): - r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py''' + r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py""" warmup_steps = float(warmup_steps) - step = global_step + 1. - lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, - step**-0.5) + step = global_step + 1.0 + lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5) return lr @@ -54,14 +54,14 @@ def adam_weight_decay(optimizer): Custom weight decay operation, not effecting grad values. """ for group in optimizer.param_groups: - for param in group['params']: - current_lr = group['lr'] - weight_decay = group['weight_decay'] - factor = -weight_decay * group['lr'] - param.data = param.data.add(param.data, - alpha=factor) + for param in group["params"]: + current_lr = group["lr"] + weight_decay = group["weight_decay"] + factor = -weight_decay * group["lr"] + param.data = param.data.add(param.data, alpha=factor) return optimizer, current_lr + # pylint: disable=dangerous-default-value def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}): """ @@ -78,13 +78,7 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn no_decay.append(param) else: decay.append(param) - return [{ - 'params': no_decay, - 'weight_decay': 0. - }, { - 'params': decay, - 'weight_decay': weight_decay - }] + return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}] # pylint: disable=protected-access @@ -96,8 +90,7 @@ def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): def get_lr(self): step = max(self.last_epoch, 1) return [ - base_lr * self.warmup_steps**0.5 * - min(step * self.warmup_steps**-1.5, step**-0.5) + base_lr * self.warmup_steps ** 0.5 * min(step * self.warmup_steps ** -1.5, step ** -0.5) for base_lr in self.base_lrs ] diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py index 455ea95c85..85d27e8bb7 100644 --- a/TTS/vocoder/datasets/gan_dataset.py +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -1,10 +1,11 @@ -import os import glob -import torch +import os import random +from multiprocessing import Manager + import numpy as np +import torch from torch.utils.data import Dataset -from multiprocessing import Manager class GANDataset(Dataset): @@ -13,19 +14,22 @@ class GANDataset(Dataset): and converts them to acoustic features on the fly and returns random segments of (audio, feature) couples. """ - def __init__(self, - ap, - items, - seq_len, - hop_len, - pad_short, - conv_pad=2, - return_pairs=False, - is_training=True, - return_segments=True, - use_noise_augment=False, - use_cache=False, - verbose=False): + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + return_pairs=False, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + verbose=False, + ): super(GANDataset, self).__init__() self.ap = ap self.item_list = items @@ -59,14 +63,14 @@ def create_feature_cache(self): @staticmethod def find_wav_files(path): - return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True) + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) def __len__(self): return len(self.item_list) def __getitem__(self, idx): - """ Return different items for Generator and Discriminator and - cache acoustic features """ + """Return different items for Generator and Discriminator and + cache acoustic features""" # set the seed differently for each worker if torch.utils.data.get_worker_info(): @@ -85,13 +89,16 @@ def __getitem__(self, idx): def _pad_short_samples(self, audio, mel=None): """Pad samples shorter than the output sequence length""" if len(audio) < self.seq_len: - audio = np.pad(audio, (0, self.seq_len - len(audio)), - mode='constant', - constant_values=0.0) + audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0) if mel is not None and mel.shape[1] < self.feat_frame_len: pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0] - mel = np.pad(mel, ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), mode='constant', constant_values=pad_value.mean()) + mel = np.pad( + mel, + ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), + mode="constant", + constant_values=pad_value.mean(), + ) return audio, mel def shuffle_mapping(self): @@ -124,8 +131,10 @@ def load_item(self, idx): # correct the audio length wrt padding applied in stft audio = np.pad(audio, (0, self.hop_len), mode="edge") - audio = audio[:mel.shape[-1] * self.hop_len] - assert mel.shape[-1] * self.hop_len == audio.shape[-1], f' [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}' + audio = audio[: mel.shape[-1] * self.hop_len] + assert ( + mel.shape[-1] * self.hop_len == audio.shape[-1] + ), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}" audio = torch.from_numpy(audio).float().unsqueeze(0) mel = torch.from_numpy(mel).float().squeeze(0) @@ -137,8 +146,7 @@ def load_item(self, idx): mel = mel[:, mel_start:mel_end] audio_start = mel_start * self.hop_len - audio = audio[:, audio_start:audio_start + - self.seq_len] + audio = audio[:, audio_start : audio_start + self.seq_len] if self.use_noise_augment and self.is_training and self.return_segments: audio = audio + (1 / 32768) * torch.randn_like(audio) diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py index afea45fd32..d99ee1479f 100644 --- a/TTS/vocoder/datasets/preprocess.py +++ b/TTS/vocoder/datasets/preprocess.py @@ -1,9 +1,9 @@ import glob import os from pathlib import Path -from tqdm import tqdm import numpy as np +from tqdm import tqdm def preprocess_wav_files(out_path, config, ap): @@ -18,11 +18,7 @@ def preprocess_wav_files(out_path, config, ap): mel = ap.melspectrogram(y) np.save(mel_path, mel) if isinstance(config.mode, int): - quant = ( - ap.mulaw_encode(y, qc=config.mode) - if config.mulaw - else ap.quantize(y, bits=config.mode) - ) + quant = ap.mulaw_encode(y, qc=config.mode) if config.mulaw else ap.quantize(y, bits=config.mode) np.save(quant_path, quant) diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index 6cd5862aa5..51767b565d 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -1,10 +1,11 @@ -import os import glob -import torch +import os import random +from multiprocessing import Manager + import numpy as np +import torch from torch.utils.data import Dataset -from multiprocessing import Manager class WaveGradDataset(Dataset): @@ -13,18 +14,21 @@ class WaveGradDataset(Dataset): and converts them to acoustic features on the fly and returns random segments of (audio, feature) couples. """ - def __init__(self, - ap, - items, - seq_len, - hop_len, - pad_short, - conv_pad=2, - is_training=True, - return_segments=True, - use_noise_augment=False, - use_cache=False, - verbose=False): + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + verbose=False, + ): super().__init__() self.ap = ap @@ -54,7 +58,7 @@ def create_feature_cache(self): @staticmethod def find_wav_files(path): - return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True) + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) def __len__(self): return len(self.item_list) @@ -86,13 +90,16 @@ def load_item(self, idx): if self.return_segments: # correct audio length wrt segment length if audio.shape[-1] < self.seq_len + self.pad_short: - audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \ - mode='constant', constant_values=0.0) - assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}" + audio = np.pad( + audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0 + ) + assert ( + audio.shape[-1] >= self.seq_len + self.pad_short + ), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}" # correct the audio length wrt hop length p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1] - audio = np.pad(audio, (0, p), mode='constant', constant_values=0.0) + audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0) if self.use_cache: self.cache[idx] = audio @@ -126,7 +133,7 @@ def collate_full_clips(batch): for idx, b in enumerate(batch): mel = b[0] audio = b[1] - mels[idx, :, :mel.shape[1]] = mel - audios[idx, :audio.shape[0]] = audio + mels[idx, :, : mel.shape[1]] = mel + audios[idx, : audio.shape[0]] = audio return mels, audios diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index d45932c99c..4ab8a17467 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch from torch.utils.data import Dataset @@ -9,17 +9,18 @@ class WaveRNNDataset(Dataset): and converts them to acoustic features on the fly. """ - def __init__(self, - ap, - items, - seq_len, - hop_len, - pad, - mode, - mulaw, - is_training=True, - verbose=False, - ): + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad, + mode, + mulaw, + is_training=True, + verbose=False, + ): super(WaveRNNDataset, self).__init__() self.ap = ap @@ -61,8 +62,9 @@ def load_item(self, index): if self.mode in ["gauss", "mold"]: x_input = audio elif isinstance(self.mode, int): - x_input = (self.ap.mulaw_encode(audio, qc=self.mode) - if self.mulaw else self.ap.quantize(audio, bits=self.mode)) + x_input = ( + self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode) + ) else: raise RuntimeError("Unknown dataset mode - ", self.mode) @@ -71,7 +73,7 @@ def load_item(self, index): wavpath, feat_path = self.item_list[index] mel = np.load(feat_path.replace("/quant/", "/mel/")) - if mel.shape[-1] < self.mel_len + 2 * self.pad: + if mel.shape[-1] < self.mel_len + 2 * self.pad: print(" [!] Instance is too short! : {}".format(wavpath)) self.item_list[index] = self.item_list[index + 1] feat_path = self.item_list[index] @@ -87,22 +89,14 @@ def load_item(self, index): def collate(self, batch): mel_win = self.seq_len // self.hop_len + 2 * self.pad - max_offsets = [x[0].shape[-1] - - (mel_win + 2 * self.pad) for x in batch] + max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch] mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] - sig_offsets = [(offset + self.pad) * - self.hop_len for offset in mel_offsets] + sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets] - mels = [ - x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win] - for i, x in enumerate(batch) - ] + mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)] - coarse = [ - x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1] - for i, x in enumerate(batch) - ] + coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)] mels = np.stack(mels).astype(np.float32) if self.mode in ["gauss", "mold"]: @@ -112,8 +106,7 @@ def collate(self, batch): elif isinstance(self.mode, int): coarse = np.stack(coarse).astype(np.int64) coarse = torch.LongTensor(coarse) - x_input = (2 * coarse[:, : self.seq_len].float() / - (2 ** self.mode - 1.0) - 1.0) + x_input = 2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0 y_coarse = coarse[:, 1:] mels = torch.FloatTensor(mels) return x_input, mels, y_coarse diff --git a/TTS/vocoder/layers/hifigan.py b/TTS/vocoder/layers/hifigan.py index ffd4058841..11f2ea0c3e 100644 --- a/TTS/vocoder/layers/hifigan.py +++ b/TTS/vocoder/layers/hifigan.py @@ -10,20 +10,14 @@ def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]): resstack += [ nn.LeakyReLU(0.2), nn.ReflectionPad1d(dilation), - nn.utils.weight_norm( - nn.Conv1d(channel, - channel, - kernel_size=kernel, - dilation=dilation)), + nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)), nn.LeakyReLU(0.2), nn.ReflectionPad1d(padding), - nn.utils.weight_norm(nn.Conv1d(channel, channel, - kernel_size=1)), + nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), ] self.resstack = nn.Sequential(*resstack) - self.shortcut = nn.utils.weight_norm( - nn.Conv1d(channel, channel, kernel_size=1)) + self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) def forward(self, x): x1 = self.shortcut(x) diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index c7495cb812..a6e438f920 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -1,22 +1,25 @@ -import torch import librosa +import torch from torch import nn from torch.nn import functional as F class TorchSTFT(nn.Module): # pylint: disable=abstract-method """TODO: Merge this with audio.py""" - def __init__(self, - n_fft, - hop_length, - win_length, - pad_wav=False, - window='hann_window', - sample_rate=None, - mel_fmin=0, - mel_fmax=None, - n_mels=80, - use_mel=False): + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + ): """ Torch based STFT operation """ super(TorchSTFT, self).__init__() self.n_fft = n_fft @@ -28,8 +31,7 @@ def __init__(self, self.mel_fmax = mel_fmax self.n_mels = n_mels self.use_mel = use_mel - self.window = nn.Parameter(getattr(torch, window)(win_length), - requires_grad=False) + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) self.mel_basis = None if use_mel: self._build_mel_basis() @@ -50,7 +52,7 @@ def __call__(self, x): x = x.unsqueeze(1) if self.pad_wav: padding = int((self.n_fft - self.hop_length) / 2) - x = torch.nn.functional.pad(x, (padding, padding), mode='reflect') + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") # B x D x T x 2 o = torch.stft( x.squeeze(1), @@ -62,33 +64,32 @@ def __call__(self, x): pad_mode="reflect", # compatible with audio.py normalized=False, onesided=True, - return_complex=False) + return_complex=False, + ) M = o[:, :, :, 0] P = o[:, :, :, 1] - S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) + S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) if self.use_mel: S = torch.matmul(self.mel_basis.to(x), S) return S def _build_mel_basis(self): - mel_basis = librosa.filters.mel(self.sample_rate, - self.n_fft, - n_mels=self.n_mels, - fmin=self.mel_fmin, - fmax=self.mel_fmax) + mel_basis = librosa.filters.mel( + self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) self.mel_basis = torch.from_numpy(mel_basis).float() - ################################# # GENERATOR LOSSES ################################# class STFTLoss(nn.Module): - """ STFT loss. Input generate and real waveforms are converted + """STFT loss. Input generate and real waveforms are converted to spectrograms compared with L1 and Spectral convergence losses. It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" + def __init__(self, n_fft, hop_length, win_length): super(STFTLoss, self).__init__() self.n_fft = n_fft @@ -105,14 +106,13 @@ def forward(self, y_hat, y): loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro") return loss_mag, loss_sc + class MultiScaleSTFTLoss(torch.nn.Module): - """ Multi-scale STFT loss. Input generate and real waveforms are converted + """Multi-scale STFT loss. Input generate and real waveforms are converted to spectrograms compared with L1 and Spectral convergence losses. It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" - def __init__(self, - n_ffts=(1024, 2048, 512), - hop_lengths=(120, 240, 50), - win_lengths=(600, 1200, 240)): + + def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)): super(MultiScaleSTFTLoss, self).__init__() self.loss_funcs = torch.nn.ModuleList() for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths): @@ -130,19 +130,25 @@ def forward(self, y_hat, y): loss_mag /= N return loss_mag, loss_sc + class L1SpecLoss(nn.Module): """ L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf""" - def __init__(self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True): + + def __init__( + self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True + ): super().__init__() self.use_mel = use_mel - self.stft = TorchSTFT(n_fft, - hop_length, - win_length, - sample_rate=sample_rate, - mel_fmin=mel_fmin, - mel_fmax=mel_fmax, - n_mels=n_mels, - use_mel=use_mel) + self.stft = TorchSTFT( + n_fft, + hop_length, + win_length, + sample_rate=sample_rate, + mel_fmin=mel_fmin, + mel_fmax=mel_fmax, + n_mels=n_mels, + use_mel=use_mel, + ) def forward(self, y_hat, y): y_hat_M = self.stft(y_hat) @@ -151,9 +157,11 @@ def forward(self, y_hat, y): loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) return loss_mag + class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): - """ Multiscale STFT loss for multi band model outputs. + """Multiscale STFT loss for multi band model outputs. From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106""" + # pylint: disable=no-self-use def forward(self, y_hat, y): y_hat = y_hat.view(-1, 1, y_hat.shape[2]) @@ -163,6 +171,7 @@ def forward(self, y_hat, y): class MSEGLoss(nn.Module): """ Mean Squared Generator Loss """ + # pylint: disable=no-self-use def forward(self, score_real): loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape)) @@ -171,10 +180,11 @@ def forward(self, score_real): class HingeGLoss(nn.Module): """ Hinge Discriminator Loss """ + # pylint: disable=no-self-use def forward(self, score_real): # TODO: this might be wrong - loss_fake = torch.mean(F.relu(1. - score_real)) + loss_fake = torch.mean(F.relu(1.0 - score_real)) return loss_fake @@ -185,7 +195,10 @@ def forward(self, score_real): class MSEDLoss(nn.Module): """ Mean Squared Discriminator Loss """ - def __init__(self,): + + def __init__( + self, + ): super(MSEDLoss, self).__init__() self.loss_func = nn.MSELoss() @@ -199,16 +212,19 @@ def forward(self, score_fake, score_real): class HingeDLoss(nn.Module): """ Hinge Discriminator Loss """ + # pylint: disable=no-self-use def forward(self, score_fake, score_real): - loss_real = torch.mean(F.relu(1. - score_real)) - loss_fake = torch.mean(F.relu(1. + score_fake)) + loss_real = torch.mean(F.relu(1.0 - score_real)) + loss_fake = torch.mean(F.relu(1.0 + score_fake)) loss_d = loss_real + loss_fake return loss_d, loss_real, loss_fake class MelganFeatureLoss(nn.Module): - def __init__(self,): + def __init__( + self, + ): super(MelganFeatureLoss, self).__init__() self.loss_func = nn.L1Loss() @@ -230,8 +246,8 @@ def forward(self, fake_feats, real_feats): def _apply_G_adv_loss(scores_fake, loss_func): - """ Compute G adversarial loss function - and normalize values """ + """Compute G adversarial loss function + and normalize values""" adv_loss = 0 if isinstance(scores_fake, list): for score_fake in scores_fake: @@ -280,24 +296,26 @@ class GeneratorLoss(nn.Module): Args: C (AttrDict): model configuration. """ + def __init__(self, C): super().__init__() - assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\ - " [!] Cannot use HingeGANLoss and MSEGANLoss together." - - self.use_stft_loss = C.use_stft_loss if 'use_stft_loss' in C else False - self.use_subband_stft_loss = C.use_subband_stft_loss if 'use_subband_stft_loss' in C else False - self.use_mse_gan_loss = C.use_mse_gan_loss if 'use_mse_gan_loss' in C else False - self.use_hinge_gan_loss = C.use_hinge_gan_loss if 'use_hinge_gan_loss' in C else False - self.use_feat_match_loss = C.use_feat_match_loss if 'use_feat_match_loss' in C else False - self.use_l1_spec_loss = C.use_l1_spec_loss if 'use_l1_spec_loss' in C else False - - self.stft_loss_weight = C.stft_loss_weight if 'stft_loss_weight' in C else 0.0 - self.subband_stft_loss_weight = C.subband_stft_loss_weight if 'subband_stft_loss_weight' in C else 0.0 - self.mse_gan_loss_weight = C.mse_G_loss_weight if 'mse_G_loss_weight' in C else 0.0 - self.hinge_gan_loss_weight = C.hinge_G_loss_weight if 'hinde_G_loss_weight' in C else 0.0 - self.feat_match_loss_weight = C.feat_match_loss_weight if 'feat_match_loss_weight' in C else 0.0 - self.l1_spec_loss_weight = C.l1_spec_loss_weight if 'l1_spec_loss_weight' in C else 0.0 + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." + + self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False + self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False + self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False + self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False + self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False + self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False + + self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0 + self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0 + self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0 + self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0 + self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0 + self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0 if C.use_stft_loss: self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params) @@ -310,69 +328,73 @@ def __init__(self, C): if C.use_feat_match_loss: self.feat_match_loss = MelganFeatureLoss() if C.use_l1_spec_loss: - assert C.audio['sample_rate'] == C.l1_spec_loss_params['sample_rate'] + assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"] self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params) - def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None): + def forward( + self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None + ): gen_loss = 0 adv_loss = 0 return_dict = {} # STFT Loss if self.use_stft_loss: - stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1)) - return_dict['G_stft_loss_mg'] = stft_loss_mg - return_dict['G_stft_loss_sc'] = stft_loss_sc + stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1)) + return_dict["G_stft_loss_mg"] = stft_loss_mg + return_dict["G_stft_loss_sc"] = stft_loss_sc gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) # L1 Spec loss if self.use_l1_spec_loss: l1_spec_loss = self.l1_spec_loss(y_hat, y) - return_dict['G_l1_spec_loss'] = l1_spec_loss + return_dict["G_l1_spec_loss"] = l1_spec_loss gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss # L1 Spec loss if self.use_l1_spec_loss: l1_spec_loss = self.l1_spec_loss(y_hat, y) - return_dict['G_l1_spec_loss'] = l1_spec_loss + return_dict["G_l1_spec_loss"] = l1_spec_loss gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss # subband STFT Loss if self.use_subband_stft_loss: subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) - return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg - return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc + return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg + return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) # multiscale MSE adversarial loss if self.use_mse_gan_loss and scores_fake is not None: mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss) - return_dict['G_mse_fake_loss'] = mse_fake_loss + return_dict["G_mse_fake_loss"] = mse_fake_loss adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss # multiscale Hinge adversarial loss if self.use_hinge_gan_loss and not scores_fake is not None: hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss) - return_dict['G_hinge_fake_loss'] = hinge_fake_loss + return_dict["G_hinge_fake_loss"] = hinge_fake_loss adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss # Feature Matching Loss if self.use_feat_match_loss and not feats_fake is None: feat_match_loss = self.feat_match_loss(feats_fake, feats_real) - return_dict['G_feat_match_loss'] = feat_match_loss + return_dict["G_feat_match_loss"] = feat_match_loss adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss - return_dict['G_loss'] = gen_loss + adv_loss - return_dict['G_gen_loss'] = gen_loss - return_dict['G_adv_loss'] = adv_loss + return_dict["G_loss"] = gen_loss + adv_loss + return_dict["G_gen_loss"] = gen_loss + return_dict["G_adv_loss"] = adv_loss return return_dict class DiscriminatorLoss(nn.Module): """Like ```GeneratorLoss```""" + def __init__(self, C): super().__init__() - assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\ - " [!] Cannot use HingeGANLoss and MSEGANLoss together." + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." self.use_mse_gan_loss = C.use_mse_gan_loss self.use_hinge_gan_loss = C.use_hinge_gan_loss @@ -388,23 +410,21 @@ def forward(self, scores_fake, scores_real): if self.use_mse_gan_loss: mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss( - scores_fake=scores_fake, - scores_real=scores_real, - loss_func=self.mse_loss) - return_dict['D_mse_gan_loss'] = mse_D_loss - return_dict['D_mse_gan_real_loss'] = mse_D_real_loss - return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss + ) + return_dict["D_mse_gan_loss"] = mse_D_loss + return_dict["D_mse_gan_real_loss"] = mse_D_real_loss + return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss loss += mse_D_loss if self.use_hinge_gan_loss: hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss( - scores_fake=scores_fake, - scores_real=scores_real, - loss_func=self.hinge_loss) - return_dict['D_hinge_gan_loss'] = hinge_D_loss - return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss - return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss + ) + return_dict["D_hinge_gan_loss"] = hinge_D_loss + return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss + return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss loss += hinge_D_loss - return_dict['D_loss'] = loss + return_dict["D_loss"] = loss return return_dict diff --git a/TTS/vocoder/layers/melgan.py b/TTS/vocoder/layers/melgan.py index 58c12a2eb8..67f98c1351 100644 --- a/TTS/vocoder/layers/melgan.py +++ b/TTS/vocoder/layers/melgan.py @@ -12,26 +12,23 @@ def __init__(self, channels, num_res_blocks, kernel_size): self.blocks = nn.ModuleList() for idx in range(num_res_blocks): layer_kernel_size = kernel_size - layer_dilation = layer_kernel_size**idx + layer_dilation = layer_kernel_size ** idx layer_padding = base_padding * layer_dilation - self.blocks += [nn.Sequential( - nn.LeakyReLU(0.2), - nn.ReflectionPad1d(layer_padding), - weight_norm( - nn.Conv1d(channels, - channels, - kernel_size=kernel_size, - dilation=layer_dilation, - bias=True)), - nn.LeakyReLU(0.2), - weight_norm( - nn.Conv1d(channels, channels, kernel_size=1, bias=True)), - )] + self.blocks += [ + nn.Sequential( + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(layer_padding), + weight_norm( + nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True) + ), + nn.LeakyReLU(0.2), + weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)), + ) + ] - self.shortcuts = nn.ModuleList([ - weight_norm(nn.Conv1d(channels, channels, kernel_size=1, - bias=True)) for i in range(num_res_blocks) - ]) + self.shortcuts = nn.ModuleList( + [weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for i in range(num_res_blocks)] + ) def forward(self, x): for block, shortcut in zip(self.blocks, self.shortcuts): diff --git a/TTS/vocoder/layers/parallel_wavegan.py b/TTS/vocoder/layers/parallel_wavegan.py index bedfe5519a..427a2f3db4 100644 --- a/TTS/vocoder/layers/parallel_wavegan.py +++ b/TTS/vocoder/layers/parallel_wavegan.py @@ -4,54 +4,44 @@ class ResidualBlock(torch.nn.Module): """Residual block module in WaveNet.""" - def __init__(self, - kernel_size=3, - res_channels=64, - gate_channels=128, - skip_channels=64, - aux_channels=80, - dropout=0.0, - dilation=1, - bias=True, - use_causal_conv=False): + + def __init__( + self, + kernel_size=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + bias=True, + use_causal_conv=False, + ): super(ResidualBlock, self).__init__() self.dropout = dropout # no future time stamps available if use_causal_conv: padding = (kernel_size - 1) * dilation else: - assert (kernel_size - - 1) % 2 == 0, "Not support even number kernel size." + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." padding = (kernel_size - 1) // 2 * dilation self.use_causal_conv = use_causal_conv # dilation conv - self.conv = torch.nn.Conv1d(res_channels, - gate_channels, - kernel_size, - padding=padding, - dilation=dilation, - bias=bias) + self.conv = torch.nn.Conv1d( + res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias + ) # local conditioning if aux_channels > 0: - self.conv1x1_aux = torch.nn.Conv1d(aux_channels, - gate_channels, - 1, - bias=False) + self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False) else: self.conv1x1_aux = None # conv output is split into two groups gate_out_channels = gate_channels // 2 - self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, - res_channels, - 1, - bias=bias) - self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, - skip_channels, - 1, - bias=bias) + self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias) + self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias) def forward(self, x, c): """ @@ -63,7 +53,7 @@ def forward(self, x, c): x = self.conv(x) # remove future time steps if use_causal_conv conv - x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x + x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x # split into two part for gated activation splitdim = 1 @@ -82,6 +72,6 @@ def forward(self, x, c): s = self.conv1x1_skip(x) # for residual connection - x = (self.conv1x1_out(x) + residual) * (0.5**2) + x = (self.conv1x1_out(x) + residual) * (0.5 ** 2) return x, s diff --git a/TTS/vocoder/layers/pqmf.py b/TTS/vocoder/layers/pqmf.py index d31953d628..38d28fc1d9 100644 --- a/TTS/vocoder/layers/pqmf.py +++ b/TTS/vocoder/layers/pqmf.py @@ -1,7 +1,6 @@ import numpy as np import torch import torch.nn.functional as F - from scipy import signal as sig @@ -16,14 +15,14 @@ def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): self.cutoff = cutoff self.beta = beta - QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta)) + QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta)) H = np.zeros((N, len(QMF))) G = np.zeros((N, len(QMF))) for k in range(N): - constant_factor = (2 * k + 1) * (np.pi / - (2 * N)) * (np.arange(taps + 1) - - ((taps - 1) / 2)) # TODO: (taps - 1) -> taps - phase = (-1)**k * np.pi / 4 + constant_factor = ( + (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + ) # TODO: (taps - 1) -> taps + phase = (-1) ** k * np.pi / 4 H[k] = 2 * QMF * np.cos(constant_factor + phase) G[k] = 2 * QMF * np.cos(constant_factor - phase) @@ -49,8 +48,6 @@ def analysis(self, x): return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N) def synthesis(self, x): - x = F.conv_transpose1d(x, - self.updown_filter * self.N, - stride=self.N) + x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N) x = F.conv1d(x, self.G, padding=self.taps // 2) return x diff --git a/TTS/vocoder/layers/upsample.py b/TTS/vocoder/layers/upsample.py index 1340687582..9bf2a6b424 100644 --- a/TTS/vocoder/layers/upsample.py +++ b/TTS/vocoder/layers/upsample.py @@ -11,23 +11,23 @@ def __init__(self, x_scale, y_scale, mode="nearest"): def forward(self, x): """ - x (Tensor): Input tensor (B, C, F, T). - Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), + x (Tensor): Input tensor (B, C, F, T). + Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), """ - return F.interpolate( - x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) + return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) class UpsampleNetwork(torch.nn.Module): # pylint: disable=dangerous-default-value - def __init__(self, - upsample_factors, - nonlinear_activation=None, - nonlinear_activation_params={}, - interpolate_mode="nearest", - freq_axis_kernel_size=1, - use_causal_conv=False, - ): + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + ): super(UpsampleNetwork, self).__init__() self.use_causal_conv = use_causal_conv self.up_layers = torch.nn.ModuleList() @@ -54,8 +54,8 @@ def __init__(self, def forward(self, c): """ - c : (B, C, T_in). - Tensor: (B, C, T_upsample) + c : (B, C, T_in). + Tensor: (B, C, T_upsample) """ c = c.unsqueeze(1) # (B, 1, C, T) for f in self.up_layers: @@ -65,16 +65,17 @@ def forward(self, c): class ConvUpsample(torch.nn.Module): # pylint: disable=dangerous-default-value - def __init__(self, - upsample_factors, - nonlinear_activation=None, - nonlinear_activation_params={}, - interpolate_mode="nearest", - freq_axis_kernel_size=1, - aux_channels=80, - aux_context_window=0, - use_causal_conv=False - ): + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False, + ): super(ConvUpsample, self).__init__() self.aux_context_window = aux_context_window self.use_causal_conv = use_causal_conv and aux_context_window > 0 @@ -97,5 +98,5 @@ def forward(self, c): Tensor: (B, C, T_upsampled), """ c_ = self.conv_in(c) - c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ + c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ return self.upsample(c) diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py index 81f0312476..83cd4233ef 100644 --- a/TTS/vocoder/layers/wavegrad.py +++ b/TTS/vocoder/layers/wavegrad.py @@ -13,6 +13,7 @@ def __init__(self, *args, **kwargs): class PositionalEncoding(nn.Module): """Positional encoding with noise level conditioning""" + def __init__(self, n_channels, max_len=10000): super().__init__() self.n_channels = n_channels @@ -23,9 +24,7 @@ def __init__(self, n_channels, max_len=10000): def forward(self, x, noise_level): if x.shape[2] > self.pe.shape[1]: self.init_pe_matrix(x.shape[1], x.shape[2], x) - return x + noise_level[..., None, - None] + self.pe[:, :x.size(2)].repeat( - x.shape[0], 1, 1) / self.C + return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C def init_pe_matrix(self, n_channels, max_len, x): pe = torch.zeros(max_len, n_channels) @@ -79,30 +78,18 @@ def __init__(self, input_size, hidden_size, factor, dilation): self.factor = factor self.res_block = Conv1d(input_size, hidden_size, 1) - self.main_block = nn.ModuleList([ - Conv1d(input_size, - hidden_size, - 3, - dilation=dilation[0], - padding=dilation[0]), - Conv1d(hidden_size, - hidden_size, - 3, - dilation=dilation[1], - padding=dilation[1]) - ]) - self.out_block = nn.ModuleList([ - Conv1d(hidden_size, - hidden_size, - 3, - dilation=dilation[2], - padding=dilation[2]), - Conv1d(hidden_size, - hidden_size, - 3, - dilation=dilation[3], - padding=dilation[3]) - ]) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]), + ] + ) + self.out_block = nn.ModuleList( + [ + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]), + ] + ) def forward(self, x, shift, scale): x_inter = F.interpolate(x, size=x.shape[-1] * self.factor) @@ -147,11 +134,13 @@ def __init__(self, input_size, hidden_size, factor): super().__init__() self.factor = factor self.res_block = Conv1d(input_size, hidden_size, 1) - self.main_block = nn.ModuleList([ - Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), - Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), - Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), - ]) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), + Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), + Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), + ] + ) def forward(self, x): size = x.shape[-1] // self.factor diff --git a/TTS/vocoder/models/fullband_melgan_generator.py b/TTS/vocoder/models/fullband_melgan_generator.py index 52dcc75ee1..ee25559af0 100644 --- a/TTS/vocoder/models/fullband_melgan_generator.py +++ b/TTS/vocoder/models/fullband_melgan_generator.py @@ -4,27 +4,30 @@ class FullbandMelganGenerator(MelganGenerator): - def __init__(self, - in_channels=80, - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=(2, 8, 2, 2), - res_kernel=3, - num_res_blocks=4): - super().__init__(in_channels=in_channels, - out_channels=out_channels, - proj_kernel=proj_kernel, - base_channels=base_channels, - upsample_factors=upsample_factors, - res_kernel=res_kernel, - num_res_blocks=num_res_blocks) + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=4, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) @torch.no_grad() def inference(self, cond_features): cond_features = cond_features.to(self.layers[1].weight.device) cond_features = torch.nn.functional.pad( - cond_features, - (self.inference_padding, self.inference_padding), - 'replicate') + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) return self.layers(cond_features) diff --git a/TTS/vocoder/models/hifigan_discriminator.py b/TTS/vocoder/models/hifigan_discriminator.py index 8e824b6b45..49c877af62 100644 --- a/TTS/vocoder/models/hifigan_discriminator.py +++ b/TTS/vocoder/models/hifigan_discriminator.py @@ -3,7 +3,6 @@ from torch import nn from torch.nn import functional as F - LRELU_SLOPE = 0.1 @@ -26,18 +25,21 @@ class DiscriminatorP(torch.nn.Module): Shapes: x: [B, 1, T] """ + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): super().__init__() self.period = period - get_padding = lambda k, d: int((k*d - d)/2) + get_padding = lambda k, d: int((k * d - d) / 2) norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm - self.convs = nn.ModuleList([ - norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), - norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), - ]) + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) def forward(self, x): @@ -56,7 +58,7 @@ def forward(self, x): # 1d to 2d b, c, t = x.shape - if t % self.period != 0: # pad first + if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad @@ -78,15 +80,18 @@ class MultiPeriodDiscriminator(torch.nn.Module): Wrapper for the `PeriodDiscriminator` to apply it in different periods. Periods are suggested to be prime numbers to reduce the overlap between each discriminator. """ + def __init__(self): super(MultiPeriodDiscriminator, self).__init__() - self.discriminators = nn.ModuleList([ - DiscriminatorP(2), - DiscriminatorP(3), - DiscriminatorP(5), - DiscriminatorP(7), - DiscriminatorP(11), - ]) + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) def forward(self, x): """ @@ -117,18 +122,21 @@ class DiscriminatorS(torch.nn.Module): use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. """ + def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm - self.convs = nn.ModuleList([ - norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), - norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), - norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), - norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), - norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), - norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), - norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), - ]) + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), + norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) def forward(self, x): @@ -155,17 +163,17 @@ class MultiScaleDiscriminator(torch.nn.Module): """HiFiGAN Multi-Scale Discriminator. It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper. """ + def __init__(self): super(MultiScaleDiscriminator, self).__init__() - self.discriminators = nn.ModuleList([ - DiscriminatorS(use_spectral_norm=True), - DiscriminatorS(), - DiscriminatorS(), - ]) - self.meanpools = nn.ModuleList([ - nn.AvgPool1d(4, 2, padding=2), - nn.AvgPool1d(4, 2, padding=2) - ]) + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]) def forward(self, x): """ @@ -180,7 +188,7 @@ def forward(self, x): feats = [] for i, d in enumerate(self.discriminators): if i != 0: - x = self.meanpools[i-1](x) + x = self.meanpools[i - 1](x) score, feat = d(x) scores.append(score) feats.append(feat) @@ -188,8 +196,8 @@ def forward(self, x): class HifiganDiscriminator(nn.Module): - """HiFiGAN discriminator wrapping MPD and MSD. - """ + """HiFiGAN discriminator wrapping MPD and MSD.""" + def __init__(self): super().__init__() self.mpd = MultiPeriodDiscriminator() diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index c838fb8f85..8d595a6380 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -1,9 +1,9 @@ # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py import torch -import torch.nn.functional as F import torch.nn as nn +import torch.nn.functional as F from torch.nn import Conv1d, ConvTranspose1d -from torch.nn.utils import weight_norm, remove_weight_norm +from torch.nn.utils import remove_weight_norm, weight_norm LRELU_SLOPE = 0.1 @@ -26,55 +26,57 @@ class ResBlock1(torch.nn.Module): kernel_size (int): size of the convolution filter in each layer. dilations (list): list of dilation value for each conv layer in a block. """ + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super().__init__() - self.convs1 = nn.ModuleList([ - weight_norm( - Conv1d(channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]))), - weight_norm( - Conv1d(channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]))), - weight_norm( - Conv1d(channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]))) - ]) - - self.convs2 = nn.ModuleList([ - weight_norm( - Conv1d(channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm( - Conv1d(channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1))), - weight_norm( - Conv1d(channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1))) - ]) + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) def forward(self, x): """ @@ -114,24 +116,33 @@ class ResBlock2(torch.nn.Module): kernel_size (int): size of the convolution filter in each layer. dilations (list): list of dilation value for each conv layer in a block. """ + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super().__init__() - self.convs = nn.ModuleList([ - weight_norm( - Conv1d(channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]))), - weight_norm( - Conv1d(channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]))) - ]) + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) def forward(self, x): for c in self.convs: @@ -146,10 +157,18 @@ def remove_weight_norm(self): class HifiganGenerator(torch.nn.Module): - def __init__(self, in_channels, out_channels, resblock_type, - resblock_dilation_sizes, resblock_kernel_sizes, - upsample_kernel_sizes, upsample_initial_channel, - upsample_factors, inference_padding=5): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + ): r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) Network: @@ -174,26 +193,27 @@ def __init__(self, in_channels, out_channels, resblock_type, self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_factors) # initial upsampling layers - self.conv_pre = weight_norm( - Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) - resblock = ResBlock1 if resblock_type == '1' else ResBlock2 + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 # upsampling layers self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_factors, - upsample_kernel_sizes)): + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): self.ups.append( weight_norm( - ConvTranspose1d(upsample_initial_channel // (2**i), - upsample_initial_channel // (2**(i + 1)), - k, - u, - padding=(k - u) // 2))) + ConvTranspose1d( + upsample_initial_channel // (2 ** i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) # MRF blocks self.resblocks = nn.ModuleList() for i in range(len(self.ups)): - ch = upsample_initial_channel // (2**(i + 1)) - for _, (k, d) in enumerate( - zip(resblock_kernel_sizes, resblock_dilation_sizes)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(resblock(ch, k, d)) # post convolution layer self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3)) @@ -240,12 +260,11 @@ def inference(self, c): Tensor: [B, 1, T] """ c = c.to(self.conv_pre.weight.device) - c = torch.nn.functional.pad( - c, (self.inference_padding, self.inference_padding), 'replicate') + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") return self.forward(c) def remove_weight_norm(self): - print('Removing weight norm...') + print("Removing weight norm...") for l in self.ups: remove_weight_norm(l) for l in self.resblocks: @@ -253,9 +272,11 @@ def remove_weight_norm(self): remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/vocoder/models/melgan_discriminator.py b/TTS/vocoder/models/melgan_discriminator.py index 5e32d5695f..4849938937 100644 --- a/TTS/vocoder/models/melgan_discriminator.py +++ b/TTS/vocoder/models/melgan_discriminator.py @@ -4,14 +4,16 @@ class MelganDiscriminator(nn.Module): - def __init__(self, - in_channels=1, - out_channels=1, - kernel_sizes=(5, 3), - base_channels=16, - max_channels=1024, - downsample_factors=(4, 4, 4, 4), - groups_denominator=4): + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4, 4), + groups_denominator=4, + ): super(MelganDiscriminator, self).__init__() self.layers = nn.ModuleList() @@ -22,31 +24,32 @@ def __init__(self, self.layers += [ nn.Sequential( nn.ReflectionPad1d(layer_padding), - weight_norm( - nn.Conv1d(in_channels, - base_channels, - layer_kernel_size, - stride=1)), nn.LeakyReLU(0.2, inplace=True)) + weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)), + nn.LeakyReLU(0.2, inplace=True), + ) ] # downsampling layers layer_in_channels = base_channels for downsample_factor in downsample_factors: - layer_out_channels = min(layer_in_channels * downsample_factor, - max_channels) + layer_out_channels = min(layer_in_channels * downsample_factor, max_channels) layer_kernel_size = downsample_factor * 10 + 1 layer_padding = (layer_kernel_size - 1) // 2 layer_groups = layer_in_channels // groups_denominator self.layers += [ nn.Sequential( weight_norm( - nn.Conv1d(layer_in_channels, - layer_out_channels, - kernel_size=layer_kernel_size, - stride=downsample_factor, - padding=layer_padding, - groups=layer_groups)), - nn.LeakyReLU(0.2, inplace=True)) + nn.Conv1d( + layer_in_channels, + layer_out_channels, + kernel_size=layer_kernel_size, + stride=downsample_factor, + padding=layer_padding, + groups=layer_groups, + ) + ), + nn.LeakyReLU(0.2, inplace=True), + ) ] layer_in_channels = layer_out_channels @@ -56,19 +59,21 @@ def __init__(self, self.layers += [ nn.Sequential( weight_norm( - nn.Conv1d(layer_out_channels, - layer_out_channels, - kernel_size=kernel_sizes[0], - stride=1, - padding=layer_padding1)), + nn.Conv1d( + layer_out_channels, + layer_out_channels, + kernel_size=kernel_sizes[0], + stride=1, + padding=layer_padding1, + ) + ), nn.LeakyReLU(0.2, inplace=True), ), weight_norm( - nn.Conv1d(layer_out_channels, - out_channels, - kernel_size=kernel_sizes[1], - stride=1, - padding=layer_padding2)), + nn.Conv1d( + layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2 + ) + ), ] def forward(self, x): diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index 3070eac779..ea7a90b1bb 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -6,19 +6,20 @@ class MelganGenerator(nn.Module): - def __init__(self, - in_channels=80, - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=(8, 8, 2, 2), - res_kernel=3, - num_res_blocks=3): + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(8, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): super(MelganGenerator, self).__init__() # assert model parameters - assert (proj_kernel - - 1) % 2 == 0, " [!] proj_kernel should be an odd number." + assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." # setup additional model parameters base_padding = (proj_kernel - 1) // 2 @@ -29,18 +30,13 @@ def __init__(self, layers = [] layers += [ nn.ReflectionPad1d(base_padding), - weight_norm( - nn.Conv1d(in_channels, - base_channels, - kernel_size=proj_kernel, - stride=1, - bias=True)) + weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), ] # upsampling layers and residual stacks for idx, upsample_factor in enumerate(upsample_factors): - layer_in_channels = base_channels // (2**idx) - layer_out_channels = base_channels // (2**(idx + 1)) + layer_in_channels = base_channels // (2 ** idx) + layer_out_channels = base_channels // (2 ** (idx + 1)) layer_filter_size = upsample_factor * 2 layer_stride = upsample_factor layer_output_padding = upsample_factor % 2 @@ -48,18 +44,17 @@ def __init__(self, layers += [ nn.LeakyReLU(act_slope), weight_norm( - nn.ConvTranspose1d(layer_in_channels, - layer_out_channels, - layer_filter_size, - stride=layer_stride, - padding=layer_padding, - output_padding=layer_output_padding, - bias=True)), - ResidualStack( - channels=layer_out_channels, - num_res_blocks=num_res_blocks, - kernel_size=res_kernel - ) + nn.ConvTranspose1d( + layer_in_channels, + layer_out_channels, + layer_filter_size, + stride=layer_stride, + padding=layer_padding, + output_padding=layer_output_padding, + bias=True, + ) + ), + ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), ] layers += [nn.LeakyReLU(act_slope)] @@ -67,13 +62,8 @@ def __init__(self, # final layer layers += [ nn.ReflectionPad1d(base_padding), - weight_norm( - nn.Conv1d(layer_out_channels, - out_channels, - proj_kernel, - stride=1, - bias=True)), - nn.Tanh() + weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), + nn.Tanh(), ] self.layers = nn.Sequential(*layers) @@ -82,10 +72,7 @@ def forward(self, c): def inference(self, c): c = c.to(self.layers[1].weight.device) - c = torch.nn.functional.pad( - c, - (self.inference_padding, self.inference_padding), - 'replicate') + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") return self.layers(c) def remove_weight_norm(self): @@ -96,9 +83,11 @@ def remove_weight_norm(self): except ValueError: layer.remove_weight_norm() - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/vocoder/models/melgan_multiscale_discriminator.py b/TTS/vocoder/models/melgan_multiscale_discriminator.py index b01ab91fa2..dc907040cd 100644 --- a/TTS/vocoder/models/melgan_multiscale_discriminator.py +++ b/TTS/vocoder/models/melgan_multiscale_discriminator.py @@ -4,35 +4,40 @@ class MelganMultiscaleDiscriminator(nn.Module): - def __init__(self, - in_channels=1, - out_channels=1, - num_scales=3, - kernel_sizes=(5, 3), - base_channels=16, - max_channels=1024, - downsample_factors=(4, 4, 4), - pooling_kernel_size=4, - pooling_stride=2, - pooling_padding=2, - groups_denominator=4): + def __init__( + self, + in_channels=1, + out_channels=1, + num_scales=3, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4), + pooling_kernel_size=4, + pooling_stride=2, + pooling_padding=2, + groups_denominator=4, + ): super(MelganMultiscaleDiscriminator, self).__init__() - self.discriminators = nn.ModuleList([ - MelganDiscriminator(in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - base_channels=base_channels, - max_channels=max_channels, - downsample_factors=downsample_factors, - groups_denominator=groups_denominator) - for _ in range(num_scales) - ]) + self.discriminators = nn.ModuleList( + [ + MelganDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + base_channels=base_channels, + max_channels=max_channels, + downsample_factors=downsample_factors, + groups_denominator=groups_denominator, + ) + for _ in range(num_scales) + ] + ) - self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size, - stride=pooling_stride, - padding=pooling_padding, - count_include_pad=False) + self.pooling = nn.AvgPool1d( + kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False + ) def forward(self, x): scores = list() diff --git a/TTS/vocoder/models/multiband_melgan_generator.py b/TTS/vocoder/models/multiband_melgan_generator.py index 15e7426e74..8b61db1878 100644 --- a/TTS/vocoder/models/multiband_melgan_generator.py +++ b/TTS/vocoder/models/multiband_melgan_generator.py @@ -1,26 +1,29 @@ import torch -from TTS.vocoder.models.melgan_generator import MelganGenerator from TTS.vocoder.layers.pqmf import PQMF +from TTS.vocoder.models.melgan_generator import MelganGenerator class MultibandMelganGenerator(MelganGenerator): - def __init__(self, - in_channels=80, - out_channels=4, - proj_kernel=7, - base_channels=384, - upsample_factors=(2, 8, 2, 2), - res_kernel=3, - num_res_blocks=3): - super(MultibandMelganGenerator, - self).__init__(in_channels=in_channels, - out_channels=out_channels, - proj_kernel=proj_kernel, - base_channels=base_channels, - upsample_factors=upsample_factors, - res_kernel=res_kernel, - num_res_blocks=num_res_blocks) + def __init__( + self, + in_channels=80, + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super(MultibandMelganGenerator, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) def pqmf_analysis(self, x): @@ -33,7 +36,6 @@ def pqmf_synthesis(self, x): def inference(self, cond_features): cond_features = cond_features.to(self.layers[1].weight.device) cond_features = torch.nn.functional.pad( - cond_features, - (self.inference_padding, self.inference_padding), - 'replicate') + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) return self.pqmf_synthesis(self.layers(cond_features)) diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py index 37c2269550..6cd6a82e76 100644 --- a/TTS/vocoder/models/parallel_wavegan_discriminator.py +++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -1,4 +1,5 @@ import math + import torch from torch import nn @@ -11,18 +12,20 @@ class ParallelWaveganDiscriminator(nn.Module): of predictions. It is a stack of convolutional blocks with dilation. """ + # pylint: disable=dangerous-default-value - def __init__(self, - in_channels=1, - out_channels=1, - kernel_size=3, - num_layers=10, - conv_channels=64, - dilation_factor=1, - nonlinear_activation="LeakyReLU", - nonlinear_activation_params={"negative_slope": 0.2}, - bias=True, - ): + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=10, + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + ): super(ParallelWaveganDiscriminator, self).__init__() assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size." assert dilation_factor > 0, " [!] dilation factor must be > 0." @@ -36,21 +39,19 @@ def __init__(self, conv_in_channels = conv_channels padding = (kernel_size - 1) // 2 * dilation conv_layer = [ - nn.Conv1d(conv_in_channels, - conv_channels, - kernel_size=kernel_size, - padding=padding, - dilation=dilation, - bias=bias), - getattr(nn, - nonlinear_activation)(inplace=True, - **nonlinear_activation_params) + nn.Conv1d( + conv_in_channels, + conv_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), ] self.conv_layers += conv_layer padding = (kernel_size - 1) // 2 - last_conv_layer = nn.Conv1d( - conv_in_channels, out_channels, - kernel_size=kernel_size, padding=padding, bias=bias) + last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) self.conv_layers += [last_conv_layer] self.apply_weight_norm() @@ -68,6 +69,7 @@ def apply_weight_norm(self): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.weight_norm(m) + self.apply(_apply_weight_norm) def remove_weight_norm(self): @@ -77,25 +79,27 @@ def _remove_weight_norm(m): nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return + self.apply(_remove_weight_norm) class ResidualParallelWaveganDiscriminator(nn.Module): # pylint: disable=dangerous-default-value - def __init__(self, - in_channels=1, - out_channels=1, - kernel_size=3, - num_layers=30, - stacks=3, - res_channels=64, - gate_channels=128, - skip_channels=64, - dropout=0.0, - bias=True, - nonlinear_activation="LeakyReLU", - nonlinear_activation_params={"negative_slope": 0.2}, - ): + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ): super(ResidualParallelWaveganDiscriminator, self).__init__() assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." @@ -112,14 +116,8 @@ def __init__(self, # define first convolution self.first_conv = nn.Sequential( - nn.Conv1d(in_channels, - res_channels, - kernel_size=1, - padding=0, - dilation=1, - bias=True), - getattr(nn, nonlinear_activation)(inplace=True, - **nonlinear_activation_params), + nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), ) # define residual blocks @@ -140,24 +138,14 @@ def __init__(self, self.conv_layers += [conv] # define output layers - self.last_conv_layers = nn.ModuleList([ - getattr(nn, nonlinear_activation)(inplace=True, - **nonlinear_activation_params), - nn.Conv1d(skip_channels, - skip_channels, - kernel_size=1, - padding=0, - dilation=1, - bias=True), - getattr(nn, nonlinear_activation)(inplace=True, - **nonlinear_activation_params), - nn.Conv1d(skip_channels, - out_channels, - kernel_size=1, - padding=0, - dilation=1, - bias=True), - ]) + self.last_conv_layers = nn.ModuleList( + [ + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True), + ] + ) # apply weight norm self.apply_weight_norm() @@ -184,6 +172,7 @@ def apply_weight_norm(self): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.weight_norm(m) + self.apply(_apply_weight_norm) def remove_weight_norm(self): diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index 1d1bcdcbf8..08e5237155 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -1,4 +1,5 @@ import math + import numpy as np import torch @@ -12,22 +13,25 @@ class ParallelWaveganGenerator(torch.nn.Module): It is conditioned on an aux feature (spectrogram) to generate an output waveform from an input noise. """ + # pylint: disable=dangerous-default-value - def __init__(self, - in_channels=1, - out_channels=1, - kernel_size=3, - num_res_blocks=30, - stacks=3, - res_channels=64, - gate_channels=128, - skip_channels=64, - aux_channels=80, - dropout=0.0, - bias=True, - use_weight_norm=True, - upsample_factors=[4, 4, 4, 4], - inference_padding=2): + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_factors=[4, 4, 4, 4], + inference_padding=2, + ): super(ParallelWaveganGenerator, self).__init__() self.in_channels = in_channels @@ -46,10 +50,7 @@ def __init__(self, layers_per_stack = num_res_blocks // stacks # define first convolution - self.first_conv = torch.nn.Conv1d(in_channels, - res_channels, - kernel_size=1, - bias=True) + self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True) # define conv + upsampling network self.upsample_net = ConvUpsample(upsample_factors=upsample_factors) @@ -57,7 +58,7 @@ def __init__(self, # define residual blocks self.conv_layers = torch.nn.ModuleList() for layer in range(num_res_blocks): - dilation = 2**(layer % layers_per_stack) + dilation = 2 ** (layer % layers_per_stack) conv = ResidualBlock( kernel_size=kernel_size, res_channels=res_channels, @@ -71,18 +72,14 @@ def __init__(self, self.conv_layers += [conv] # define output layers - self.last_conv_layers = torch.nn.ModuleList([ - torch.nn.ReLU(inplace=True), - torch.nn.Conv1d(skip_channels, - skip_channels, - kernel_size=1, - bias=True), - torch.nn.ReLU(inplace=True), - torch.nn.Conv1d(skip_channels, - out_channels, - kernel_size=1, - bias=True), - ]) + self.last_conv_layers = torch.nn.ModuleList( + [ + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True), + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True), + ] + ) # apply weight norm if use_weight_norm: @@ -90,8 +87,8 @@ def __init__(self, def forward(self, c): """ - c: (B, C ,T'). - o: Output tensor (B, out_channels, T) + c: (B, C ,T'). + o: Output tensor (B, out_channels, T) """ # random noise x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale]) @@ -100,8 +97,9 @@ def forward(self, c): # perform upsampling if c is not None and self.upsample_net is not None: c = self.upsample_net(c) - assert c.shape[-1] == x.shape[ - -1], f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}" + assert ( + c.shape[-1] == x.shape[-1] + ), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}" # encode to hidden representation x = self.first_conv(x) @@ -121,8 +119,7 @@ def forward(self, c): @torch.no_grad() def inference(self, c): c = c.to(self.first_conv.weight.device) - c = torch.nn.functional.pad( - c, (self.inference_padding, self.inference_padding), 'replicate') + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") return self.forward(c) def remove_weight_norm(self): @@ -144,10 +141,7 @@ def _apply_weight_norm(m): self.apply(_apply_weight_norm) @staticmethod - def _get_receptive_field_size(layers, - stacks, - kernel_size, - dilation=lambda x: 2**x): + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x): assert layers % stacks == 0 layers_per_cycle = layers // stacks dilations = [dilation(i % layers_per_cycle) for i in range(layers)] @@ -155,12 +149,13 @@ def _get_receptive_field_size(layers, @property def receptive_field_size(self): - return self._get_receptive_field_size(self.layers, self.stacks, - self.kernel_size) + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/vocoder/models/random_window_discriminator.py b/TTS/vocoder/models/random_window_discriminator.py index 3efd395e1d..0a2e2887c8 100644 --- a/TTS/vocoder/models/random_window_discriminator.py +++ b/TTS/vocoder/models/random_window_discriminator.py @@ -13,20 +13,16 @@ def __init__(self, in_channels, cond_channels, downsample_factor): self.start = nn.Sequential( nn.AvgPool1d(downsample_factor, stride=downsample_factor), nn.ReLU(), - nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1)) - self.lc_conv1d = nn.Conv1d(cond_channels, - in_channels * 2, - kernel_size=1) + nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1), + ) + self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1) self.end = nn.Sequential( - nn.ReLU(), - nn.Conv1d(in_channels * 2, - in_channels * 2, - kernel_size=3, - dilation=2, - padding=2)) + nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2) + ) self.residual = nn.Sequential( nn.Conv1d(in_channels, in_channels * 2, kernel_size=1), - nn.AvgPool1d(downsample_factor, stride=downsample_factor)) + nn.AvgPool1d(downsample_factor, stride=downsample_factor), + ) def forward(self, inputs, conditions): outputs = self.start(inputs) + self.lc_conv1d(conditions) @@ -45,35 +41,27 @@ def __init__(self, in_channels, out_channels, downsample_factor): self.downsample_factor = downsample_factor self.out_channels = out_channels - self.donwsample_layer = nn.AvgPool1d(downsample_factor, - stride=downsample_factor) + self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor) self.layers = nn.Sequential( nn.ReLU(), nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(), - nn.Conv1d(out_channels, - out_channels, - kernel_size=3, - dilation=2, - padding=2)) + nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2), + ) self.residual = nn.Sequential( - nn.Conv1d(in_channels, out_channels, kernel_size=1), ) + nn.Conv1d(in_channels, out_channels, kernel_size=1), + ) def forward(self, inputs): if self.downsample_factor > 1: - outputs = self.layers(self.donwsample_layer(inputs))\ - + self.donwsample_layer(self.residual(inputs)) + outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs)) else: outputs = self.layers(inputs) + self.residual(inputs) return outputs class ConditionalDiscriminator(nn.Module): - def __init__(self, - in_channels, - cond_channels, - downsample_factors=(2, 2, 2), - out_channels=(128, 256)): + def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)): super(ConditionalDiscriminator, self).__init__() assert len(downsample_factors) == len(out_channels) + 1 @@ -90,13 +78,11 @@ def __init__(self, self.pre_cond_layers += [DBlock(in_channels, 64, 1)] in_channels = 64 for (i, channel) in enumerate(out_channels): - self.pre_cond_layers.append( - DBlock(in_channels, channel, downsample_factors[i])) + self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i])) in_channels = channel # condition block - self.cond_block = GBlock(in_channels, cond_channels, - downsample_factors[-1]) + self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1]) # layers after condition block self.post_cond_layers += [ @@ -119,11 +105,7 @@ def forward(self, inputs, conditions): class UnconditionalDiscriminator(nn.Module): - def __init__(self, - in_channels, - base_channels=64, - downsample_factors=(8, 4), - out_channels=(128, 256)): + def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)): super(UnconditionalDiscriminator, self).__init__() self.downsample_factors = downsample_factors @@ -155,15 +137,16 @@ def forward(self, inputs): class RandomWindowDiscriminator(nn.Module): """Random Window Discriminator as described in http://arxiv.org/abs/1909.11646""" - def __init__(self, - cond_channels, - hop_length, - uncond_disc_donwsample_factors=(8, 4), - cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), - (8, 4, 2), (8, 4), (4, 2, 2)), - cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), - (128, 256), (256, ), (128, 256)), - window_sizes=(512, 1024, 2048, 4096, 8192)): + + def __init__( + self, + cond_channels, + hop_length, + uncond_disc_donwsample_factors=(8, 4), + cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)), + cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)), + window_sizes=(512, 1024, 2048, 4096, 8192), + ): super(RandomWindowDiscriminator, self).__init__() self.cond_channels = cond_channels @@ -173,8 +156,7 @@ def __init__(self, self.ks = [ws // self.base_window_size for ws in window_sizes] # check arguments - assert len(cond_disc_downsample_factors) == len( - cond_disc_out_channels) == len(window_sizes) + assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes) for ws in window_sizes: assert ws % hop_length == 0 @@ -185,9 +167,8 @@ def __init__(self, self.unconditional_discriminators = nn.ModuleList([]) for k in self.ks: layer = UnconditionalDiscriminator( - in_channels=k, - base_channels=64, - downsample_factors=uncond_disc_donwsample_factors) + in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors + ) self.unconditional_discriminators.append(layer) self.conditional_discriminators = nn.ModuleList([]) @@ -196,29 +177,27 @@ def __init__(self, in_channels=k, cond_channels=cond_channels, downsample_factors=cond_disc_downsample_factors[idx], - out_channels=cond_disc_out_channels[idx]) + out_channels=cond_disc_out_channels[idx], + ) self.conditional_discriminators.append(layer) def forward(self, x, c): scores = [] feats = [] # unconditional pass - for (window_size, layer) in zip(self.window_sizes, - self.unconditional_discriminators): + for (window_size, layer) in zip(self.window_sizes, self.unconditional_discriminators): index = np.random.randint(x.shape[-1] - window_size) - score = layer(x[:, :, index:index + window_size]) + score = layer(x[:, :, index : index + window_size]) scores.append(score) # conditional pass - for (window_size, layer) in zip(self.window_sizes, - self.conditional_discriminators): + for (window_size, layer) in zip(self.window_sizes, self.conditional_discriminators): frame_size = window_size // self.hop_length lc_index = np.random.randint(c.shape[-1] - frame_size) sample_index = lc_index * self.hop_length - x_sub = x[:, :, - sample_index:(lc_index + frame_size) * self.hop_length] - c_sub = c[:, :, lc_index:lc_index + frame_size] + x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length] + c_sub = c[:, :, lc_index : lc_index + frame_size] score = layer(x_sub, c_sub) scores.append(score) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 96951ad180..d19c2a5ee9 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -3,22 +3,23 @@ from torch import nn from torch.nn.utils import weight_norm -from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d +from ..layers.wavegrad import Conv1d, DBlock, FiLM, UBlock class Wavegrad(nn.Module): # pylint: disable=dangerous-default-value - def __init__(self, - in_channels=80, - out_channels=1, - use_weight_norm=False, - y_conv_channels=32, - x_conv_channels=768, - dblock_out_channels=[128, 128, 256, 512], - ublock_out_channels=[512, 512, 256, 128, 128], - upsample_factors=[5, 5, 3, 2, 2], - upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], - [1, 2, 4, 8], [1, 2, 4, 8]]): + def __init__( + self, + in_channels=80, + out_channels=1, + use_weight_norm=False, + y_conv_channels=32, + x_conv_channels=768, + dblock_out_channels=[128, 128, 256, 512], + ublock_out_channels=[512, 512, 256, 128, 128], + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]], + ): super().__init__() self.use_weight_norm = use_weight_norm @@ -72,14 +73,13 @@ def forward(self, x, spectrogram, noise_scale): shift_and_scale.append(film(x, noise_scale)) x = self.x_conv(spectrogram) - for layer, (film_shift, film_scale) in zip(self.ublocks, - reversed(shift_and_scale)): + for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)): x = layer(x, film_shift, film_scale) x = self.out_conv(x) return x def load_noise_schedule(self, path): - beta = np.load(path, allow_pickle=True).item()['beta'] # pylint: disable=unexpected-keyword-arg + beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg self.compute_noise_level(beta) @torch.no_grad() @@ -91,26 +91,24 @@ def inference(self, x, y_n=None): y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x) sqrt_alpha_hat = self.noise_level.to(x) for n in range(len(self.alpha) - 1, -1, -1): - y_n = self.c1[n] * (y_n - self.c2[n] * self.forward( - y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) + y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) if n > 0: z = torch.randn_like(y_n) y_n += self.sigma[n - 1] * z y_n.clamp_(-1.0, 1.0) return y_n - def compute_y_n(self, y_0): """Compute noisy audio based on noise schedule""" self.noise_level = self.noise_level.to(y_0) if len(y_0.shape) == 3: y_0 = y_0.squeeze(1) s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]]) - l_a, l_b = self.noise_level[s], self.noise_level[s+1] + l_a, l_b = self.noise_level[s], self.noise_level[s + 1] noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) noise_scale = noise_scale.unsqueeze(1) noise = torch.randn_like(y_0) - noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise + noisy_audio = noise_scale * y_0 + (1.0 - noise_scale ** 2) ** 0.5 * noise return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] def compute_noise_level(self, beta): @@ -127,9 +125,9 @@ def compute_noise_level(self, beta): self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) self.noise_level = torch.tensor(noise_level.astype(np.float32)) - self.c1 = 1 / self.alpha**0.5 - self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5 - self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5 + self.c1 = 1 / self.alpha ** 0.5 + self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 + self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 def remove_weight_norm(self): for _, layer in enumerate(self.dblocks): @@ -146,7 +144,6 @@ def remove_weight_norm(self): except ValueError: layer.remove_weight_norm() - for _, layer in enumerate(self.ublocks): if len(layer.state_dict()) != 0: try: @@ -167,7 +164,6 @@ def apply_weight_norm(self): if len(layer.state_dict()) != 0: layer.apply_weight_norm() - for _, layer in enumerate(self.ublocks): if len(layer.state_dict()) != 0: layer.apply_weight_norm() @@ -176,21 +172,26 @@ def apply_weight_norm(self): self.out_conv = weight_norm(self.out_conv) self.y_conv = weight_norm(self.y_conv) - - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training if self.use_weight_norm: self.remove_weight_norm() - betas = np.linspace(config['test_noise_schedule']['min_val'], - config['test_noise_schedule']['max_val'], - config['test_noise_schedule']['num_steps']) + betas = np.linspace( + config["test_noise_schedule"]["min_val"], + config["test_noise_schedule"]["max_val"], + config["test_noise_schedule"]["num_steps"], + ) self.compute_noise_level(betas) else: - betas = np.linspace(config['train_noise_schedule']['min_val'], - config['train_noise_schedule']['max_val'], - config['train_noise_schedule']['num_steps']) + betas = np.linspace( + config["train_noise_schedule"]["min_val"], + config["train_noise_schedule"]["max_val"], + config["train_noise_schedule"]["num_steps"], + ) self.compute_noise_level(betas) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index dbcaea660b..994244dc31 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -1,21 +1,20 @@ import sys +import time + +import numpy as np import torch import torch.nn as nn -import numpy as np import torch.nn.functional as F -import time # fix this from TTS.utils.audio import AudioProcessor as ap -from TTS.vocoder.utils.distribution import ( - sample_from_gaussian, - sample_from_discretized_mix_logistic, -) +from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian def stream(string, variables): sys.stdout.write(f"\r{string}" % variables) + # pylint: disable=abstract-method # relates https://github.com/pytorch/pytorch/issues/42305 class ResBlock(nn.Module): @@ -40,8 +39,7 @@ class MelResNet(nn.Module): def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad): super().__init__() k_size = pad * 2 + 1 - self.conv_in = nn.Conv1d( - in_dims, compute_dims, kernel_size=k_size, bias=False) + self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) self.batch_norm = nn.BatchNorm1d(compute_dims) self.layers = nn.ModuleList() for _ in range(num_res_blocks): @@ -73,31 +71,28 @@ def forward(self, x): class UpsampleNetwork(nn.Module): def __init__( - self, - feat_dims, - upsample_scales, - compute_dims, - num_res_blocks, - res_out_dims, - pad, - use_aux_net, - ): + self, + feat_dims, + upsample_scales, + compute_dims, + num_res_blocks, + res_out_dims, + pad, + use_aux_net, + ): super().__init__() self.total_scale = np.cumproduct(upsample_scales)[-1] self.indent = pad * self.total_scale self.use_aux_net = use_aux_net if use_aux_net: - self.resnet = MelResNet( - num_res_blocks, feat_dims, compute_dims, res_out_dims, pad - ) + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) self.resnet_stretch = Stretch2d(self.total_scale, 1) self.up_layers = nn.ModuleList() for scale in upsample_scales: k_size = (1, scale * 2 + 1) padding = (0, scale) stretch = Stretch2d(scale, 1) - conv = nn.Conv2d(1, 1, kernel_size=k_size, - padding=padding, bias=False) + conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) conv.weight.data.fill_(1.0 / k_size[1]) self.up_layers.append(stretch) self.up_layers.append(conv) @@ -113,56 +108,51 @@ def forward(self, m): m = m.unsqueeze(1) for f in self.up_layers: m = f(m) - m = m.squeeze(1)[:, :, self.indent: -self.indent] + m = m.squeeze(1)[:, :, self.indent : -self.indent] return m.transpose(1, 2), aux class Upsample(nn.Module): - def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, - res_out_dims, use_aux_net): + def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net): super().__init__() self.scale = scale self.pad = pad self.indent = pad * scale self.use_aux_net = use_aux_net - self.resnet = MelResNet(num_res_blocks, feat_dims, - compute_dims, res_out_dims, pad) + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) def forward(self, m): if self.use_aux_net: aux = self.resnet(m) - aux = torch.nn.functional.interpolate( - aux, scale_factor=self.scale, mode="linear", align_corners=True - ) + aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True) aux = aux.transpose(1, 2) else: aux = None - m = torch.nn.functional.interpolate( - m, scale_factor=self.scale, mode="linear", align_corners=True - ) - m = m[:, :, self.indent: -self.indent] + m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True) + m = m[:, :, self.indent : -self.indent] m = m * 0.045 # empirically found return m.transpose(1, 2), aux class WaveRNN(nn.Module): - def __init__(self, - rnn_dims, - fc_dims, - mode, - mulaw, - pad, - use_aux_net, - use_upsample_net, - upsample_factors, - feat_dims, - compute_dims, - res_out_dims, - num_res_blocks, - hop_length, - sample_rate, - ): + def __init__( + self, + rnn_dims, + fc_dims, + mode, + mulaw, + pad, + use_aux_net, + use_upsample_net, + upsample_factors, + feat_dims, + compute_dims, + res_out_dims, + num_res_blocks, + hop_length, + sample_rate, + ): super().__init__() self.mode = mode self.mulaw = mulaw @@ -209,8 +199,7 @@ def __init__(self, if self.use_aux_net: self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) - self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, - rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) self.fc3 = nn.Linear(fc_dims, self.n_classes) @@ -230,10 +219,10 @@ def forward(self, x, mels): if self.use_aux_net: aux_idx = [self.aux_dims * i for i in range(5)] - a1 = aux[:, :, aux_idx[0]: aux_idx[1]] - a2 = aux[:, :, aux_idx[1]: aux_idx[2]] - a3 = aux[:, :, aux_idx[2]: aux_idx[3]] - a4 = aux[:, :, aux_idx[3]: aux_idx[4]] + a1 = aux[:, :, aux_idx[0] : aux_idx[1]] + a2 = aux[:, :, aux_idx[1] : aux_idx[2]] + a3 = aux[:, :, aux_idx[2] : aux_idx[3]] + a4 = aux[:, :, aux_idx[3] : aux_idx[4]] x = ( torch.cat([x.unsqueeze(-1), mels, a1], dim=2) @@ -276,8 +265,7 @@ def inference(self, mels, batched=None, target=None, overlap=None): mels = mels.unsqueeze(0) wave_len = (mels.size(-1) - 1) * self.hop_length - mels = self.pad_tensor(mels.transpose( - 1, 2), pad=self.pad, side="both") + mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both") mels, aux = self.upsample(mels.transpose(1, 2)) if batched: @@ -293,7 +281,7 @@ def inference(self, mels, batched=None, target=None, overlap=None): if self.use_aux_net: d = self.aux_dims - aux_split = [aux[:, :, d * i: d * (i + 1)] for i in range(4)] + aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)] for i in range(seq_len): @@ -302,11 +290,7 @@ def inference(self, mels, batched=None, target=None, overlap=None): if self.use_aux_net: a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) - x = ( - torch.cat([x, m_t, a1_t], dim=1) - if self.use_aux_net - else torch.cat([x, m_t], dim=1) - ) + x = torch.cat([x, m_t, a1_t], dim=1) if self.use_aux_net else torch.cat([x, m_t], dim=1) x = self.I(x) h1 = rnn1(x, h1) @@ -324,14 +308,11 @@ def inference(self, mels, batched=None, target=None, overlap=None): logits = self.fc3(x) if self.mode == "mold": - sample = sample_from_discretized_mix_logistic( - logits.unsqueeze(0).transpose(1, 2) - ) + sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) x = sample.transpose(0, 1).to(device) elif self.mode == "gauss": - sample = sample_from_gaussian( - logits.unsqueeze(0).transpose(1, 2)) + sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) x = sample.transpose(0, 1).to(device) elif isinstance(self.mode, int): @@ -342,8 +323,7 @@ def inference(self, mels, batched=None, target=None, overlap=None): output.append(sample) x = sample.unsqueeze(-1) else: - raise RuntimeError( - "Unknown model mode value - ", self.mode) + raise RuntimeError("Unknown model mode value - ", self.mode) if i % 100 == 0: self.gen_display(i, seq_len, b_size, start) @@ -366,7 +346,7 @@ def inference(self, mels, batched=None, target=None, overlap=None): output = output[:wave_len] if wave_len > len(fade_out): - output[-20 * self.hop_length:] *= fade_out + output[-20 * self.hop_length :] *= fade_out self.train() return output @@ -411,8 +391,7 @@ def fold_with_overlap(self, x, target, overlap): padding = target + 2 * overlap - remaining x = self.pad_tensor(x, padding, side="after") - folded = torch.zeros(num_folds, target + 2 * - overlap, features).to(x.device) + folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device) # Get the values for the folded tensor for i in range(num_folds): @@ -439,7 +418,7 @@ def pad_tensor(x, pad, side="both"): total = t + 2 * pad if side == "both" else t + pad padded = torch.zeros(b, total, c).to(x.device) if side in ("before", "both"): - padded[:, pad: pad + t, :] = x + padded[:, pad : pad + t, :] = x elif side == "after": padded[:, :t, :] = x return padded @@ -500,9 +479,11 @@ def xfade_and_unfold(y, target, overlap): return unfolded - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/vocoder/tf/layers/melgan.py b/TTS/vocoder/tf/layers/melgan.py index 34b25d65cf..36d8724c27 100644 --- a/TTS/vocoder/tf/layers/melgan.py +++ b/TTS/vocoder/tf/layers/melgan.py @@ -21,29 +21,27 @@ def __init__(self, channels, num_res_blocks, kernel_size, name): num_layers = 2 for idx in range(num_res_blocks): layer_kernel_size = kernel_size - layer_dilation = layer_kernel_size**idx + layer_dilation = layer_kernel_size ** idx layer_padding = base_padding * layer_dilation block = [ tf.keras.layers.LeakyReLU(0.2), ReflectionPad1d(layer_padding), - tf.keras.layers.Conv2D(filters=channels, - kernel_size=(kernel_size, 1), - dilation_rate=(layer_dilation, 1), - use_bias=True, - padding='valid', - name=f'blocks.{idx}.{num_layers}'), + tf.keras.layers.Conv2D( + filters=channels, + kernel_size=(kernel_size, 1), + dilation_rate=(layer_dilation, 1), + use_bias=True, + padding="valid", + name=f"blocks.{idx}.{num_layers}", + ), tf.keras.layers.LeakyReLU(0.2), - tf.keras.layers.Conv2D(filters=channels, - kernel_size=(1, 1), - use_bias=True, - name=f'blocks.{idx}.{num_layers + 2}') + tf.keras.layers.Conv2D( + filters=channels, kernel_size=(1, 1), use_bias=True, name=f"blocks.{idx}.{num_layers + 2}" + ), ] self.blocks.append(block) self.shortcuts = [ - tf.keras.layers.Conv2D(channels, - kernel_size=1, - use_bias=True, - name=f'shortcuts.{i}') + tf.keras.layers.Conv2D(channels, kernel_size=1, use_bias=True, name=f"shortcuts.{i}") for i in range(num_res_blocks) ] diff --git a/TTS/vocoder/tf/layers/pqmf.py b/TTS/vocoder/tf/layers/pqmf.py index c018971f3e..06800b6981 100644 --- a/TTS/vocoder/tf/layers/pqmf.py +++ b/TTS/vocoder/tf/layers/pqmf.py @@ -1,6 +1,5 @@ import numpy as np import tensorflow as tf - from scipy import signal as sig @@ -13,21 +12,19 @@ def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): self.cutoff = cutoff self.beta = beta - QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta)) + QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta)) H = np.zeros((N, len(QMF))) G = np.zeros((N, len(QMF))) for k in range(N): - constant_factor = (2 * k + 1) * (np.pi / - (2 * N)) * (np.arange(taps + 1) - - ((taps - 1) / 2)) - phase = (-1)**k * np.pi / 4 + constant_factor = (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + phase = (-1) ** k * np.pi / 4 H[k] = 2 * QMF * np.cos(constant_factor + phase) G[k] = 2 * QMF * np.cos(constant_factor - phase) # [N, 1, taps + 1] == [filter_width, in_channels, out_channels] - self.H = np.transpose(H[:, None, :], (2, 1, 0)).astype('float32') - self.G = np.transpose(G[None, :, :], (2, 1, 0)).astype('float32') + self.H = np.transpose(H[:, None, :], (2, 1, 0)).astype("float32") + self.G = np.transpose(G[None, :, :], (2, 1, 0)).astype("float32") # filter for downsampling & upsampling updown_filter = np.zeros((N, N, N), dtype=np.float32) @@ -41,11 +38,8 @@ def analysis(self, x): """ x = tf.transpose(x, perm=[0, 2, 1]) x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0) - x = tf.nn.conv1d(x, self.H, stride=1, padding='VALID') - x = tf.nn.conv1d(x, - self.updown_filter, - stride=self.N, - padding='VALID') + x = tf.nn.conv1d(x, self.H, stride=1, padding="VALID") + x = tf.nn.conv1d(x, self.updown_filter, stride=self.N, padding="VALID") x = tf.transpose(x, perm=[0, 2, 1]) return x @@ -58,8 +52,8 @@ def synthesis(self, x): x, self.updown_filter * self.N, strides=self.N, - output_shape=(tf.shape(x)[0], tf.shape(x)[1] * self.N, - self.N)) + output_shape=(tf.shape(x)[0], tf.shape(x)[1] * self.N, self.N), + ) x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0) x = tf.nn.conv1d(x, self.G, stride=1, padding="VALID") x = tf.transpose(x, perm=[0, 2, 1]) diff --git a/TTS/vocoder/tf/models/melgan_generator.py b/TTS/vocoder/tf/models/melgan_generator.py index 9a029df45a..eb8b6eecaa 100644 --- a/TTS/vocoder/tf/models/melgan_generator.py +++ b/TTS/vocoder/tf/models/melgan_generator.py @@ -1,33 +1,36 @@ import logging import os -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL -logging.getLogger('tensorflow').setLevel(logging.FATAL) +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # FATAL +logging.getLogger("tensorflow").setLevel(logging.FATAL) import tensorflow as tf -from TTS.vocoder.tf.layers.melgan import ResidualStack, ReflectionPad1d +from TTS.vocoder.tf.layers.melgan import ReflectionPad1d, ResidualStack -#pylint: disable=too-many-ancestors -#pylint: disable=abstract-method + +# pylint: disable=too-many-ancestors +# pylint: disable=abstract-method class MelganGenerator(tf.keras.models.Model): - """ Melgan Generator TF implementation dedicated for inference with no - weight norm """ - def __init__(self, - in_channels=80, - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=(8, 8, 2, 2), - res_kernel=3, - num_res_blocks=3): + """Melgan Generator TF implementation dedicated for inference with no + weight norm""" + + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(8, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): super(MelganGenerator, self).__init__() self.in_channels = in_channels # assert model parameters - assert (proj_kernel - - 1) % 2 == 0, " [!] proj_kernel should be an odd number." + assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." # setup additional model parameters base_padding = (proj_kernel - 1) // 2 @@ -37,19 +40,16 @@ def __init__(self, # initial layer self.initial_layer = [ ReflectionPad1d(base_padding), - tf.keras.layers.Conv2D(filters=base_channels, - kernel_size=(proj_kernel, 1), - strides=1, - padding='valid', - use_bias=True, - name="1") + tf.keras.layers.Conv2D( + filters=base_channels, kernel_size=(proj_kernel, 1), strides=1, padding="valid", use_bias=True, name="1" + ), ] num_layers = 3 # count number of layers for layer naming # upsampling layers and residual stacks self.upsample_layers = [] for idx, upsample_factor in enumerate(upsample_factors): - layer_out_channels = base_channels // (2**(idx + 1)) + layer_out_channels = base_channels // (2 ** (idx + 1)) layer_filter_size = upsample_factor * 2 layer_stride = upsample_factor # layer_output_padding = upsample_factor % 2 @@ -59,14 +59,17 @@ def __init__(self, filters=layer_out_channels, kernel_size=(layer_filter_size, 1), strides=(layer_stride, 1), - padding='same', + padding="same", # output_padding=layer_output_padding, use_bias=True, - name=f'{num_layers}'), - ResidualStack(channels=layer_out_channels, - num_res_blocks=num_res_blocks, - kernel_size=res_kernel, - name=f'layers.{num_layers + 1}') + name=f"{num_layers}", + ), + ResidualStack( + channels=layer_out_channels, + num_res_blocks=num_res_blocks, + kernel_size=res_kernel, + name=f"layers.{num_layers + 1}", + ), ] num_layers += num_res_blocks - 1 @@ -75,11 +78,10 @@ def __init__(self, # final layer self.final_layers = [ ReflectionPad1d(base_padding), - tf.keras.layers.Conv2D(filters=out_channels, - kernel_size=(proj_kernel, 1), - use_bias=True, - name=f'layers.{num_layers + 1}'), - tf.keras.layers.Activation("tanh") + tf.keras.layers.Conv2D( + filters=out_channels, kernel_size=(proj_kernel, 1), use_bias=True, name=f"layers.{num_layers + 1}" + ), + tf.keras.layers.Activation("tanh"), ] # self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers") @@ -114,7 +116,8 @@ def build_inference(self): experimental_relax_shapes=True, input_signature=[ tf.TensorSpec([1, None, None], dtype=tf.float32), - ],) + ], + ) def inference_tflite(self, c): c = tf.transpose(c, perm=[0, 2, 1]) c = tf.expand_dims(c, 2) diff --git a/TTS/vocoder/tf/models/multiband_melgan_generator.py b/TTS/vocoder/tf/models/multiband_melgan_generator.py index bdd333ed3a..51d5cbc3ac 100644 --- a/TTS/vocoder/tf/models/multiband_melgan_generator.py +++ b/TTS/vocoder/tf/models/multiband_melgan_generator.py @@ -1,27 +1,31 @@ import tensorflow as tf -from TTS.vocoder.tf.models.melgan_generator import MelganGenerator from TTS.vocoder.tf.layers.pqmf import PQMF +from TTS.vocoder.tf.models.melgan_generator import MelganGenerator + -#pylint: disable=too-many-ancestors -#pylint: disable=abstract-method +# pylint: disable=too-many-ancestors +# pylint: disable=abstract-method class MultibandMelganGenerator(MelganGenerator): - def __init__(self, - in_channels=80, - out_channels=4, - proj_kernel=7, - base_channels=384, - upsample_factors=(2, 8, 2, 2), - res_kernel=3, - num_res_blocks=3): - super(MultibandMelganGenerator, - self).__init__(in_channels=in_channels, - out_channels=out_channels, - proj_kernel=proj_kernel, - base_channels=base_channels, - upsample_factors=upsample_factors, - res_kernel=res_kernel, - num_res_blocks=num_res_blocks) + def __init__( + self, + in_channels=80, + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super(MultibandMelganGenerator, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) def pqmf_analysis(self, x): @@ -46,7 +50,8 @@ def inference(self, c): experimental_relax_shapes=True, input_signature=[ tf.TensorSpec([1, 80, None], dtype=tf.float32), - ],) + ], + ) def inference_tflite(self, c): c = tf.transpose(c, perm=[0, 2, 1]) c = tf.expand_dims(c, 2) diff --git a/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py b/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py index 25139cc3ce..5e0427b183 100644 --- a/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py +++ b/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py @@ -10,14 +10,14 @@ def compare_torch_tf(torch_tensor, tf_tensor): def convert_tf_name(tf_name): """ Convert certain patterns in TF layer names to Torch patterns """ tf_name_tmp = tf_name - tf_name_tmp = tf_name_tmp.replace(':0', '') - tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0') - tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1') - tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh') - tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight') - tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight') - tf_name_tmp = tf_name_tmp.replace('/beta', '/bias') - tf_name_tmp = tf_name_tmp.replace('/', '.') + tf_name_tmp = tf_name_tmp.replace(":0", "") + tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0") + tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1") + tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh") + tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight") + tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight") + tf_name_tmp = tf_name_tmp.replace("/beta", "/bias") + tf_name_tmp = tf_name_tmp.replace("/", ".") return tf_name_tmp @@ -26,15 +26,17 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): print(" > Passing weights from Torch to TF ...") for tf_var in tf_vars: torch_var_name = var_map_dict[tf_var.name] - print(f' | > {tf_var.name} <-- {torch_var_name}') + print(f" | > {tf_var.name} <-- {torch_var_name}") # if tuple, it is a bias variable - if 'kernel' in tf_var.name: + if "kernel" in tf_var.name: torch_weight = state_dict[torch_var_name] numpy_weight = torch_weight.permute([2, 1, 0]).numpy()[:, None, :, :] - if 'bias' in tf_var.name: + if "bias" in tf_var.name: torch_weight = state_dict[torch_var_name] numpy_weight = torch_weight - assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" + assert np.all( + tf_var.shape == numpy_weight.shape + ), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" tf.keras.backend.set_value(tf_var, numpy_weight) return tf_vars diff --git a/TTS/vocoder/tf/utils/generic_utils.py b/TTS/vocoder/tf/utils/generic_utils.py index 0daf2d6e13..94364ab478 100644 --- a/TTS/vocoder/tf/utils/generic_utils.py +++ b/TTS/vocoder/tf/utils/generic_utils.py @@ -1,35 +1,36 @@ -import re import importlib +import re def to_camel(text): text = text.capitalize() - return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) def setup_generator(c): print(" > Generator Model: {}".format(c.generator_model)) - MyModel = importlib.import_module('TTS.vocoder.tf.models.' + - c.generator_model.lower()) + MyModel = importlib.import_module("TTS.vocoder.tf.models." + c.generator_model.lower()) MyModel = getattr(MyModel, to_camel(c.generator_model)) - if c.generator_model in 'melgan_generator': + if c.generator_model in "melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=1, proj_kernel=7, base_channels=512, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) - if c.generator_model in 'melgan_fb_generator': + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + if c.generator_model in "melgan_fb_generator": pass - if c.generator_model in 'multiband_melgan_generator': + if c.generator_model in "multiband_melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=4, proj_kernel=7, base_channels=384, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) return model diff --git a/TTS/vocoder/tf/utils/io.py b/TTS/vocoder/tf/utils/io.py index c73c9cd86a..4f69ad27bf 100644 --- a/TTS/vocoder/tf/utils/io.py +++ b/TTS/vocoder/tf/utils/io.py @@ -1,24 +1,25 @@ import datetime import pickle + import tensorflow as tf def save_checkpoint(model, current_step, epoch, output_path, **kwargs): """ Save TF Vocoder model """ state = { - 'model': model.weights, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), + "model": model.weights, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), } state.update(kwargs) - pickle.dump(state, open(output_path, 'wb')) + pickle.dump(state, open(output_path, "wb")) def load_checkpoint(model, checkpoint_path): """ Load TF Vocoder model """ - checkpoint = pickle.load(open(checkpoint_path, 'rb')) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} + checkpoint = pickle.load(open(checkpoint_path, "rb")) + chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: layer_name = tf_var.name diff --git a/TTS/vocoder/tf/utils/tflite.py b/TTS/vocoder/tf/utils/tflite.py index d62a081a24..e0c630b9ed 100644 --- a/TTS/vocoder/tf/utils/tflite.py +++ b/TTS/vocoder/tf/utils/tflite.py @@ -1,25 +1,20 @@ import tensorflow as tf -def convert_melgan_to_tflite(model, - output_path=None, - experimental_converter=True): +def convert_melgan_to_tflite(model, output_path=None, experimental_converter=True): """Convert Tensorflow MelGAN model to TFLite. Save a binary file if output_path is provided, else return TFLite model.""" concrete_function = model.inference_tflite.get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_function]) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function]) converter.experimental_new_converter = experimental_converter converter.optimizations = [] - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS - ] + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] tflite_model = converter.convert() - print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.') + print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") if output_path is not None: # same model binary if outputpath is provided - with open(output_path, 'wb') as f: + with open(output_path, "wb") as f: f.write(tflite_model) return None return tflite_model diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py index b0553ed0d9..9b37aa12c0 100644 --- a/TTS/vocoder/utils/distribution.py +++ b/TTS/vocoder/utils/distribution.py @@ -1,8 +1,9 @@ -import numpy as np import math + +import numpy as np import torch -from torch.distributions.normal import Normal import torch.nn.functional as F +from torch.distributions.normal import Normal def gaussian_loss(y_hat, y, log_std_min=-7.0): @@ -11,11 +12,7 @@ def gaussian_loss(y_hat, y, log_std_min=-7.0): mean = y_hat[:, :, :1] log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) # TODO: replace with pytorch dist - log_probs = -0.5 * ( - -math.log(2.0 * math.pi) - - 2.0 * log_std - - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)) - ) + log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std))) return log_probs.squeeze().mean() @@ -28,8 +25,7 @@ def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0): torch.exp(log_std), ) sample = dist.sample() - sample = torch.clamp(torch.clamp( - sample, min=-scale_factor), max=scale_factor) + sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor) del dist return sample @@ -44,11 +40,7 @@ def log_sum_exp(x): # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py -def discretized_mix_logistic_loss(y_hat, - y, - num_classes=65536, - log_scale_min=None, - reduce=True): +def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): if log_scale_min is None: log_scale_min = float(np.log(1e-14)) y_hat = y_hat.permute(0, 2, 1) @@ -61,9 +53,8 @@ def discretized_mix_logistic_loss(y_hat, # unpack parameters. (B, T, num_mixtures) x 3 logit_probs = y_hat[:, :, :nr_mix] - means = y_hat[:, :, nr_mix: 2 * nr_mix] - log_scales = torch.clamp( - y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=log_scale_min) + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) # B x T x 1 -> B x T x num_mixtures y = y.expand_as(means) @@ -103,14 +94,11 @@ def discretized_mix_logistic_loss(y_hat, # for num_classes=65536 case? 1e-7? not sure.. inner_inner_cond = (cdf_delta > 1e-5).float() - inner_inner_out = inner_inner_cond * torch.log( - torch.clamp(cdf_delta, min=1e-12) - ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) - inner_cond = (y > 0.999).float() - inner_out = ( - inner_cond * log_one_minus_cdf_min + - (1.0 - inner_cond) * inner_inner_out + inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * ( + log_pdf_mid - np.log((num_classes - 1) / 2) ) + inner_cond = (y > 0.999).float() + inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out cond = (y < -0.999).float() log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out @@ -147,10 +135,8 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None): # (B, T) -> (B, T, nr_mix) one_hot = to_one_hot(argmax, nr_mix) # select logistic parameters - means = torch.sum(y[:, :, nr_mix: 2 * nr_mix] * one_hot, dim=-1) - log_scales = torch.clamp( - torch.sum(y[:, :, 2 * nr_mix: 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min - ) + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) # sample from logistic & clip to interval # we don't actually round to the nearest 8bit value when sampling u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 77386d306b..04c72b02ab 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -1,7 +1,8 @@ -import re -import torch import importlib +import re + import numpy as np +import torch from matplotlib import pyplot as plt from TTS.tts.utils.visual import plot_spectrogram @@ -21,11 +22,9 @@ def interpolate_vocoder_input(scale_factor, spec): """ print(" > before interpolation :", spec.shape) spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable - spec = torch.nn.functional.interpolate(spec, - scale_factor=scale_factor, - recompute_scale_factor=True, - mode='bilinear', - align_corners=False).squeeze(0) + spec = torch.nn.functional.interpolate( + spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False + ).squeeze(0) print(" > after interpolation :", spec.shape) return spec @@ -63,139 +62,135 @@ def plot_results(y_hat, y, ap, global_step, name_prefix): def to_camel(text): text = text.capitalize() - return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) def setup_generator(c): print(" > Generator Model: {}".format(c.generator_model)) - MyModel = importlib.import_module('TTS.vocoder.models.' + - c.generator_model.lower()) + MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) # this is to preserve the WaveRNN class name (instead of Wavernn) - if c.generator_model.lower() == 'wavernn': - MyModel = getattr(MyModel, 'WaveRNN') + if c.generator_model.lower() == "wavernn": + MyModel = getattr(MyModel, "WaveRNN") else: MyModel = getattr(MyModel, to_camel(c.generator_model)) - if c.generator_model.lower() in 'wavernn': + if c.generator_model.lower() in "wavernn": model = MyModel( - rnn_dims=c.wavernn_model_params['rnn_dims'], - fc_dims=c.wavernn_model_params['fc_dims'], + rnn_dims=c.wavernn_model_params["rnn_dims"], + fc_dims=c.wavernn_model_params["fc_dims"], mode=c.mode, mulaw=c.mulaw, pad=c.padding, - use_aux_net=c.wavernn_model_params['use_aux_net'], - use_upsample_net=c.wavernn_model_params['use_upsample_net'], - upsample_factors=c.wavernn_model_params['upsample_factors'], - feat_dims=c.audio['num_mels'], - compute_dims=c.wavernn_model_params['compute_dims'], - res_out_dims=c.wavernn_model_params['res_out_dims'], - num_res_blocks=c.wavernn_model_params['num_res_blocks'], + use_aux_net=c.wavernn_model_params["use_aux_net"], + use_upsample_net=c.wavernn_model_params["use_upsample_net"], + upsample_factors=c.wavernn_model_params["upsample_factors"], + feat_dims=c.audio["num_mels"], + compute_dims=c.wavernn_model_params["compute_dims"], + res_out_dims=c.wavernn_model_params["res_out_dims"], + num_res_blocks=c.wavernn_model_params["num_res_blocks"], hop_length=c.audio["hop_length"], - sample_rate=c.audio["sample_rate"],) - elif c.generator_model.lower() in 'hifigan_generator': - model = MyModel( - in_channels=c.audio['num_mels'], - out_channels=1, - **c.generator_model_params) - elif c.generator_model.lower() in 'melgan_generator': + sample_rate=c.audio["sample_rate"], + ) + elif c.generator_model.lower() in "hifigan_generator": + model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params) + elif c.generator_model.lower() in "melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=1, proj_kernel=7, base_channels=512, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) - elif c.generator_model in 'melgan_fb_generator': - raise ValueError( - 'melgan_fb_generator is now fullband_melgan_generator') - elif c.generator_model.lower() in 'multiband_melgan_generator': + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model in "melgan_fb_generator": + raise ValueError("melgan_fb_generator is now fullband_melgan_generator") + elif c.generator_model.lower() in "multiband_melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=4, proj_kernel=7, base_channels=384, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) - elif c.generator_model.lower() in 'fullband_melgan_generator': + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "fullband_melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=1, proj_kernel=7, base_channels=512, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) - elif c.generator_model.lower() in 'parallel_wavegan_generator': + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "parallel_wavegan_generator": model = MyModel( in_channels=1, out_channels=1, kernel_size=3, - num_res_blocks=c.generator_model_params['num_res_blocks'], - stacks=c.generator_model_params['stacks'], + num_res_blocks=c.generator_model_params["num_res_blocks"], + stacks=c.generator_model_params["stacks"], res_channels=64, gate_channels=128, skip_channels=64, - aux_channels=c.audio['num_mels'], + aux_channels=c.audio["num_mels"], dropout=0.0, bias=True, use_weight_norm=True, - upsample_factors=c.generator_model_params['upsample_factors']) - elif c.generator_model.lower() in 'wavegrad': + upsample_factors=c.generator_model_params["upsample_factors"], + ) + elif c.generator_model.lower() in "wavegrad": model = MyModel( - in_channels=c['audio']['num_mels'], + in_channels=c["audio"]["num_mels"], out_channels=1, - use_weight_norm=c['model_params']['use_weight_norm'], - x_conv_channels=c['model_params']['x_conv_channels'], - y_conv_channels=c['model_params']['y_conv_channels'], - dblock_out_channels=c['model_params']['dblock_out_channels'], - ublock_out_channels=c['model_params']['ublock_out_channels'], - upsample_factors=c['model_params']['upsample_factors'], - upsample_dilations=c['model_params']['upsample_dilations']) + use_weight_norm=c["model_params"]["use_weight_norm"], + x_conv_channels=c["model_params"]["x_conv_channels"], + y_conv_channels=c["model_params"]["y_conv_channels"], + dblock_out_channels=c["model_params"]["dblock_out_channels"], + ublock_out_channels=c["model_params"]["ublock_out_channels"], + upsample_factors=c["model_params"]["upsample_factors"], + upsample_dilations=c["model_params"]["upsample_dilations"], + ) else: - raise NotImplementedError( - f'Model {c.generator_model} not implemented!') + raise NotImplementedError(f"Model {c.generator_model} not implemented!") return model def setup_discriminator(c): print(" > Discriminator Model: {}".format(c.discriminator_model)) - if 'parallel_wavegan' in c.discriminator_model: - MyModel = importlib.import_module( - 'TTS.vocoder.models.parallel_wavegan_discriminator') + if "parallel_wavegan" in c.discriminator_model: + MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") else: - MyModel = importlib.import_module('TTS.vocoder.models.' + - c.discriminator_model.lower()) + MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) - if c.discriminator_model in 'hifigan_discriminator': + if c.discriminator_model in "hifigan_discriminator": model = MyModel() - if c.discriminator_model in 'random_window_discriminator': + if c.discriminator_model in "random_window_discriminator": model = MyModel( - cond_channels=c.audio['num_mels'], - hop_length=c.audio['hop_length'], - uncond_disc_donwsample_factors=c. - discriminator_model_params['uncond_disc_donwsample_factors'], - cond_disc_downsample_factors=c. - discriminator_model_params['cond_disc_downsample_factors'], - cond_disc_out_channels=c. - discriminator_model_params['cond_disc_out_channels'], - window_sizes=c.discriminator_model_params['window_sizes']) - if c.discriminator_model in 'melgan_multiscale_discriminator': + cond_channels=c.audio["num_mels"], + hop_length=c.audio["hop_length"], + uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], + cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], + cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], + window_sizes=c.discriminator_model_params["window_sizes"], + ) + if c.discriminator_model in "melgan_multiscale_discriminator": model = MyModel( in_channels=1, out_channels=1, kernel_sizes=(5, 3), - base_channels=c.discriminator_model_params['base_channels'], - max_channels=c.discriminator_model_params['max_channels'], - downsample_factors=c. - discriminator_model_params['downsample_factors']) - if c.discriminator_model == 'residual_parallel_wavegan_discriminator': + base_channels=c.discriminator_model_params["base_channels"], + max_channels=c.discriminator_model_params["max_channels"], + downsample_factors=c.discriminator_model_params["downsample_factors"], + ) + if c.discriminator_model == "residual_parallel_wavegan_discriminator": model = MyModel( in_channels=1, out_channels=1, kernel_size=3, - num_layers=c.discriminator_model_params['num_layers'], - stacks=c.discriminator_model_params['stacks'], + num_layers=c.discriminator_model_params["num_layers"], + stacks=c.discriminator_model_params["stacks"], res_channels=64, gate_channels=128, skip_channels=64, @@ -204,17 +199,17 @@ def setup_discriminator(c): nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, ) - if c.discriminator_model == 'parallel_wavegan_discriminator': + if c.discriminator_model == "parallel_wavegan_discriminator": model = MyModel( in_channels=1, out_channels=1, kernel_size=3, - num_layers=c.discriminator_model_params['num_layers'], + num_layers=c.discriminator_model_params["num_layers"], conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, - bias=True + bias=True, ) return model diff --git a/TTS/vocoder/utils/io.py b/TTS/vocoder/utils/io.py index f3bc9bad6d..9c67535f14 100644 --- a/TTS/vocoder/utils/io.py +++ b/TTS/vocoder/utils/io.py @@ -1,19 +1,20 @@ -import os -import glob -import torch import datetime +import glob +import os import pickle as pickle_tts +import torch + from TTS.utils.io import RenamingUnpickler def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin try: - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) except ModuleNotFoundError: pickle_tts.Unpickler = RenamingUnpickler - state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts) - model.load_state_dict(state['model']) + state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + model.load_state_dict(state["model"]) if use_cuda: model.cuda() if eval: @@ -21,76 +22,104 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli return model, state -def save_model(model, optimizer, scheduler, model_disc, optimizer_disc, - scheduler_disc, current_step, epoch, output_path, **kwargs): - if hasattr(model, 'module'): +def save_model( + model, optimizer, scheduler, model_disc, optimizer_disc, scheduler_disc, current_step, epoch, output_path, **kwargs +): + if hasattr(model, "module"): model_state = model.module.state_dict() else: model_state = model.state_dict() - model_disc_state = model_disc.state_dict()\ - if model_disc is not None else None - optimizer_state = optimizer.state_dict()\ - if optimizer is not None else None - optimizer_disc_state = optimizer_disc.state_dict()\ - if optimizer_disc is not None else None - scheduler_state = scheduler.state_dict()\ - if scheduler is not None else None - scheduler_disc_state = scheduler_disc.state_dict()\ - if scheduler_disc is not None else None + model_disc_state = model_disc.state_dict() if model_disc is not None else None + optimizer_state = optimizer.state_dict() if optimizer is not None else None + optimizer_disc_state = optimizer_disc.state_dict() if optimizer_disc is not None else None + scheduler_state = scheduler.state_dict() if scheduler is not None else None + scheduler_disc_state = scheduler_disc.state_dict() if scheduler_disc is not None else None state = { - 'model': model_state, - 'optimizer': optimizer_state, - 'scheduler': scheduler_state, - 'model_disc': model_disc_state, - 'optimizer_disc': optimizer_disc_state, - 'scheduler_disc': scheduler_disc_state, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), + "model": model_state, + "optimizer": optimizer_state, + "scheduler": scheduler_state, + "model_disc": model_disc_state, + "optimizer_disc": optimizer_disc_state, + "scheduler_disc": scheduler_disc_state, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), } state.update(kwargs) torch.save(state, output_path) -def save_checkpoint(model, optimizer, scheduler, model_disc, optimizer_disc, - scheduler_disc, current_step, epoch, output_folder, - **kwargs): - file_name = 'checkpoint_{}.pth.tar'.format(current_step) +def save_checkpoint( + model, + optimizer, + scheduler, + model_disc, + optimizer_disc, + scheduler_disc, + current_step, + epoch, + output_folder, + **kwargs, +): + file_name = "checkpoint_{}.pth.tar".format(current_step) checkpoint_path = os.path.join(output_folder, file_name) print(" > CHECKPOINT : {}".format(checkpoint_path)) - save_model(model, optimizer, scheduler, model_disc, optimizer_disc, - scheduler_disc, current_step, epoch, checkpoint_path, **kwargs) + save_model( + model, + optimizer, + scheduler, + model_disc, + optimizer_disc, + scheduler_disc, + current_step, + epoch, + checkpoint_path, + **kwargs, + ) -def save_best_model(current_loss, best_loss, model, optimizer, scheduler, - model_disc, optimizer_disc, scheduler_disc, current_step, - epoch, out_path, keep_all_best=False, keep_after=10000, - **kwargs): +def save_best_model( + current_loss, + best_loss, + model, + optimizer, + scheduler, + model_disc, + optimizer_disc, + scheduler_disc, + current_step, + epoch, + out_path, + keep_all_best=False, + keep_after=10000, + **kwargs, +): if current_loss < best_loss: - best_model_name = f'best_model_{current_step}.pth.tar' + best_model_name = f"best_model_{current_step}.pth.tar" checkpoint_path = os.path.join(out_path, best_model_name) print(" > BEST MODEL : {}".format(checkpoint_path)) - save_model(model, - optimizer, - scheduler, - model_disc, - optimizer_disc, - scheduler_disc, - current_step, - epoch, - checkpoint_path, - model_loss=current_loss, - **kwargs) + save_model( + model, + optimizer, + scheduler, + model_disc, + optimizer_disc, + scheduler_disc, + current_step, + epoch, + checkpoint_path, + model_loss=current_loss, + **kwargs, + ) # only delete previous if current is saved successfully if not keep_all_best or (current_step < keep_after): - model_names = glob.glob( - os.path.join(out_path, 'best_model*.pth.tar')) + model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar")) for model_name in model_names: if os.path.basename(model_name) == best_model_name: continue os.remove(model_name) # create symlink to best model for convinience - link_name = 'best_model.pth.tar' + link_name = "best_model.pth.tar" link_path = os.path.join(out_path, link_name) if os.path.islink(link_path) or os.path.isfile(link_path): os.remove(link_path) diff --git a/notebooks/dataset_analysis/analyze.py b/notebooks/dataset_analysis/analyze.py index 161e2ae3be..6c6bc5827e 100644 --- a/notebooks/dataset_analysis/analyze.py +++ b/notebooks/dataset_analysis/analyze.py @@ -1,16 +1,17 @@ # visualisation tools for mimic2 -import matplotlib.pyplot as plt -from statistics import stdev, mode, mean, median -from statistics import StatisticsError import argparse -import os import csv -import seaborn as sns +import os import random +from statistics import StatisticsError, mean, median, mode, stdev + +import matplotlib.pyplot as plt +import seaborn as sns from text.cmudict import CMUDict + def get_audio_seconds(frames): - return (frames*12.5)/1000 + return (frames * 12.5) / 1000 def append_data_statistics(meta_data): @@ -27,9 +28,7 @@ def append_data_statistics(meta_data): median_audio_len = median(audio_len_list) try: - std = stdev( - d["audio_len"] for d in data - ) + std = stdev(d["audio_len"] for d in data) except StatisticsError: std = 0 @@ -44,24 +43,22 @@ def process_meta_data(path): meta_data = {} # load meta data - with open(path, 'r') as f: - data = csv.reader(f, delimiter='|') + with open(path, "r") as f: + data = csv.reader(f, delimiter="|") for row in data: frames = int(row[2]) utt = row[3] audio_len = get_audio_seconds(frames) char_count = len(utt) if not meta_data.get(char_count): - meta_data[char_count] = { - "data": [] - } + meta_data[char_count] = {"data": []} meta_data[char_count]["data"].append( { "utt": utt, "frames": frames, "audio_len": audio_len, - "row": "{}|{}|{}|{}".format(row[0], row[1], row[2], row[3]) + "row": "{}|{}|{}|{}".format(row[0], row[1], row[2], row[3]), } ) @@ -72,30 +69,30 @@ def process_meta_data(path): def get_data_points(meta_data): x = meta_data - y_avg = [meta_data[d]['mean'] for d in meta_data] - y_mode = [meta_data[d]['mode'] for d in meta_data] - y_median = [meta_data[d]['median'] for d in meta_data] - y_std = [meta_data[d]['std'] for d in meta_data] - y_num_samples = [len(meta_data[d]['data']) for d in meta_data] + y_avg = [meta_data[d]["mean"] for d in meta_data] + y_mode = [meta_data[d]["mode"] for d in meta_data] + y_median = [meta_data[d]["median"] for d in meta_data] + y_std = [meta_data[d]["std"] for d in meta_data] + y_num_samples = [len(meta_data[d]["data"]) for d in meta_data] return { "x": x, "y_avg": y_avg, "y_mode": y_mode, "y_median": y_median, "y_std": y_std, - "y_num_samples": y_num_samples + "y_num_samples": y_num_samples, } def save_training(file_path, meta_data): rows = [] for char_cnt in meta_data: - data = meta_data[char_cnt]['data'] + data = meta_data[char_cnt]["data"] for d in data: - rows.append(d['row'] + "\n") + rows.append(d["row"] + "\n") random.shuffle(rows) - with open(file_path, 'w+') as f: + with open(file_path, "w+") as f: for row in rows: f.write(row) @@ -106,15 +103,15 @@ def plot(meta_data, save_path=None): save = True graph_data = get_data_points(meta_data) - x = graph_data['x'] - y_avg = graph_data['y_avg'] - y_std = graph_data['y_std'] - y_mode = graph_data['y_mode'] - y_median = graph_data['y_median'] - y_num_samples = graph_data['y_num_samples'] + x = graph_data["x"] + y_avg = graph_data["y_avg"] + y_std = graph_data["y_std"] + y_mode = graph_data["y_mode"] + y_median = graph_data["y_median"] + y_num_samples = graph_data["y_num_samples"] plt.figure() - plt.plot(x, y_avg, 'ro') + plt.plot(x, y_avg, "ro") plt.xlabel("character lengths", fontsize=30) plt.ylabel("avg seconds", fontsize=30) if save: @@ -122,7 +119,7 @@ def plot(meta_data, save_path=None): plt.savefig(os.path.join(save_path, name)) plt.figure() - plt.plot(x, y_mode, 'ro') + plt.plot(x, y_mode, "ro") plt.xlabel("character lengths", fontsize=30) plt.ylabel("mode seconds", fontsize=30) if save: @@ -130,7 +127,7 @@ def plot(meta_data, save_path=None): plt.savefig(os.path.join(save_path, name)) plt.figure() - plt.plot(x, y_median, 'ro') + plt.plot(x, y_median, "ro") plt.xlabel("character lengths", fontsize=30) plt.ylabel("median seconds", fontsize=30) if save: @@ -138,7 +135,7 @@ def plot(meta_data, save_path=None): plt.savefig(os.path.join(save_path, name)) plt.figure() - plt.plot(x, y_std, 'ro') + plt.plot(x, y_std, "ro") plt.xlabel("character lengths", fontsize=30) plt.ylabel("standard deviation", fontsize=30) if save: @@ -146,7 +143,7 @@ def plot(meta_data, save_path=None): plt.savefig(os.path.join(save_path, name)) plt.figure() - plt.plot(x, y_num_samples, 'ro') + plt.plot(x, y_num_samples, "ro") plt.xlabel("character lengths", fontsize=30) plt.ylabel("number of samples", fontsize=30) if save: @@ -159,8 +156,8 @@ def plot_phonemes(train_path, cmu_dict_path, save_path): phonemes = {} - with open(train_path, 'r') as f: - data = csv.reader(f, delimiter='|') + with open(train_path, "r") as f: + data = csv.reader(f, delimiter="|") phonemes["None"] = 0 for row in data: words = row[3].split() @@ -192,15 +189,12 @@ def plot_phonemes(train_path, cmu_dict_path, save_path): def main(): parser = argparse.ArgumentParser() parser.add_argument( - '--train_file_path', required=True, - help='this is the path to the train.txt file that the preprocess.py script creates' - ) - parser.add_argument( - '--save_to', help='path to save charts of data to' - ) - parser.add_argument( - '--cmu_dict_path', help='give cmudict-0.7b to see phoneme distribution' + "--train_file_path", + required=True, + help="this is the path to the train.txt file that the preprocess.py script creates", ) + parser.add_argument("--save_to", help="path to save charts of data to") + parser.add_argument("--cmu_dict_path", help="give cmudict-0.7b to see phoneme distribution") args = parser.parse_args() meta_data = process_meta_data(args.train_file_path) plt.rcParams["figure.figsize"] = (10, 5) @@ -211,5 +205,6 @@ def main(): plt.show() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index b6c632d896..5c742966fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,33 @@ [build-system] requires = ["setuptools", "wheel", "Cython", "numpy==1.17.5"] + +[flake8] +max-line-length=120 + +[tool.black] +line-length = 120 +target-version = ['py38'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ + | foo.py # also separately exclude a file named foo.py in + # the root of the project +) +''' + +[tool.isort] +line_length = 120 +profile = "black" +multi_line_output = 3 \ No newline at end of file diff --git a/tests/test_audio.py b/tests/test_audio.py index c00cd8f8fe..5eb0826207 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -2,7 +2,6 @@ import unittest from tests import get_tests_input_path, get_tests_output_path, get_tests_path - from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config @@ -11,7 +10,7 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") os.makedirs(OUT_PATH, exist_ok=True) -conf = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +conf = load_config(os.path.join(get_tests_input_path(), "test_config.json")) # pylint: disable=protected-access @@ -21,10 +20,10 @@ def __init__(self, *args, **kwargs): self.ap = AudioProcessor(**conf.audio) def test_audio_synthesis(self): - """ 1. load wav - 2. set normalization parameters - 3. extract mel-spec - 4. invert to wav and save the output + """1. load wav + 2. set normalization parameters + 3. extract mel-spec + 4. invert to wav and save the output """ print(" > Sanity check for the process wav -> mel -> wav") @@ -36,23 +35,24 @@ def _test(max_norm, signal_norm, symmetric_norm, clip_norm): wav = self.ap.load_wav(WAV_FILE) mel = self.ap.melspectrogram(wav) wav_ = self.ap.inv_melspectrogram(mel) - file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav"\ - .format(max_norm, signal_norm, symmetric_norm, clip_norm) + file_name = "/audio_test-melspec_max_norm_{}-signal_norm_{}-symmetric_{}-clip_norm_{}.wav".format( + max_norm, signal_norm, symmetric_norm, clip_norm + ) print(" | > Creating wav file at : ", file_name) self.ap.save_wav(wav_, OUT_PATH + file_name) # maxnorm = 1.0 - _test(1., False, False, False) - _test(1., True, False, False) - _test(1., True, True, False) - _test(1., True, False, True) - _test(1., True, True, True) + _test(1.0, False, False, False) + _test(1.0, True, False, False) + _test(1.0, True, True, False) + _test(1.0, True, False, True) + _test(1.0, True, True, True) # maxnorm = 4.0 - _test(4., False, False, False) - _test(4., True, False, False) - _test(4., True, True, False) - _test(4., True, False, True) - _test(4., True, True, True) + _test(4.0, False, False, False) + _test(4.0, True, False, False) + _test(4.0, True, True, False) + _test(4.0, True, False, True) + _test(4.0, True, True, True) def test_normalize(self): """Check normalization and denormalization for range values and consistency """ @@ -68,7 +68,9 @@ def test_normalize(self): self.ap.clip_norm = False self.ap.max_norm = 4.0 x_norm = self.ap.normalize(x) - print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}") + print( + f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}" + ) assert (x_old - x).sum() == 0 # check value range assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max() @@ -82,8 +84,9 @@ def test_normalize(self): self.ap.clip_norm = True self.ap.max_norm = 4.0 x_norm = self.ap.normalize(x) - print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}") - + print( + f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}" + ) assert (x_old - x).sum() == 0 # check value range @@ -98,13 +101,14 @@ def test_normalize(self): self.ap.clip_norm = False self.ap.max_norm = 4.0 x_norm = self.ap.normalize(x) - print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}") - + print( + f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}" + ) assert (x_old - x).sum() == 0 # check value range assert x_norm.max() <= self.ap.max_norm + 1, x_norm.max() - assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min() #pylint: disable=invalid-unary-operand-type + assert x_norm.min() >= -self.ap.max_norm - 2, x_norm.min() # pylint: disable=invalid-unary-operand-type assert x_norm.min() <= 0, x_norm.min() # check denorm. x_ = self.ap.denormalize(x_norm) @@ -115,13 +119,14 @@ def test_normalize(self): self.ap.clip_norm = True self.ap.max_norm = 4.0 x_norm = self.ap.normalize(x) - print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}") - + print( + f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}" + ) assert (x_old - x).sum() == 0 # check value range assert x_norm.max() <= self.ap.max_norm, x_norm.max() - assert x_norm.min() >= -self.ap.max_norm, x_norm.min() #pylint: disable=invalid-unary-operand-type + assert x_norm.min() >= -self.ap.max_norm, x_norm.min() # pylint: disable=invalid-unary-operand-type assert x_norm.min() <= 0, x_norm.min() # check denorm. x_ = self.ap.denormalize(x_norm) @@ -131,8 +136,9 @@ def test_normalize(self): self.ap.symmetric_norm = False self.ap.max_norm = 1.0 x_norm = self.ap.normalize(x) - print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}") - + print( + f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}" + ) assert (x_old - x).sum() == 0 assert x_norm.max() <= self.ap.max_norm, x_norm.max() @@ -144,22 +150,23 @@ def test_normalize(self): self.ap.symmetric_norm = True self.ap.max_norm = 1.0 x_norm = self.ap.normalize(x) - print(f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}") - + print( + f" > MaxNorm: {self.ap.max_norm}, ClipNorm:{self.ap.clip_norm}, SymmetricNorm:{self.ap.symmetric_norm}, SignalNorm:{self.ap.signal_norm} Range-> {x_norm.max()} -- {x_norm.min()}" + ) assert (x_old - x).sum() == 0 assert x_norm.max() <= self.ap.max_norm, x_norm.max() - assert x_norm.min() >= -self.ap.max_norm, x_norm.min() #pylint: disable=invalid-unary-operand-type + assert x_norm.min() >= -self.ap.max_norm, x_norm.min() # pylint: disable=invalid-unary-operand-type assert x_norm.min() < 0, x_norm.min() x_ = self.ap.denormalize(x_norm) assert (x - x_).sum() < 1e-3 def test_scaler(self): - scaler_stats_path = os.path.join(get_tests_input_path(), 'scale_stats.npy') - conf.audio['stats_path'] = scaler_stats_path - conf.audio['preemphasis'] = 0.0 - conf.audio['do_trim_silence'] = True - conf.audio['signal_norm'] = True + scaler_stats_path = os.path.join(get_tests_input_path(), "scale_stats.npy") + conf.audio["stats_path"] = scaler_stats_path + conf.audio["preemphasis"] = 0.0 + conf.audio["do_trim_silence"] = True + conf.audio["signal_norm"] = True ap = AudioProcessor(**conf.audio) mel_mean, mel_std, linear_mean, linear_std, _ = ap.load_stats(scaler_stats_path) diff --git a/tests/test_feed_forward_layers.py b/tests/test_feed_forward_layers.py index 7dd54e564f..1db980a3b2 100644 --- a/tests/test_feed_forward_layers.py +++ b/tests/test_feed_forward_layers.py @@ -1,4 +1,5 @@ import torch + from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.utils.generic_utils import sequence_mask @@ -8,99 +9,99 @@ def test_encoder(): input_dummy = torch.rand(8, 14, 37).to(device) - input_lengths = torch.randint(31, 37, (8, )).long().to(device) + input_lengths = torch.randint(31, 37, (8,)).long().to(device) input_lengths[-1] = 37 - input_mask = torch.unsqueeze( - sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) + input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) # relative positional transformer encoder - layer = Encoder(out_channels=11, - in_hidden_channels=14, - encoder_type='relative_position_transformer', - encoder_params={ - 'hidden_channels_ffn': 768, - 'num_heads': 2, - "kernel_size": 3, - "dropout_p": 0.1, - "num_layers": 6, - "rel_attn_window_size": 4, - "input_length": None - }).to(device) + layer = Encoder( + out_channels=11, + in_hidden_channels=14, + encoder_type="relative_position_transformer", + encoder_params={ + "hidden_channels_ffn": 768, + "num_heads": 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "rel_attn_window_size": 4, + "input_length": None, + }, + ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # residual conv bn encoder - layer = Encoder(out_channels=11, - in_hidden_channels=14, - encoder_type='residual_conv_bn', - encoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 13 - }).to(device) + layer = Encoder( + out_channels=11, + in_hidden_channels=14, + encoder_type="residual_conv_bn", + encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, + ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # FFTransformer encoder - layer = Encoder(out_channels=14, - in_hidden_channels=14, - encoder_type='fftransformer', - encoder_params={ - "hidden_channels_ffn": 31, - "num_heads": 2, - "num_layers": 2, - "dropout_p": 0.1 - }).to(device) + layer = Encoder( + out_channels=14, + in_hidden_channels=14, + encoder_type="fftransformer", + encoder_params={"hidden_channels_ffn": 31, "num_heads": 2, "num_layers": 2, "dropout_p": 0.1}, + ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 14, 37] def test_decoder(): input_dummy = torch.rand(8, 128, 37).to(device) - input_lengths = torch.randint(31, 37, (8, )).long().to(device) + input_lengths = torch.randint(31, 37, (8,)).long().to(device) input_lengths[-1] = 37 - input_mask = torch.unsqueeze( - sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) + input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) # residual bn conv decoder layer = Decoder(out_channels=11, in_hidden_channels=128).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # transformer decoder - layer = Decoder(out_channels=11, - in_hidden_channels=128, - decoder_type='relative_position_transformer', - decoder_params={ - 'hidden_channels_ffn': 128, - 'num_heads': 2, - "kernel_size": 3, - "dropout_p": 0.1, - "num_layers": 8, - "rel_attn_window_size": 4, - "input_length": None - }).to(device) + layer = Decoder( + out_channels=11, + in_hidden_channels=128, + decoder_type="relative_position_transformer", + decoder_params={ + "hidden_channels_ffn": 128, + "num_heads": 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 8, + "rel_attn_window_size": 4, + "input_length": None, + }, + ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # wavenet decoder - layer = Decoder(out_channels=11, - in_hidden_channels=128, - decoder_type='wavenet', - decoder_params={ - "num_blocks": 12, - "hidden_channels": 192, - "kernel_size": 5, - "dilation_rate": 1, - "num_layers": 4, - "dropout_p": 0.05 - }).to(device) + layer = Decoder( + out_channels=11, + in_hidden_channels=128, + decoder_type="wavenet", + decoder_params={ + "num_blocks": 12, + "hidden_channels": 192, + "kernel_size": 5, + "dilation_rate": 1, + "num_layers": 4, + "dropout_p": 0.05, + }, + ).to(device) output = layer(input_dummy, input_mask) # FFTransformer decoder - layer = Decoder(out_channels=11, - in_hidden_channels=128, - decoder_type='fftransformer', - decoder_params={ - 'hidden_channels_ffn': 31, - 'num_heads': 2, - "dropout_p": 0.1, - "num_layers": 2, - }).to(device) + layer = Decoder( + out_channels=11, + in_hidden_channels=128, + decoder_type="fftransformer", + decoder_params={ + "hidden_channels_ffn": 31, + "num_heads": 2, + "dropout_p": 0.1, + "num_layers": 2, + }, + ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] diff --git a/tests/test_glow_tts.py b/tests/test_glow_tts.py index e9fdc7615d..8e699faf13 100644 --- a/tests/test_glow_tts.py +++ b/tests/test_glow_tts.py @@ -3,21 +3,21 @@ import unittest import torch -from tests import get_tests_input_path from torch import optim +from tests import get_tests_input_path from TTS.tts.layers.losses import GlowTTSLoss from TTS.tts.models.glow_tts import GlowTTS -from TTS.utils.io import load_config from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_config -#pylint: disable=unused-variable +# pylint: disable=unused-variable torch.manual_seed(1) use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -c = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +c = load_config(os.path.join(get_tests_input_path(), "test_config.json")) ap = AudioProcessor(**c.audio) WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") @@ -32,11 +32,11 @@ class GlowTTSTrainTest(unittest.TestCase): @staticmethod def test_train_step(): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) criterion = GlowTTSLoss() @@ -47,27 +47,28 @@ def test_train_step(): hidden_channels_dec=48, hidden_channels_dp=32, out_channels=80, - encoder_type='rel_pos_transformer', + encoder_type="rel_pos_transformer", encoder_params={ - 'kernel_size': 3, - 'dropout_p': 0.1, - 'num_layers': 6, - 'num_heads': 2, - 'hidden_channels_ffn': 16, # 4 times the hidden_channels - 'input_length': None + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 16, # 4 times the hidden_channels + "input_length": None, }, use_encoder_prenet=True, num_flow_blocks_dec=12, kernel_size_dec=5, dilation_rate=1, num_block_layers=4, - dropout_p_dec=0., + dropout_p_dec=0.0, num_speakers=0, c_in_channels=0, num_splits=4, num_squeeze=1, sigmoid_scale=False, - mean_only=False).to(device) + mean_only=False, + ).to(device) # reference model to compare model weights model_ref = GlowTTS( @@ -76,38 +77,37 @@ def test_train_step(): hidden_channels_dec=48, hidden_channels_dp=32, out_channels=80, - encoder_type='rel_pos_transformer', + encoder_type="rel_pos_transformer", encoder_params={ - 'kernel_size': 3, - 'dropout_p': 0.1, - 'num_layers': 6, - 'num_heads': 2, - 'hidden_channels_ffn': 16, # 4 times the hidden_channels - 'input_length': None + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 16, # 4 times the hidden_channels + "input_length": None, }, use_encoder_prenet=True, num_flow_blocks_dec=12, kernel_size_dec=5, dilation_rate=1, num_block_layers=4, - dropout_p_dec=0., + dropout_p_dec=0.0, num_speakers=0, c_in_channels=0, num_splits=4, num_squeeze=1, sigmoid_scale=False, - mean_only=False).to(device) + mean_only=False, + ).to(device) model.train() - print(" > Num parameters for GlowTTS model:%s" % - (count_parameters(model))) + print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # pass the state to ref model model_ref.load_state_dict(copy.deepcopy(model.state_dict())) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 @@ -115,18 +115,17 @@ def test_train_step(): for _ in range(5): optimizer.zero_grad() z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, None) - loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, - o_dur_log, o_total_dur, input_lengths) - loss = loss_dict['loss'] + input_dummy, input_lengths, mel_spec, mel_lengths, None + ) + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, input_lengths) + loss = loss_dict["loss"] loss.backward() optimizer.step() # check parameter changes count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + for param, param_ref in zip(model.parameters(), model_ref.parameters()): + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1 diff --git a/tests/test_layers.py b/tests/test_layers.py index 582ca8bedd..9b89e645ba 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,15 +1,16 @@ import unittest + import torch as T -from TTS.tts.layers.tacotron.tacotron import Prenet, CBHG, Decoder, Encoder from TTS.tts.layers.losses import L1LossMasked, SSIMLoss +from TTS.tts.layers.tacotron.tacotron import CBHG, Decoder, Encoder, Prenet from TTS.tts.utils.generic_utils import sequence_mask # pylint: disable=unused-variable class PrenetTests(unittest.TestCase): - def test_in_out(self): #pylint: disable=no-self-use + def test_in_out(self): # pylint: disable=no-self-use layer = Prenet(128, out_features=[256, 128]) dummy_input = T.rand(4, 128) @@ -21,7 +22,7 @@ def test_in_out(self): #pylint: disable=no-self-use class CBHGTests(unittest.TestCase): def test_in_out(self): - #pylint: disable=attribute-defined-outside-init + # pylint: disable=attribute-defined-outside-init layer = self.cbhg = CBHG( 128, K=8, @@ -29,7 +30,8 @@ def test_in_out(self): conv_projections=[160, 128], highway_features=80, gru_features=80, - num_highways=4) + num_highways=4, + ) # B x D x T dummy_input = T.rand(4, 128, 8) @@ -52,26 +54,27 @@ def test_in_out(): attn_norm="sigmoid", attn_K=5, attn_type="original", - prenet_type='original', + prenet_type="original", prenet_dropout=True, forward_attn=True, trans_agent=True, forward_attn_mask=True, location_attn=True, - separate_stopnet=True) + separate_stopnet=True, + ) dummy_input = T.rand(4, 8, 256) dummy_memory = T.rand(4, 2, 80) - output, alignment, stop_tokens = layer( - dummy_input, dummy_memory, mask=None) + output, alignment, stop_tokens = layer(dummy_input, dummy_memory, mask=None) assert output.shape[0] == 4 assert output.shape[1] == 80, "size not {}".format(output.shape[1]) assert output.shape[2] == 2, "size not {}".format(output.shape[2]) assert stop_tokens.shape[0] == 4 + class EncoderTests(unittest.TestCase): - def test_in_out(self): #pylint: disable=no-self-use + def test_in_out(self): # pylint: disable=no-self-use layer = Encoder(128) dummy_input = T.rand(4, 8, 128) @@ -84,7 +87,7 @@ def test_in_out(self): #pylint: disable=no-self-use class L1LossMaskedTests(unittest.TestCase): - def test_in_out(self): #pylint: disable=no-self-use + def test_in_out(self): # pylint: disable=no-self-use # test input == target layer = L1LossMasked(seq_len_norm=False) dummy_input = T.ones(4, 8, 128).float() @@ -104,16 +107,14 @@ def test_in_out(self): #pylint: disable=no-self-use dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.arange(5, 9)).long() - mask = ( - (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 1.0, "1.0 vs {}".format(output.item()) dummy_input = T.rand(4, 8, 128).float() dummy_target = dummy_input.detach() dummy_length = (T.arange(5, 9)).long() - mask = ( - (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 0, "0 vs {}".format(output.item()) @@ -137,22 +138,20 @@ def test_in_out(self): #pylint: disable=no-self-use dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.arange(5, 9)).long() - mask = ( - (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item()) dummy_input = T.rand(4, 8, 128).float() dummy_target = dummy_input.detach() dummy_length = (T.arange(5, 9)).long() - mask = ( - (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 0, "0 vs {}".format(output.item()) class SSIMLossTests(unittest.TestCase): - def test_in_out(self): #pylint: disable=no-self-use + def test_in_out(self): # pylint: disable=no-self-use # test input == target layer = SSIMLoss() dummy_input = T.ones(4, 8, 128).float() @@ -172,16 +171,14 @@ def test_in_out(self): #pylint: disable=no-self-use dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.arange(5, 9)).long() - mask = ( - (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert abs(output.item() - 1.0) < 1e-4, "1.0 vs {}".format(output.item()) dummy_input = T.rand(4, 8, 128).float() dummy_target = dummy_input.detach() dummy_length = (T.arange(5, 9)).long() - mask = ( - (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 0, "0 vs {}".format(output.item()) @@ -205,15 +202,13 @@ def test_in_out(self): #pylint: disable=no-self-use dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.arange(5, 9)).long() - mask = ( - (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item()) dummy_input = T.rand(4, 8, 128).float() dummy_target = dummy_input.detach() dummy_length = (T.arange(5, 9)).long() - mask = ( - (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 0, "0 vs {}".format(output.item()) diff --git a/tests/test_loader.py b/tests/test_loader.py index b79aad191d..b7cf73023e 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -4,19 +4,19 @@ import numpy as np import torch -from tests import get_tests_input_path, get_tests_output_path from torch.utils.data import DataLoader +from tests import get_tests_input_path, get_tests_output_path from TTS.tts.datasets import TTSDataset from TTS.tts.datasets.preprocess import ljspeech from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config -#pylint: disable=unused-variable +# pylint: disable=unused-variable OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/") os.makedirs(OUTPATH, exist_ok=True) -c = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +c = load_config(os.path.join(get_tests_input_path(), "test_config.json")) ok_ljspeech = os.path.exists(c.data_path) DATA_EXIST = True @@ -33,25 +33,27 @@ def __init__(self, *args, **kwargs): self.ap = AudioProcessor(**c.audio) def _create_dataloader(self, batch_size, r, bgs): - items = ljspeech(c.data_path, 'metadata.csv') + items = ljspeech(c.data_path, "metadata.csv") dataset = TTSDataset.MyDataset( r, c.text_cleaner, compute_linear_spec=True, ap=self.ap, meta_data=items, - tp=c.characters if 'characters' in c.keys() else None, + tp=c.characters if "characters" in c.keys() else None, batch_group_size=bgs, min_seq_len=c.min_seq_len, max_seq_len=float("inf"), - use_phonemes=False) + use_phonemes=False, + ) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=dataset.collate_fn, drop_last=True, - num_workers=c.num_loader_workers) + num_workers=c.num_loader_workers, + ) return dataloader, dataset def test_loader(self): @@ -72,18 +74,17 @@ def test_loader(self): neg_values = text_input[text_input < 0] check_count = len(neg_values) - assert check_count == 0, \ - " !! Negative values in text_input: {}".format(check_count) + assert check_count == 0, " !! Negative values in text_input: {}".format(check_count) # TODO: more assertion here assert isinstance(speaker_name[0], str) assert linear_input.shape[0] == c.batch_size assert linear_input.shape[2] == self.ap.fft_size // 2 + 1 assert mel_input.shape[0] == c.batch_size - assert mel_input.shape[2] == c.audio['num_mels'] + assert mel_input.shape[2] == c.audio["num_mels"] # check normalization ranges if self.ap.symmetric_norm: assert mel_input.max() <= self.ap.max_norm - assert mel_input.min() >= -self.ap.max_norm #pylint: disable=invalid-unary-operand-type + assert mel_input.min() >= -self.ap.max_norm # pylint: disable=invalid-unary-operand-type assert mel_input.min() < 0 else: assert mel_input.max() <= self.ap.max_norm @@ -134,7 +135,7 @@ def test_padding_and_spec(self): # check mel_spec consistency wav = np.asarray(self.ap.load_wav(item_idx[0]), dtype=np.float32) - mel = self.ap.melspectrogram(wav).astype('float32') + mel = self.ap.melspectrogram(wav).astype("float32") mel = torch.FloatTensor(mel).contiguous() mel_dl = mel_input[0] # NOTE: Below needs to check == 0 but due to an unknown reason @@ -145,15 +146,14 @@ def test_padding_and_spec(self): # check mel-spec correctness mel_spec = mel_input[0].cpu().numpy() wav = self.ap.inv_melspectrogram(mel_spec.T) - self.ap.save_wav(wav, OUTPATH + '/mel_inv_dataloader.wav') - shutil.copy(item_idx[0], OUTPATH + '/mel_target_dataloader.wav') + self.ap.save_wav(wav, OUTPATH + "/mel_inv_dataloader.wav") + shutil.copy(item_idx[0], OUTPATH + "/mel_target_dataloader.wav") # check linear-spec linear_spec = linear_input[0].cpu().numpy() wav = self.ap.inv_spectrogram(linear_spec.T) - self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav') - shutil.copy(item_idx[0], - OUTPATH + '/linear_target_dataloader.wav') + self.ap.save_wav(wav, OUTPATH + "/linear_inv_dataloader.wav") + shutil.copy(item_idx[0], OUTPATH + "/linear_target_dataloader.wav") # check the last time step to be zero padded assert linear_input[0, -1].sum() != 0 @@ -202,8 +202,8 @@ def test_padding_and_spec(self): # check the second itme in the batch assert linear_input[1 - idx, -1].sum() == 0 assert mel_input[1 - idx, -1].sum() == 0 - assert stop_target[1, mel_lengths[1]-1] == 1 - assert stop_target[1, mel_lengths[1]:].sum() == 0 + assert stop_target[1, mel_lengths[1] - 1] == 1 + assert stop_target[1, mel_lengths[1] :].sum() == 0 assert len(mel_lengths.shape) == 1 # check batch zero-frame conditions (zero-frame disabled) diff --git a/tests/test_preprocessors.py b/tests/test_preprocessors.py index 8c7b16b01c..968e2a29fa 100644 --- a/tests/test_preprocessors.py +++ b/tests/test_preprocessors.py @@ -1,17 +1,16 @@ -import unittest import os -from tests import get_tests_input_path +import unittest +from tests import get_tests_input_path from TTS.tts.datasets.preprocess import common_voice class TestPreprocessors(unittest.TestCase): - - def test_common_voice_preprocessor(self): #pylint: disable=no-self-use + def test_common_voice_preprocessor(self): # pylint: disable=no-self-use root_path = get_tests_input_path() meta_file = "common_voice.tsv" items = common_voice(root_path, meta_file) - assert items[0][0] == 'The applicants are invited for coffee and visa is given immediately.' + assert items[0][0] == "The applicants are invited for coffee and visa is given immediately." assert items[0][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav") assert items[-1][0] == "Competition for limited resources has also resulted in some local conflicts." diff --git a/tests/test_speaker_encoder.py b/tests/test_speaker_encoder.py index 4d4dbba116..32ba2924c8 100644 --- a/tests/test_speaker_encoder.py +++ b/tests/test_speaker_encoder.py @@ -2,9 +2,9 @@ import unittest import torch as T -from tests import get_tests_input_path -from TTS.speaker_encoder.losses import GE2ELoss, AngleProtoLoss +from tests import get_tests_input_path +from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss from TTS.speaker_encoder.model import SpeakerEncoder from TTS.utils.io import load_config @@ -17,9 +17,7 @@ class SpeakerEncoderTests(unittest.TestCase): def test_in_out(self): dummy_input = T.rand(4, 20, 80) # B x T x D dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)] - model = SpeakerEncoder( - input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3 - ) + model = SpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3) # computing d vectors output = model.forward(dummy_input) assert output.shape[0] == 4 @@ -36,9 +34,7 @@ def test_in_out(self): output_norm = T.nn.functional.normalize(output, dim=1, p=2) assert_diff = (output_norm - output).sum().item() assert output.type() == "torch.FloatTensor" - assert ( - abs(assert_diff) < 1e-4 - ), f" [!] output_norm has wrong values - {assert_diff}" + assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}" # compute d for a given batch dummy_input = T.rand(1, 240, 80) # B x T x D output = model.compute_embedding(dummy_input, num_frames=160, overlap=0.5) @@ -74,6 +70,7 @@ def test_in_out(self): output = loss.forward(dummy_input) assert output.item() < 0.005 + class AngleProtoLossTests(unittest.TestCase): # pylint: disable=R0201 def test_in_out(self): @@ -103,6 +100,7 @@ def test_in_out(self): output = loss.forward(dummy_input) assert output.item() < 0.005 + # class LoaderTest(unittest.TestCase): # def test_output(self): # items = libri_tts("/home/erogol/Data/Libri-TTS/train-clean-360/") diff --git a/tests/test_speedy_speech_layers.py b/tests/test_speedy_speech_layers.py index 954d5ecade..3473769b20 100644 --- a/tests/test_speedy_speech_layers.py +++ b/tests/test_speedy_speech_layers.py @@ -1,8 +1,8 @@ import torch + from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor -from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.models.speedy_speech import SpeedySpeech - +from TTS.tts.utils.generic_utils import sequence_mask use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -10,11 +10,10 @@ def test_duration_predictor(): input_dummy = torch.rand(8, 128, 27).to(device) - input_lengths = torch.randint(20, 27, (8, )).long().to(device) + input_lengths = torch.randint(20, 27, (8,)).long().to(device) input_lengths[-1] = 27 - x_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), - 1).to(device) + x_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) layer = DurationPredictor(hidden_channels=128).to(device) @@ -29,7 +28,7 @@ def test_speedy_speech(): T_de = 74 x_dummy = torch.randint(0, 7, (B, T_en)).long().to(device) - x_lengths = torch.randint(31, T_en, (B, )).long().to(device) + x_lengths = torch.randint(31, T_en, (B,)).long().to(device) x_lengths[-1] = T_en # set durations. max total duration should be equal to T_de @@ -53,34 +52,18 @@ def test_speedy_speech(): assert list(o_dr.shape) == [B, T_en] # with speaker embedding - model = SpeedySpeech(num_chars, - out_channels=80, - hidden_channels=128, - num_speakers=10, - c_in_channels=256).to(device) - model.forward(x_dummy, - x_lengths, - y_lengths, - durations, - g=torch.randint(0, 10, (B,)).to(device)) + model = SpeedySpeech(num_chars, out_channels=80, hidden_channels=128, num_speakers=10, c_in_channels=256).to(device) + model.forward(x_dummy, x_lengths, y_lengths, durations, g=torch.randint(0, 10, (B,)).to(device)) assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}" assert list(attn.shape) == [B, T_de, T_en] assert list(o_dr.shape) == [B, T_en] - # with speaker external embedding - model = SpeedySpeech(num_chars, - out_channels=80, - hidden_channels=128, - num_speakers=10, - external_c=True, - c_in_channels=256).to(device) - model.forward(x_dummy, - x_lengths, - y_lengths, - durations, - g=torch.rand((B, 256)).to(device)) + model = SpeedySpeech( + num_chars, out_channels=80, hidden_channels=128, num_speakers=10, external_c=True, c_in_channels=256 + ).to(device) + model.forward(x_dummy, x_lengths, y_lengths, durations, g=torch.rand((B, 256)).to(device)) assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}" assert list(attn.shape) == [B, T_de, T_en] diff --git a/tests/test_symbols.py b/tests/test_symbols.py index 4e70b9d550..49b2598649 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -2,6 +2,7 @@ from TTS.tts.utils.text import phonemes + class SymbolsTest(unittest.TestCase): - def test_uniqueness(self): #pylint: disable=no-self-use + def test_uniqueness(self): # pylint: disable=no-self-use assert sorted(phonemes) == sorted(list(set(phonemes))), " {} vs {} ".format(len(phonemes), len(set(phonemes))) diff --git a/tests/test_synthesizer.py b/tests/test_synthesizer.py index b7d3febc3b..46b9ab74f3 100644 --- a/tests/test_synthesizer.py +++ b/tests/test_synthesizer.py @@ -2,11 +2,11 @@ import unittest from tests import get_tests_input_path, get_tests_output_path -from TTS.utils.synthesizer import Synthesizer from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_checkpoint from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.io import load_config +from TTS.utils.synthesizer import Synthesizer class SynthesizerTest(unittest.TestCase): @@ -14,8 +14,8 @@ class SynthesizerTest(unittest.TestCase): def _create_random_model(self): # pylint: disable=global-statement global symbols, phonemes - config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json')) - if 'characters' in config.keys(): + config = load_config(os.path.join(get_tests_output_path(), "dummy_model_config.json")) + if "characters" in config.keys(): symbols, phonemes = make_symbols(**config.characters) num_chars = len(phonemes) if config.use_phonemes else len(symbols) @@ -25,11 +25,11 @@ def _create_random_model(self): def test_in_out(self): self._create_random_model() - config = load_config(os.path.join(get_tests_input_path(), 'server_config.json')) + config = load_config(os.path.join(get_tests_input_path(), "server_config.json")) tts_root_path = get_tests_output_path() - config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint']) - config['tts_config'] = os.path.join(tts_root_path, config['tts_config']) - synthesizer = Synthesizer(config['tts_checkpoint'], config['tts_config'], None, None) + config["tts_checkpoint"] = os.path.join(tts_root_path, config["tts_checkpoint"]) + config["tts_config"] = os.path.join(tts_root_path, config["tts_config"]) + synthesizer = Synthesizer(config["tts_checkpoint"], config["tts_config"], None, None) synthesizer.tts("Better this test works!!") def test_split_into_sentences(self): @@ -38,20 +38,48 @@ def test_split_into_sentences(self): # pylint: disable=attribute-defined-outside-init self.seg = Synthesizer.get_segmenter("en") sis = Synthesizer.split_into_sentences - assert sis(self, 'Hello. Two sentences') == ['Hello.', 'Two sentences'] - assert sis(self, 'He went to meet the adviser from Scott, Waltman & Co. next morning.') == ['He went to meet the adviser from Scott, Waltman & Co. next morning.'] - assert sis(self, 'Let\'s run it past Sarah and co. They\'ll want to see this.') == ['Let\'s run it past Sarah and co.', 'They\'ll want to see this.'] - assert sis(self, 'Where is Bobby Jr.\'s rabbit?') == ['Where is Bobby Jr.\'s rabbit?'] - assert sis(self, 'Please inform the U.K. authorities right away.') == ['Please inform the U.K. authorities right away.'] - assert sis(self, 'Were David and co. at the event?') == ['Were David and co. at the event?'] - assert sis(self, 'paging dr. green, please come to theatre four immediately.') == ['paging dr. green, please come to theatre four immediately.'] - assert sis(self, 'The email format is Firstname.Lastname@example.com. I think you reversed them.') == ['The email format is Firstname.Lastname@example.com.', 'I think you reversed them.'] - assert sis(self, 'The demo site is: https://top100.example.com/subsection/latestnews.html. Please send us your feedback.') == ['The demo site is: https://top100.example.com/subsection/latestnews.html.', 'Please send us your feedback.'] - assert sis(self, 'Scowling at him, \'You are not done yet!\' she yelled.') == ['Scowling at him, \'You are not done yet!\' she yelled.'] # with the final lowercase "she" we see it's all one sentence - assert sis(self, 'Hey!! So good to see you.') == ['Hey!!', 'So good to see you.'] - assert sis(self, 'He went to Yahoo! but I don\'t know the division.') == ['He went to Yahoo! but I don\'t know the division.'] - assert sis(self, 'If you can\'t remember a quote, “at least make up a memorable one that\'s plausible..."') == ['If you can\'t remember a quote, “at least make up a memorable one that\'s plausible..."'] - assert sis(self, 'The address is not google.com.') == ['The address is not google.com.'] - assert sis(self, '1.) The first item 2.) The second item') == ['1.) The first item', '2.) The second item'] - assert sis(self, '1) The first item 2) The second item') == ['1) The first item', '2) The second item'] - assert sis(self, 'a. The first item b. The second item c. The third list item') == ['a. The first item', 'b. The second item', 'c. The third list item'] + assert sis(self, "Hello. Two sentences") == ["Hello.", "Two sentences"] + assert sis(self, "He went to meet the adviser from Scott, Waltman & Co. next morning.") == [ + "He went to meet the adviser from Scott, Waltman & Co. next morning." + ] + assert sis(self, "Let's run it past Sarah and co. They'll want to see this.") == [ + "Let's run it past Sarah and co.", + "They'll want to see this.", + ] + assert sis(self, "Where is Bobby Jr.'s rabbit?") == ["Where is Bobby Jr.'s rabbit?"] + assert sis(self, "Please inform the U.K. authorities right away.") == [ + "Please inform the U.K. authorities right away." + ] + assert sis(self, "Were David and co. at the event?") == ["Were David and co. at the event?"] + assert sis(self, "paging dr. green, please come to theatre four immediately.") == [ + "paging dr. green, please come to theatre four immediately." + ] + assert sis(self, "The email format is Firstname.Lastname@example.com. I think you reversed them.") == [ + "The email format is Firstname.Lastname@example.com.", + "I think you reversed them.", + ] + assert sis( + self, + "The demo site is: https://top100.example.com/subsection/latestnews.html. Please send us your feedback.", + ) == [ + "The demo site is: https://top100.example.com/subsection/latestnews.html.", + "Please send us your feedback.", + ] + assert sis(self, "Scowling at him, 'You are not done yet!' she yelled.") == [ + "Scowling at him, 'You are not done yet!' she yelled." + ] # with the final lowercase "she" we see it's all one sentence + assert sis(self, "Hey!! So good to see you.") == ["Hey!!", "So good to see you."] + assert sis(self, "He went to Yahoo! but I don't know the division.") == [ + "He went to Yahoo! but I don't know the division." + ] + assert sis(self, "If you can't remember a quote, “at least make up a memorable one that's plausible...\"") == [ + "If you can't remember a quote, “at least make up a memorable one that's plausible...\"" + ] + assert sis(self, "The address is not google.com.") == ["The address is not google.com."] + assert sis(self, "1.) The first item 2.) The second item") == ["1.) The first item", "2.) The second item"] + assert sis(self, "1) The first item 2) The second item") == ["1) The first item", "2) The second item"] + assert sis(self, "a. The first item b. The second item c. The third list item") == [ + "a. The first item", + "b. The second item", + "c. The third list item", + ] diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index 4ac0711811..0e35605f57 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -3,21 +3,21 @@ import unittest import torch -from tests import get_tests_input_path from torch import nn, optim +from tests import get_tests_input_path from TTS.tts.layers.losses import MSELossMasked from TTS.tts.models.tacotron2 import Tacotron2 -from TTS.utils.io import load_config from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_config -#pylint: disable=unused-variable +# pylint: disable=unused-variable torch.manual_seed(1) use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -c = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +c = load_config(os.path.join(get_tests_input_path(), "test_config.json")) ap = AudioProcessor(**c.audio) WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") @@ -26,20 +26,19 @@ class TacotronTrainTest(unittest.TestCase): def test_train_step(self): # pylint: disable=no-self-use input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) mel_lengths[0] = 30 stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = MSELossMasked(seq_len_norm=False).to(device) @@ -48,14 +47,14 @@ def test_train_step(self): # pylint: disable=no-self-use model.train() model_ref = copy.deepcopy(model) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(5): mel_out, mel_postnet_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + ) assert torch.sigmoid(stop_tokens).data.max() <= 1.0 assert torch.sigmoid(stop_tokens).data.min() >= 0.0 optimizer.zero_grad() @@ -66,13 +65,12 @@ def test_train_step(self): # pylint: disable=no-self-use optimizer.step() # check parameter changes count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional # if count not in [145, 59]: - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1 @@ -80,20 +78,19 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) mel_lengths[0] = 30 stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_embeddings = torch.rand(8, 55).to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = MSELossMasked(seq_len_norm=False).to(device) @@ -102,14 +99,14 @@ def test_train_step(): model.train() model_ref = copy.deepcopy(model) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(5): mel_out, mel_postnet_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings + ) assert torch.sigmoid(stop_tokens).data.max() <= 1.0 assert torch.sigmoid(stop_tokens).data.min() >= 0.0 optimizer.zero_grad() @@ -120,39 +117,46 @@ def test_train_step(): optimizer.step() # check parameter changes count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional # if count not in [145, 59]: - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1 + class TacotronGSTTrainTest(unittest.TestCase): - #pylint: disable=no-self-use + # pylint: disable=no-self-use def test_train_step(self): # with random gst mel style input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) mel_lengths[0] = 30 stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = MSELossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) - model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, gst=True, gst_embedding_dim=c.gst['gst_embedding_dim'], gst_num_heads=c.gst['gst_num_heads'], gst_style_tokens=c.gst['gst_style_tokens']).to(device) + model = Tacotron2( + num_chars=24, + r=c.r, + num_speakers=5, + gst=True, + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + ).to(device) model.train() model_ref = copy.deepcopy(model) count = 0 @@ -162,7 +166,8 @@ def test_train_step(self): optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(10): mel_out, mel_postnet_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + ) assert torch.sigmoid(stop_tokens).data.max() <= 1.0 assert torch.sigmoid(stop_tokens).data.min() >= 0.0 optimizer.zero_grad() @@ -177,36 +182,45 @@ def test_train_step(self): # ignore pre-higway layer since it works conditional # if count not in [145, 59]: name, param = name_param - if name == 'gst_layer.encoder.recurrence.weight_hh_l0': - #print(param.grad) + if name == "gst_layer.encoder.recurrence.weight_hh_l0": + # print(param.grad) continue - assert (param != param_ref).any( - ), "param {} {} with shape {} not updated!! \n{}\n{}".format( - name, count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} {} with shape {} not updated!! \n{}\n{}".format( + name, count, param.shape, param, param_ref + ) count += 1 # with file gst style - mel_spec = torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :30].unsqueeze(0).transpose(1, 2).to(device) + mel_spec = ( + torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :30].unsqueeze(0).transpose(1, 2).to(device) + ) mel_spec = mel_spec.repeat(8, 1, 1) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] - mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) mel_lengths[0] = 30 stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = MSELossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) - model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, gst=True, gst_embedding_dim=c.gst['gst_embedding_dim'], gst_num_heads=c.gst['gst_num_heads'], gst_style_tokens=c.gst['gst_style_tokens']).to(device) + model = Tacotron2( + num_chars=24, + r=c.r, + num_speakers=5, + gst=True, + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + ).to(device) model.train() model_ref = copy.deepcopy(model) count = 0 @@ -216,7 +230,8 @@ def test_train_step(self): optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(10): mel_out, mel_postnet_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + ) assert torch.sigmoid(stop_tokens).data.max() <= 1.0 assert torch.sigmoid(stop_tokens).data.min() >= 0.0 optimizer.zero_grad() @@ -231,47 +246,57 @@ def test_train_step(self): # ignore pre-higway layer since it works conditional # if count not in [145, 59]: name, param = name_param - if name == 'gst_layer.encoder.recurrence.weight_hh_l0': - #print(param.grad) + if name == "gst_layer.encoder.recurrence.weight_hh_l0": + # print(param.grad) continue - assert (param != param_ref).any( - ), "param {} {} with shape {} not updated!! \n{}\n{}".format( - name, count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} {} with shape {} not updated!! \n{}\n{}".format( + name, count, param.shape, param, param_ref + ) count += 1 + class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.randint(100, 128, (8,)).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) mel_lengths[0] = 30 stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_embeddings = torch.rand(8, 55).to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = MSELossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) - model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, speaker_embedding_dim=55, gst=True, gst_embedding_dim=c.gst['gst_embedding_dim'], gst_num_heads=c.gst['gst_num_heads'], gst_style_tokens=c.gst['gst_style_tokens'], gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding']).to(device) + model = Tacotron2( + num_chars=24, + r=c.r, + num_speakers=5, + speaker_embedding_dim=55, + gst=True, + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"], + ).to(device) model.train() model_ref = copy.deepcopy(model) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(5): mel_out, mel_postnet_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings + ) assert torch.sigmoid(stop_tokens).data.max() <= 1.0 assert torch.sigmoid(stop_tokens).data.min() >= 0.0 optimizer.zero_grad() @@ -282,14 +307,13 @@ def test_train_step(): optimizer.step() # check parameter changes count = 0 - for name_param, param_ref in zip(model.named_parameters(), - model_ref.parameters()): + for name_param, param_ref in zip(model.named_parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional # if count not in [145, 59]: name, param = name_param - if name == 'gst_layer.encoder.recurrence.weight_hh_l0': + if name == "gst_layer.encoder.recurrence.weight_hh_l0": continue - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1 diff --git a/tests/test_tacotron2_tf_model.py b/tests/test_tacotron2_tf_model.py index b792cfa7e9..767e5ffc20 100644 --- a/tests/test_tacotron2_tf_model.py +++ b/tests/test_tacotron2_tf_model.py @@ -4,54 +4,57 @@ import numpy as np import tensorflow as tf import torch + from tests import get_tests_input_path from TTS.tts.tf.models.tacotron2 import Tacotron2 -from TTS.tts.tf.utils.tflite import (convert_tacotron2_to_tflite, - load_tflite_model) +from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite, load_tflite_model from TTS.utils.io import load_config -tf.get_logger().setLevel('INFO') - +tf.get_logger().setLevel("INFO") -#pylint: disable=unused-variable +# pylint: disable=unused-variable torch.manual_seed(1) use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -c = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +c = load_config(os.path.join(get_tests_input_path(), "test_config.json")) class TacotronTFTrainTest(unittest.TestCase): - @staticmethod def generate_dummy_inputs(): chars_seq = torch.randint(0, 24, (8, 128)).long().to(device) - chars_seq_lengths = torch.randint(100, 128, (8, )).long().to(device) + chars_seq_lengths = torch.randint(100, 128, (8,)).long().to(device) chars_seq_lengths = torch.sort(chars_seq_lengths, descending=True)[0] - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) chars_seq = tf.convert_to_tensor(chars_seq.cpu().numpy()) chars_seq_lengths = tf.convert_to_tensor(chars_seq_lengths.cpu().numpy()) mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy()) - return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\ - stop_targets, speaker_ids + return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths, stop_targets, speaker_ids def test_train_step(self): - ''' test forward pass ''' - chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\ - stop_targets, speaker_ids = self.generate_dummy_inputs() + """ test forward pass """ + ( + chars_seq, + chars_seq_lengths, + mel_spec, + mel_postnet_spec, + mel_lengths, + stop_targets, + speaker_ids, + ) = self.generate_dummy_inputs() for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(chars_seq.shape[0], - stop_targets.size(1) // c.r, -1) + stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() model = Tacotron2(num_chars=24, r=c.r, num_speakers=5) @@ -68,15 +71,23 @@ def test_train_step(self): # inference pass output = model(chars_seq, training=False) - def test_forward_attention(self,): - chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\ - stop_targets, speaker_ids = self.generate_dummy_inputs() + def test_forward_attention( + self, + ): + ( + chars_seq, + chars_seq_lengths, + mel_spec, + mel_postnet_spec, + mel_lengths, + stop_targets, + speaker_ids, + ) = self.generate_dummy_inputs() for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(chars_seq.shape[0], - stop_targets.size(1) // c.r, -1) + stop_targets = stop_targets.view(chars_seq.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() model = Tacotron2(num_chars=24, r=c.r, num_speakers=5, forward_attn=True) @@ -93,45 +104,51 @@ def test_forward_attention(self,): # inference pass output = model(chars_seq, training=False) - def test_tflite_conversion(self, ): #pylint:disable=no-self-use - model = Tacotron2(num_chars=24, - num_speakers=0, - r=3, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type='original', - attn_win=False, - attn_norm='sigmoid', - prenet_type='original', - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=0, - separate_stopnet=True, - bidirectional_decoder=False, - enable_tflite=True) + def test_tflite_conversion( + self, + ): # pylint:disable=no-self-use + model = Tacotron2( + num_chars=24, + num_speakers=0, + r=3, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="sigmoid", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=0, + separate_stopnet=True, + bidirectional_decoder=False, + enable_tflite=True, + ) model.build_inference() - convert_tacotron2_to_tflite(model, output_path='test_tacotron2.tflite', experimental_converter=True) + convert_tacotron2_to_tflite(model, output_path="test_tacotron2.tflite", experimental_converter=True) # init tflite model - tflite_model = load_tflite_model('test_tacotron2.tflite') + tflite_model = load_tflite_model("test_tacotron2.tflite") # fake input - inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) #pylint:disable=unexpected-keyword-arg + inputs = tf.random.uniform([1, 4], maxval=10, dtype=tf.int32) # pylint:disable=unexpected-keyword-arg # run inference # get input and output details input_details = tflite_model.get_input_details() output_details = tflite_model.get_output_details() # reshape input tensor for the new input shape - tflite_model.resize_tensor_input(input_details[0]['index'], inputs.shape) #pylint:disable=unexpected-keyword-arg + tflite_model.resize_tensor_input( + input_details[0]["index"], inputs.shape + ) # pylint:disable=unexpected-keyword-arg tflite_model.allocate_tensors() detail = input_details[0] - input_shape = detail['shape'] - tflite_model.set_tensor(detail['index'], inputs) + input_shape = detail["shape"] + tflite_model.set_tensor(detail["index"], inputs) # run the tflite_model tflite_model.invoke() # collect outputs - decoder_output = tflite_model.get_tensor(output_details[0]['index']) - postnet_output = tflite_model.get_tensor(output_details[1]['index']) + decoder_output = tflite_model.get_tensor(output_details[0]["index"]) + postnet_output = tflite_model.get_tensor(output_details[1]["index"]) # remove tflite binary - os.remove('test_tacotron2.tflite') + os.remove("test_tacotron2.tflite") diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index f8e88160f7..e3ed8ae2a0 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -3,21 +3,21 @@ import unittest import torch -from tests import get_tests_input_path from torch import nn, optim +from tests import get_tests_input_path from TTS.tts.layers.losses import L1LossMasked from TTS.tts.models.tacotron import Tacotron -from TTS.utils.io import load_config from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_config -#pylint: disable=unused-variable +# pylint: disable=unused-variable torch.manual_seed(1) use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -c = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +c = load_config(os.path.join(get_tests_input_path(), "test_config.json")) ap = AudioProcessor(**c.audio) WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") @@ -32,147 +32,140 @@ class TacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > - 0.0).unsqueeze(2).float().squeeze() + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = L1LossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) model = Tacotron( num_chars=32, num_speakers=5, - postnet_output_dim=c.audio['fft_size'], - decoder_output_dim=c.audio['num_mels'], + postnet_output_dim=c.audio["fft_size"], + decoder_output_dim=c.audio["num_mels"], r=c.r, - memory_size=c.memory_size - ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + memory_size=c.memory_size, + ).to( + device + ) # FIXME: missing num_speakers parameter to Tacotron ctor model.train() - print(" > Num parameters for Tacotron model:%s" % - (count_parameters(model))) + print(" > Num parameters for Tacotron model:%s" % (count_parameters(model))) model_ref = copy.deepcopy(model) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for _ in range(5): mel_out, linear_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + ) optimizer.zero_grad() loss = criterion(mel_out, mel_spec, mel_lengths) stop_loss = criterion_st(stop_tokens, stop_targets) - loss = loss + criterion(linear_out, linear_spec, - mel_lengths) + stop_loss + loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss loss.backward() optimizer.step() # check parameter changes count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional # if count not in [145, 59]: - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1 + class MultiSpeakeTacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_embeddings = torch.rand(8, 55).to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > - 0.0).unsqueeze(2).float().squeeze() + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = L1LossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) model = Tacotron( num_chars=32, num_speakers=5, - postnet_output_dim=c.audio['fft_size'], - decoder_output_dim=c.audio['num_mels'], + postnet_output_dim=c.audio["fft_size"], + decoder_output_dim=c.audio["num_mels"], r=c.r, memory_size=c.memory_size, speaker_embedding_dim=55, - ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + ).to( + device + ) # FIXME: missing num_speakers parameter to Tacotron ctor model.train() - print(" > Num parameters for Tacotron model:%s" % - (count_parameters(model))) + print(" > Num parameters for Tacotron model:%s" % (count_parameters(model))) model_ref = copy.deepcopy(model) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for _ in range(5): mel_out, linear_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, - speaker_embeddings=speaker_embeddings) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings + ) optimizer.zero_grad() loss = criterion(mel_out, mel_spec, mel_lengths) stop_loss = criterion_st(stop_tokens, stop_targets) - loss = loss + criterion(linear_out, linear_spec, - mel_lengths) + stop_loss + loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss loss.backward() optimizer.step() # check parameter changes count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional # if count not in [145, 59]: - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1 + class TacotronGSTTrainTest(unittest.TestCase): @staticmethod def test_train_step(): # with random gst mel style input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 120, c.audio['fft_size']).to(device) - mel_lengths = torch.randint(20, 120, (8, )).long().to(device) + mel_spec = torch.rand(8, 120, c.audio["num_mels"]).to(device) + linear_spec = torch.rand(8, 120, c.audio["fft_size"]).to(device) + mel_lengths = torch.randint(20, 120, (8,)).long().to(device) mel_lengths[-1] = 120 stop_targets = torch.zeros(8, 120, 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > - 0.0).unsqueeze(2).float().squeeze() + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = L1LossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) @@ -180,65 +173,64 @@ def test_train_step(): num_chars=32, num_speakers=5, gst=True, - gst_embedding_dim=c.gst['gst_embedding_dim'], - gst_num_heads=c.gst['gst_num_heads'], - gst_style_tokens=c.gst['gst_style_tokens'], - postnet_output_dim=c.audio['fft_size'], - decoder_output_dim=c.audio['num_mels'], + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + postnet_output_dim=c.audio["fft_size"], + decoder_output_dim=c.audio["num_mels"], r=c.r, - memory_size=c.memory_size - ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + memory_size=c.memory_size, + ).to( + device + ) # FIXME: missing num_speakers parameter to Tacotron ctor model.train() # print(model) - print(" > Num parameters for Tacotron GST model:%s" % - (count_parameters(model))) + print(" > Num parameters for Tacotron GST model:%s" % (count_parameters(model))) model_ref = copy.deepcopy(model) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for _ in range(10): mel_out, linear_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + ) optimizer.zero_grad() loss = criterion(mel_out, mel_spec, mel_lengths) stop_loss = criterion_st(stop_tokens, stop_targets) - loss = loss + criterion(linear_out, linear_spec, - mel_lengths) + stop_loss + loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss loss.backward() optimizer.step() # check parameter changes count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1 # with file gst style - mel_spec = torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :120].unsqueeze(0).transpose(1, 2).to(device) + mel_spec = ( + torch.FloatTensor(ap.melspectrogram(ap.load_wav(WAV_FILE)))[:, :120].unsqueeze(0).transpose(1, 2).to(device) + ) mel_spec = mel_spec.repeat(8, 1, 1) input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 - linear_spec = torch.rand(8, mel_spec.size(1), c.audio['fft_size']).to(device) - mel_lengths = torch.randint(20, mel_spec.size(1), (8, )).long().to(device) + linear_spec = torch.rand(8, mel_spec.size(1), c.audio["fft_size"]).to(device) + mel_lengths = torch.randint(20, mel_spec.size(1), (8,)).long().to(device) mel_lengths[-1] = mel_spec.size(1) stop_targets = torch.zeros(8, mel_spec.size(1), 1).float().to(device) - speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + speaker_ids = torch.randint(0, 5, (8,)).long().to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > - 0.0).unsqueeze(2).float().squeeze() + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = L1LossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) @@ -246,113 +238,109 @@ def test_train_step(): num_chars=32, num_speakers=5, gst=True, - gst_embedding_dim=c.gst['gst_embedding_dim'], - gst_num_heads=c.gst['gst_num_heads'], - gst_style_tokens=c.gst['gst_style_tokens'], - postnet_output_dim=c.audio['fft_size'], - decoder_output_dim=c.audio['num_mels'], + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + postnet_output_dim=c.audio["fft_size"], + decoder_output_dim=c.audio["num_mels"], r=c.r, - memory_size=c.memory_size - ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + memory_size=c.memory_size, + ).to( + device + ) # FIXME: missing num_speakers parameter to Tacotron ctor model.train() # print(model) - print(" > Num parameters for Tacotron GST model:%s" % - (count_parameters(model))) + print(" > Num parameters for Tacotron GST model:%s" % (count_parameters(model))) model_ref = copy.deepcopy(model) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for _ in range(10): mel_out, linear_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + ) optimizer.zero_grad() loss = criterion(mel_out, mel_spec, mel_lengths) stop_loss = criterion_st(stop_tokens, stop_targets) - loss = loss + criterion(linear_out, linear_spec, - mel_lengths) + stop_loss + loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss loss.backward() optimizer.step() # check parameter changes count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1 + class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase): @staticmethod def test_train_step(): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) + linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device) + mel_lengths = torch.randint(20, 30, (8,)).long().to(device) mel_lengths[-1] = mel_spec.size(1) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_embeddings = torch.rand(8, 55).to(device) for idx in mel_lengths: - stop_targets[:, int(idx.item()):, 0] = 1.0 + stop_targets[:, int(idx.item()) :, 0] = 1.0 - stop_targets = stop_targets.view(input_dummy.shape[0], - stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > - 0.0).unsqueeze(2).float().squeeze() + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = L1LossMasked(seq_len_norm=False).to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) model = Tacotron( num_chars=32, num_speakers=5, - postnet_output_dim=c.audio['fft_size'], - decoder_output_dim=c.audio['num_mels'], + postnet_output_dim=c.audio["fft_size"], + decoder_output_dim=c.audio["num_mels"], gst=True, - gst_embedding_dim=c.gst['gst_embedding_dim'], - gst_num_heads=c.gst['gst_num_heads'], - gst_style_tokens=c.gst['gst_style_tokens'], - gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'], + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"], r=c.r, memory_size=c.memory_size, speaker_embedding_dim=55, - ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + ).to( + device + ) # FIXME: missing num_speakers parameter to Tacotron ctor model.train() - print(" > Num parameters for Tacotron model:%s" % - (count_parameters(model))) + print(" > Num parameters for Tacotron model:%s" % (count_parameters(model))) model_ref = copy.deepcopy(model) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) for _ in range(5): mel_out, linear_out, align, stop_tokens = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, - speaker_embeddings=speaker_embeddings) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_embeddings=speaker_embeddings + ) optimizer.zero_grad() loss = criterion(mel_out, mel_spec, mel_lengths) stop_loss = criterion_st(stop_tokens, stop_targets) - loss = loss + criterion(linear_out, linear_spec, - mel_lengths) + stop_loss + loss = loss + criterion(linear_out, linear_spec, mel_lengths) + stop_loss loss.backward() optimizer.step() # check parameter changes count = 0 - for name_param, param_ref in zip(model.named_parameters(), - model_ref.parameters()): + for name_param, param_ref in zip(model.named_parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional # if count not in [145, 59]: name, param = name_param - if name == 'gst_layer.encoder.recurrence.weight_hh_l0': + if name == "gst_layer.encoder.recurrence.weight_hh_l0": continue - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1 diff --git a/tests/test_text_cleaners.py b/tests/test_text_cleaners.py index b301fb5aee..fcfa71e77d 100644 --- a/tests/test_text_cleaners.py +++ b/tests/test_text_cleaners.py @@ -17,5 +17,5 @@ def test_currency() -> None: def test_expand_numbers() -> None: - assert phoneme_cleaners("-1") == 'minus one' - assert phoneme_cleaners("1") == 'one' + assert phoneme_cleaners("-1") == "minus one" + assert phoneme_cleaners("1") == "one" diff --git a/tests/test_text_processing.py b/tests/test_text_processing.py index 774ac0aa46..f70056b14d 100644 --- a/tests/test_text_processing.py +++ b/tests/test_text_processing.py @@ -1,13 +1,14 @@ import os + # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import # pylint: disable=unused-import -from tests import get_tests_input_path +from tests import get_tests_input_path, get_tests_path from TTS.tts.utils.text import * -from tests import get_tests_path from TTS.utils.io import load_config -conf = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +conf = load_config(os.path.join(get_tests_input_path(), "test_config.json")) + def test_phoneme_to_sequence(): @@ -18,7 +19,7 @@ def test_phoneme_to_sequence(): text_hat = sequence_to_phoneme(sequence) _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) - gt = 'ɹiːsənt ɹᵻsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjʊleɪʃən ænd lɜːnɪŋ!' + gt = "ɹiːsənt ɹᵻsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪŋkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹᵻspɑːnsᵻbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjʊleɪʃən ænd lɜːnɪŋ!" assert text_hat == text_hat_with_params == gt # multiple punctuations @@ -87,6 +88,7 @@ def test_phoneme_to_sequence(): print(len(sequence)) assert text_hat == text_hat_with_params == gt + def test_phoneme_to_sequence_with_blank_token(): text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" @@ -105,7 +107,7 @@ def test_phoneme_to_sequence_with_blank_token(): text_hat = sequence_to_phoneme(sequence) _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) - gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ?' + gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ?" print(text_hat) print(len(sequence)) assert text_hat == text_hat_with_params == gt @@ -116,7 +118,7 @@ def test_phoneme_to_sequence_with_blank_token(): text_hat = sequence_to_phoneme(sequence) _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) - gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ' + gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ" print(text_hat) print(len(sequence)) assert text_hat == text_hat_with_params == gt @@ -127,7 +129,7 @@ def test_phoneme_to_sequence_with_blank_token(): text_hat = sequence_to_phoneme(sequence) _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) - gt = 'biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!' + gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" print(text_hat) print(len(sequence)) assert text_hat == text_hat_with_params == gt @@ -138,7 +140,7 @@ def test_phoneme_to_sequence_with_blank_token(): text_hat = sequence_to_phoneme(sequence) _ = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters, add_blank=True) text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters, add_blank=True) - gt = 'biː ɐ vɔɪs, nɑːt æn! ɛkoʊ.' + gt = "biː ɐ vɔɪs, nɑːt æn! ɛkoʊ." print(text_hat) print(len(sequence)) assert text_hat == text_hat_with_params == gt @@ -165,9 +167,10 @@ def test_phoneme_to_sequence_with_blank_token(): print(len(sequence)) assert text_hat == text_hat_with_params == gt + def test_text2phone(): text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" - gt = 'ɹ|iː|s|ə|n|t| |ɹ|ᵻ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|ŋ|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ᵻ|s|p|ɑː|n|s|ᵻ|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|ʊ|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!' + gt = "ɹ|iː|s|ə|n|t| |ɹ|ᵻ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|ŋ|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ᵻ|s|p|ɑː|n|s|ᵻ|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|ʊ|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!" lang = "en-us" ph = text2phone(text, lang) assert gt == ph diff --git a/tests/test_vocoder_gan_datasets.py b/tests/test_vocoder_gan_datasets.py index 32592d493c..9564da3bc2 100644 --- a/tests/test_vocoder_gan_datasets.py +++ b/tests/test_vocoder_gan_datasets.py @@ -1,9 +1,9 @@ import os import numpy as np -from tests import get_tests_path, get_tests_input_path, get_tests_output_path from torch.utils.data import DataLoader +from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config from TTS.vocoder.datasets.gan_dataset import GANDataset @@ -13,32 +13,33 @@ OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/") os.makedirs(OUTPATH, exist_ok=True) -C = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +C = load_config(os.path.join(get_tests_input_path(), "test_config.json")) test_data_path = os.path.join(get_tests_path(), "data/ljspeech/") ok_ljspeech = os.path.exists(test_data_path) -def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_pairs, return_segments, use_noise_augment, use_cache, num_workers): - '''Run dataloader with given parameters and check conditions ''' +def gan_dataset_case( + batch_size, seq_len, hop_len, conv_pad, return_pairs, return_segments, use_noise_augment, use_cache, num_workers +): + """Run dataloader with given parameters and check conditions """ ap = AudioProcessor(**C.audio) _, train_items = load_wav_data(test_data_path, 10) - dataset = GANDataset(ap, - train_items, - seq_len=seq_len, - hop_len=hop_len, - pad_short=2000, - conv_pad=conv_pad, - return_pairs=return_pairs, - return_segments=return_segments, - use_noise_augment=use_noise_augment, - use_cache=use_cache) - loader = DataLoader(dataset=dataset, - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - pin_memory=True, - drop_last=True) + dataset = GANDataset( + ap, + train_items, + seq_len=seq_len, + hop_len=hop_len, + pad_short=2000, + conv_pad=conv_pad, + return_pairs=return_pairs, + return_segments=return_segments, + use_noise_augment=use_noise_augment, + use_cache=use_cache, + ) + loader = DataLoader( + dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True + ) max_iter = 10 count_iter = 0 @@ -59,9 +60,8 @@ def check_item(feat, wav): mel = ap.melspectrogram(audio) # the first 2 and the last 2 frames are skipped due to the padding # differences in stft - max_diff = abs((feat - mel[:, :feat.shape[-1]])[:, 2:-2]).max() - assert max_diff <= 0, f' [!] {max_diff}' - + max_diff = abs((feat - mel[:, : feat.shape[-1]])[:, 2:-2]).max() + assert max_diff <= 0, f" [!] {max_diff}" # return random segments or return the whole audio if return_segments: @@ -90,18 +90,18 @@ def check_item(feat, wav): def test_parametrized_gan_dataset(): - ''' test dataloader with different parameters ''' + """ test dataloader with different parameters """ params = [ - [32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, True, 0], - [32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, True, 4], - [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, True, True, 0], - [1, C.audio['hop_length'], C.audio['hop_length'], 0, True, True, True, True, 0], - [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, True, True, True, True, 0], - [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, True, 0], - [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, True, 0], - [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, False, True, True, False, 0], - [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, False, False, 0], - [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, False, False, 0] + [32, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, True, 0], + [32, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, True, 4], + [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, True, True, 0], + [1, C.audio["hop_length"], C.audio["hop_length"], 0, True, True, True, True, 0], + [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, True, True, True, True, 0], + [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, True, True, 0], + [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, True, False, True, 0], + [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, False, True, True, False, 0], + [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, False, False, 0], + [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 0, True, False, False, False, 0], ] for param in params: print(param) diff --git a/tests/test_vocoder_losses.py b/tests/test_vocoder_losses.py index 2f38dd5a28..87151a050f 100644 --- a/tests/test_vocoder_losses.py +++ b/tests/test_vocoder_losses.py @@ -1,11 +1,11 @@ import os import torch -from tests import get_tests_input_path, get_tests_output_path, get_tests_path +from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config -from TTS.vocoder.layers.losses import MultiScaleSTFTLoss, STFTLoss, TorchSTFT, MelganFeatureLoss +from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT TESTS_PATH = get_tests_path() @@ -14,7 +14,7 @@ WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") -C = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) +C = load_config(os.path.join(get_tests_input_path(), "test_config.json")) ap = AudioProcessor(**C.audio) @@ -22,7 +22,7 @@ def test_torch_stft(): torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length) # librosa stft wav = ap.load_wav(WAV_FILE) - M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access + M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access # torch stft wav = torch.from_numpy(wav[None, :]).float() M_torch = torch_stft(wav) @@ -42,9 +42,11 @@ def test_stft_loss(): def test_multiscale_stft_loss(): - stft_loss = MultiScaleSTFTLoss([ap.fft_size//2, ap.fft_size, ap.fft_size*2], - [ap.hop_length // 2, ap.hop_length, ap.hop_length * 2], - [ap.win_length // 2, ap.win_length, ap.win_length * 2]) + stft_loss = MultiScaleSTFTLoss( + [ap.fft_size // 2, ap.fft_size, ap.fft_size * 2], + [ap.hop_length // 2, ap.hop_length, ap.hop_length * 2], + [ap.win_length // 2, ap.win_length, ap.win_length * 2], + ) wav = ap.load_wav(WAV_FILE) wav = torch.from_numpy(wav[None, :]).float() loss_m, loss_sc = stft_loss(wav, wav) diff --git a/tests/test_vocoder_melgan_generator.py b/tests/test_vocoder_melgan_generator.py index fedf630184..f4958de427 100644 --- a/tests/test_vocoder_melgan_generator.py +++ b/tests/test_vocoder_melgan_generator.py @@ -3,6 +3,7 @@ from TTS.vocoder.models.melgan_generator import MelganGenerator + def test_melgan_generator(): model = MelganGenerator() print(model) diff --git a/tests/test_vocoder_parallel_wavegan_discriminator.py b/tests/test_vocoder_parallel_wavegan_discriminator.py index b496e216ba..d4eca0d137 100644 --- a/tests/test_vocoder_parallel_wavegan_discriminator.py +++ b/tests/test_vocoder_parallel_wavegan_discriminator.py @@ -1,7 +1,10 @@ import numpy as np import torch -from TTS.vocoder.models.parallel_wavegan_discriminator import ParallelWaveganDiscriminator, ResidualParallelWaveganDiscriminator +from TTS.vocoder.models.parallel_wavegan_discriminator import ( + ParallelWaveganDiscriminator, + ResidualParallelWaveganDiscriminator, +) def test_pwgan_disciminator(): @@ -14,7 +17,8 @@ def test_pwgan_disciminator(): dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, - bias=True) + bias=True, + ) dummy_x = torch.rand((4, 1, 64 * 256)) output = model(dummy_x) assert np.all(output.shape == (4, 1, 64 * 256)) @@ -34,7 +38,8 @@ def test_redisual_pwgan_disciminator(): dropout=0.0, bias=True, nonlinear_activation="LeakyReLU", - nonlinear_activation_params={"negative_slope": 0.2}) + nonlinear_activation_params={"negative_slope": 0.2}, + ) dummy_x = torch.rand((4, 1, 64 * 256)) output = model(dummy_x) assert np.all(output.shape == (4, 1, 64 * 256)) diff --git a/tests/test_vocoder_parallel_wavegan_generator.py b/tests/test_vocoder_parallel_wavegan_generator.py index 9eed0eeef4..21f6f08fd6 100644 --- a/tests/test_vocoder_parallel_wavegan_generator.py +++ b/tests/test_vocoder_parallel_wavegan_generator.py @@ -18,7 +18,8 @@ def test_pwgan_generator(): dropout=0.0, bias=True, use_weight_norm=True, - upsample_factors=[4, 4, 4, 4]) + upsample_factors=[4, 4, 4, 4], + ) dummy_c = torch.rand((2, 80, 5)) output = model(dummy_c) assert np.all(output.shape == (2, 1, 5 * 256)), output.shape diff --git a/tests/test_vocoder_pqmf.py b/tests/test_vocoder_pqmf.py index 74da451fba..afe8d1dc8f 100644 --- a/tests/test_vocoder_pqmf.py +++ b/tests/test_vocoder_pqmf.py @@ -1,13 +1,12 @@ import os -import torch import soundfile as sf +import torch from librosa.core import load -from tests import get_tests_path, get_tests_input_path, get_tests_output_path +from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.vocoder.layers.pqmf import PQMF - TESTS_PATH = get_tests_path() WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") @@ -24,5 +23,4 @@ def test_pqmf(): print(w2_.max()) print(w2_.min()) print(w2_.mean()) - sf.write(os.path.join(get_tests_output_path(), 'pqmf_output.wav'), - w2_.flatten().detach(), sr) + sf.write(os.path.join(get_tests_output_path(), "pqmf_output.wav"), w2_.flatten().detach(), sr) diff --git a/tests/test_vocoder_rwd.py b/tests/test_vocoder_rwd.py index 424d3b498c..371ad9e41e 100644 --- a/tests/test_vocoder_rwd.py +++ b/tests/test_vocoder_rwd.py @@ -1,18 +1,16 @@ -import torch import numpy as np +import torch from TTS.vocoder.models.random_window_discriminator import RandomWindowDiscriminator def test_rwd(): - layer = RandomWindowDiscriminator(cond_channels=80, - window_sizes=(512, 1024, 2048, 4096, - 8192), - cond_disc_downsample_factors=[ - (8, 4, 2, 2, 2), (8, 4, 2, 2), - (8, 4, 2), (8, 4), (4, 2, 2) - ], - hop_length=256) + layer = RandomWindowDiscriminator( + cond_channels=80, + window_sizes=(512, 1024, 2048, 4096, 8192), + cond_disc_downsample_factors=[(8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)], + hop_length=256, + ) x = torch.rand([4, 1, 22050]) c = torch.rand([4, 80, 22050 // 256]) diff --git a/tests/test_vocoder_tf_pqmf.py b/tests/test_vocoder_tf_pqmf.py index 16c46b2a83..f1c3666ba9 100644 --- a/tests/test_vocoder_tf_pqmf.py +++ b/tests/test_vocoder_tf_pqmf.py @@ -1,13 +1,12 @@ import os -import tensorflow as tf import soundfile as sf +import tensorflow as tf from librosa.core import load -from tests import get_tests_path, get_tests_input_path, get_tests_output_path +from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.vocoder.tf.layers.pqmf import PQMF - TESTS_PATH = get_tests_path() WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") @@ -25,5 +24,4 @@ def test_pqmf(): print(w2_.max()) print(w2_.min()) print(w2_.mean()) - sf.write(os.path.join(get_tests_output_path(), 'tf_pqmf_output.wav'), - w2_.flatten(), sr) + sf.write(os.path.join(get_tests_output_path(), "tf_pqmf_output.wav"), w2_.flatten(), sr) diff --git a/tests/test_vocoder_wavernn.py b/tests/test_vocoder_wavernn.py index 2464cfa3de..9c58fa1c84 100644 --- a/tests/test_vocoder_wavernn.py +++ b/tests/test_vocoder_wavernn.py @@ -1,6 +1,8 @@ +import random + import numpy as np import torch -import random + from TTS.vocoder.models.wavernn import WaveRNN diff --git a/tests/test_vocoder_wavernn_datasets.py b/tests/test_vocoder_wavernn_datasets.py index a95e247ab1..7bd4380bad 100644 --- a/tests/test_vocoder_wavernn_datasets.py +++ b/tests/test_vocoder_wavernn_datasets.py @@ -2,20 +2,19 @@ import shutil import numpy as np -from tests import get_tests_path, get_tests_input_path, get_tests_output_path from torch.utils.data import DataLoader +from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config -from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.preprocess import load_wav_feat_data, preprocess_wav_files +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset file_path = os.path.dirname(os.path.realpath(__file__)) OUTPATH = os.path.join(get_tests_output_path(), "loader_tests/") os.makedirs(OUTPATH, exist_ok=True) -C = load_config(os.path.join(get_tests_input_path(), - "test_vocoder_wavernn_config.json")) +C = load_config(os.path.join(get_tests_input_path(), "test_vocoder_wavernn_config.json")) test_data_path = os.path.join(get_tests_path(), "data/ljspeech/") test_mel_feat_path = os.path.join(test_data_path, "mel") @@ -33,25 +32,20 @@ def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_wor C.data_path = test_data_path preprocess_wav_files(test_data_path, C, ap) - _, train_items = load_wav_feat_data( - test_data_path, test_mel_feat_path, 5) + _, train_items = load_wav_feat_data(test_data_path, test_mel_feat_path, 5) - dataset = WaveRNNDataset(ap=ap, - items=train_items, - seq_len=seq_len, - hop_len=hop_len, - pad=pad, - mode=mode, - mulaw=mulaw - ) + dataset = WaveRNNDataset( + ap=ap, items=train_items, seq_len=seq_len, hop_len=hop_len, pad=pad, mode=mode, mulaw=mulaw + ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader(dataset, - shuffle=True, - collate_fn=dataset.collate, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=True, - ) + loader = DataLoader( + dataset, + shuffle=True, + collate_fn=dataset.collate, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + ) max_iter = 10 count_iter = 0 @@ -59,10 +53,8 @@ def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_wor try: for data in loader: x_input, mels, _ = data - expected_feat_shape = (ap.num_mels, - (x_input.shape[-1] // hop_len) + (pad * 2)) - assert np.all( - mels.shape[1:] == expected_feat_shape), f" [!] {mels.shape} vs {expected_feat_shape}" + expected_feat_shape = (ap.num_mels, (x_input.shape[-1] // hop_len) + (pad * 2)) + assert np.all(mels.shape[1:] == expected_feat_shape), f" [!] {mels.shape} vs {expected_feat_shape}" assert (mels.shape[2] - pad * 2) * hop_len == x_input.shape[1] count_iter += 1 @@ -77,15 +69,15 @@ def wavernn_dataset_case(batch_size, seq_len, hop_len, pad, mode, mulaw, num_wor def test_parametrized_wavernn_dataset(): - ''' test dataloader with different parameters ''' + """ test dataloader with different parameters """ params = [ - [16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 10, True, 0], - [16, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, "mold", False, 4], - [1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, 9, False, 0], - [1, C.audio['hop_length'], C.audio['hop_length'], 2, 10, True, 0], - [1, C.audio['hop_length'], C.audio['hop_length'], 2, "mold", False, 0], - [1, C.audio['hop_length'] * 5, C.audio['hop_length'], 4, 10, False, 2], - [1, C.audio['hop_length'] * 5, C.audio['hop_length'], 2, "mold", False, 0], + [16, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, 10, True, 0], + [16, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, "mold", False, 4], + [1, C.audio["hop_length"] * 10, C.audio["hop_length"], 2, 9, False, 0], + [1, C.audio["hop_length"], C.audio["hop_length"], 2, 10, True, 0], + [1, C.audio["hop_length"], C.audio["hop_length"], 2, "mold", False, 0], + [1, C.audio["hop_length"] * 5, C.audio["hop_length"], 4, 10, False, 2], + [1, C.audio["hop_length"] * 5, C.audio["hop_length"], 2, "mold", False, 0], ] for param in params: print(param) diff --git a/tests/test_wavegrad_layers.py b/tests/test_wavegrad_layers.py index d81ae47d6c..0180eb0a46 100644 --- a/tests/test_wavegrad_layers.py +++ b/tests/test_wavegrad_layers.py @@ -1,6 +1,6 @@ import torch -from TTS.vocoder.layers.wavegrad import PositionalEncoding, FiLM, UBlock, DBlock +from TTS.vocoder.layers.wavegrad import DBlock, FiLM, PositionalEncoding, UBlock from TTS.vocoder.models.wavegrad import Wavegrad @@ -75,12 +75,12 @@ def test_wavegrad_forward(): c = torch.rand(32, 80, 20) noise_scale = torch.rand(32) - model = Wavegrad(in_channels=80, - out_channels=1, - upsample_factors=[5, 5, 3, 2, 2], - upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], - [1, 2, 4, 8], [1, 2, 4, 8], - [1, 2, 4, 8]]) + model = Wavegrad( + in_channels=80, + out_channels=1, + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]], + ) o = model.forward(x, c, noise_scale) assert o.shape[0] == 32 diff --git a/tests/test_wavegrad_train.py b/tests/test_wavegrad_train.py index 45f75e3bd0..a28409e58e 100644 --- a/tests/test_wavegrad_train.py +++ b/tests/test_wavegrad_train.py @@ -3,9 +3,10 @@ import numpy as np import torch from torch import optim + from TTS.vocoder.models.wavegrad import Wavegrad -#pylint: disable=unused-variable +# pylint: disable=unused-variable torch.manual_seed(1) use_cuda = torch.cuda.is_available() @@ -19,19 +20,19 @@ def test_train_step(self): # pylint: disable=no-self-use mel_spec = torch.rand(8, 80, 20).to(device) criterion = torch.nn.L1Loss().to(device) - model = Wavegrad(in_channels=80, - out_channels=1, - upsample_factors=[5, 5, 3, 2, 2], - upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], - [1, 2, 4, 8], [1, 2, 4, 8], - [1, 2, 4, 8]]) - - model_ref = Wavegrad(in_channels=80, - out_channels=1, - upsample_factors=[5, 5, 3, 2, 2], - upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], - [1, 2, 4, 8], [1, 2, 4, 8], - [1, 2, 4, 8]]) + model = Wavegrad( + in_channels=80, + out_channels=1, + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]], + ) + + model_ref = Wavegrad( + in_channels=80, + out_channels=1, + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]], + ) model.train() model.to(device) betas = np.linspace(1e-6, 1e-2, 1000) @@ -39,8 +40,7 @@ def test_train_step(self): # pylint: disable=no-self-use model_ref.load_state_dict(model.state_dict()) model_ref.to(device) count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=0.001) @@ -52,11 +52,10 @@ def test_train_step(self): # pylint: disable=no-self-use optimizer.step() # check parameter changes count = 0 - for param, param_ref in zip(model.parameters(), - model_ref.parameters()): + for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional # if count not in [145, 59]: - assert (param != param_ref).any( - ), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref) + assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref + ) count += 1