Skip to content

Commit

Permalink
Make style
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Dec 20, 2021
1 parent 54b7fb4 commit 704dddc
Show file tree
Hide file tree
Showing 18 changed files with 94 additions and 73 deletions.
1 change: 0 additions & 1 deletion TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def main(args): # pylint: disable=redefined-outer-name
else:
speaker_manager = None


# setup model
model = setup_model(c)

Expand Down
1 change: 1 addition & 0 deletions TTS/bin/find_unique_phonemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def compute_phonemes(item):
return []
return list(set(ph))


def main():
# pylint: disable=W0601
global c
Expand Down
14 changes: 6 additions & 8 deletions TTS/bin/remove_silence_using_vad.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import glob
import pathlib
import argparse
import glob
import multiprocessing
import os
import pathlib

from tqdm.contrib.concurrent import process_map

from TTS.utils.vad import read_wave, write_wave, get_vad_speech_segments
from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave


def remove_silence(filepath):
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
Expand Down Expand Up @@ -69,10 +70,7 @@ def preprocess_audios():
parser.add_argument(
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
)
parser.add_argument("-f", "--force",
default=False,
action='store_true',
help='Force the replace of exists files')
parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
parser.add_argument(
"-g",
"--glob",
Expand Down
2 changes: 1 addition & 1 deletion TTS/bin/train_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models import setup_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor


Expand Down
2 changes: 1 addition & 1 deletion TTS/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def convert_boolean(x):

# load models
synthesizer = Synthesizer(
model_path, config_path, speakers_file_path, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
model_path, config_path, speakers_file_path, None, vocoder_path, vocoder_config_path, use_cuda=args.use_cuda
)

use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and synthesizer.tts_model.num_speakers > 1
Expand Down
11 changes: 6 additions & 5 deletions TTS/speaker_encoder/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import torch
from torch import nn

from TTS.utils.audio import TorchSTFT
from TTS.utils.io import load_fsspec

# import torchaudio

from TTS.utils.audio import TorchSTFT

from TTS.utils.io import load_fsspec


class PreEmphasis(torch.nn.Module):
Expand Down Expand Up @@ -126,16 +127,16 @@ def __init__(
n_mels=audio_config["num_mels"],
power=2.0,
use_mel=True,
mel_norm=None
mel_norm=None,
),
'''torchaudio.transforms.MelSpectrogram(
"""torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),'''
),""",
)
else:
self.torch_spec = None
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def collate_fn(self, batch):
"waveform": wav_padded,
"raw_text": batch["raw_text"],
"pitch": pitch,
"language_ids": language_ids
"language_ids": language_ids,
}

raise TypeError(
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):

@staticmethod
def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
l = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
l = -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()
return l

def forward(
Expand Down
8 changes: 2 additions & 6 deletions TTS/tts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
from TTS.utils.generic_utils import find_module


def setup_model(
config,
speaker_manager: "SpeakerManager" = None,
language_manager: "LanguageManager" = None
):
def setup_model(config, speaker_manager: "SpeakerManager" = None, language_manager: "LanguageManager" = None):
print(" > Using model: {}".format(config.model))
# fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
Expand Down Expand Up @@ -35,7 +31,7 @@ def setup_model(
config.model_params.num_chars = num_chars
if "model_args" in config:
config.model_args.num_chars = num_chars
if config.model.lower() in ["vits"]: # If model supports multiple languages
if config.model.lower() in ["vits"]: # If model supports multiple languages
model = MyModel(config, speaker_manager=speaker_manager, language_manager=language_manager)
else:
model = MyModel(config, speaker_manager=speaker_manager)
Expand Down
21 changes: 14 additions & 7 deletions TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
Expand Down Expand Up @@ -150,7 +150,13 @@ def get_aux_input_from_test_setences(self, sentence_info):
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]

return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
return {
"text": text,
"speaker_id": speaker_id,
"style_wav": style_wav,
"d_vector": d_vector,
"language_id": language_id,
}

def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`.
Expand Down Expand Up @@ -337,14 +343,16 @@ def get_data_loader(
if config.compute_f0:
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))



# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None

# Weighted samplers
assert not (num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)), "language_weighted_sampler is not supported with DistributedSampler"
assert not (num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)), "speaker_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)
), "language_weighted_sampler is not supported with DistributedSampler"
assert not (
num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)
), "speaker_weighted_sampler is not supported with DistributedSampler"

if sampler is None:
if getattr(config, "use_language_weighted_sampler", False):
Expand All @@ -354,7 +362,6 @@ def get_data_loader(
print(" > Using Language weighted sampler")
sampler = get_speaker_weighted_sampler(dataset.items)


loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
Expand Down
7 changes: 4 additions & 3 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, List, Tuple

import torch

# import torchaudio
from coqpit import Coqpit
from torch import nn
Expand Down Expand Up @@ -420,8 +421,9 @@ def init_multispeaker(self, config: Coqpit):
):
# TODO: change this with torchaudio Resample
raise RuntimeError(
' [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!'
.format(self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"])
" [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format(
self.config.audio["sample_rate"], self.speaker_encoder.audio_config["sample_rate"]
)
)
# pylint: disable=W0101,W0105
""" self.audio_transform = torchaudio.transforms.Resample(
Expand Down Expand Up @@ -675,7 +677,6 @@ def forward(
)
return outputs


def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}):
"""
Shapes:
Expand Down
10 changes: 8 additions & 2 deletions TTS/utils/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
spec_gain=1.0,
power=None,
use_htk=False,
mel_norm="slaney"
mel_norm="slaney",
):
super().__init__()
self.n_fft = n_fft
Expand Down Expand Up @@ -155,7 +155,13 @@ def __call__(self, x):

def _build_mel_basis(self):
mel_basis = librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax, htk=self.use_htk, norm=self.mel_norm
self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
htk=self.use_htk,
norm=self.mel_norm,
)
self.mel_basis = torch.from_numpy(mel_basis).float()

Expand Down
13 changes: 5 additions & 8 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from TTS.config import load_config
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager

# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
Expand Down Expand Up @@ -200,12 +200,7 @@ def save_wav(self, wav: List[int], path: str) -> None:
self.ap.save_wav(wav, path, self.output_sample_rate)

def tts(
self,
text: str,
speaker_idx: str = "",
language_idx: str = "",
speaker_wav=None,
style_wav=None
self, text: str, speaker_idx: str = "", language_idx: str = "", speaker_wav=None, style_wav=None
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
Expand Down Expand Up @@ -254,7 +249,9 @@ def tts(

# handle multi-lingaul
language_id = None
if self.tts_languages_file or (hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None):
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
):
if language_idx and isinstance(language_idx, str):
language_id = self.tts_model.language_manager.language_id_mapping[language_idx]

Expand Down
8 changes: 5 additions & 3 deletions TTS/utils/vad.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
import collections
import contextlib
import wave

import webrtcvad
import contextlib
import collections


def read_wave(path):
Expand Down Expand Up @@ -37,7 +38,7 @@ class Frame(object):
"""Represents a "frame" of audio data."""

def __init__(self, _bytes, timestamp, duration):
self.bytes =_bytes
self.bytes = _bytes
self.timestamp = timestamp
self.duration = duration

Expand Down Expand Up @@ -133,6 +134,7 @@ def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, fram
if voiced_frames:
yield b"".join([f.bytes for f in voiced_frames])


def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):

vad = webrtcvad.Vad(int(aggressiveness))
Expand Down
35 changes: 24 additions & 11 deletions recipes/multilingual/vits_tts/train_vits_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor

output_path = os.path.dirname(os.path.abspath(__file__))

mailabs_path = '/home/julian/workspace/mailabs/**'
mailabs_path = "/home/julian/workspace/mailabs/**"
dataset_paths = glob(mailabs_path)
dataset_config = [BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split('/')[-1]) for path in dataset_paths]
dataset_config = [
BaseDatasetConfig(name="mailabs", meta_file_train=None, path=path, language=path.split("/")[-1])
for path in dataset_paths
]

audio_config = BaseAudioConfig(
sample_rate=16000,
Expand Down Expand Up @@ -61,29 +64,39 @@
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
compute_input_seq_cache=True,
print_step=25,
use_language_weighted_sampler= True,
use_language_weighted_sampler=True,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
min_seq_len=32 * 256 * 4,
max_seq_len=160000,
output_path=output_path,
datasets=dataset_config,
characters= {
characters={
"pad": "_",
"eos": "&",
"bos": "*",
"characters": "!¡'(),-.:;¿?abcdefghijklmnopqrstuvwxyzµßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒабвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ «°±µ»$%&‘’‚“`”„",
"punctuations": "!¡'(),-.:;¿? ",
"phonemes": None,
"unique": True
"unique": True,
},
test_sentences=[
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", 'mary_ann', None, 'en_US'],
["Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.", "ezwa", None, 'fr_FR'],
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, 'de_DE'],
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, 'ru_RU'],
]
[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"mary_ann",
None,
"en_US",
],
[
"Il m'a fallu beaucoup de temps pour d\u00e9velopper une voix, et maintenant que je l'ai, je ne vais pas me taire.",
"ezwa",
None,
"fr_FR",
],
["Ich finde, dieses Startup ist wirklich unglaublich.", "eva_k", None, "de_DE"],
["Я думаю, что этот стартап действительно удивительный.", "oblomov", None, "ru_RU"],
],
)

# init audio processor
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ unidic-lite==1.0.8
gruut[cs,de,es,fr,it,nl,pt,ru,sv]~=2.0.0
fsspec>=2021.04.0
pyworld
webrtcvad
Loading

0 comments on commit 704dddc

Please sign in to comment.