diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py index fcdaa9d76e..7bea02ca08 100644 --- a/TTS/tts/layers/tortoise/diffusion.py +++ b/TTS/tts/layers/tortoise/diffusion.py @@ -17,7 +17,6 @@ from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper - try: from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index d914ebf90f..e7b186b858 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -441,7 +441,9 @@ def forward( audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token) # Pad mel codes with stop_audio_token - audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet + audio_codes = self.set_mel_padding( + audio_codes, code_lengths - 3 + ) # -3 to get the real code lengths without consider start and stop tokens that was not added yet # Build input and target tensors # Prepend start token to inputs and append stop token to targets diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 1ef655a3cc..5284874397 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -1,23 +1,22 @@ import os import re -import torch -import pypinyin import textwrap - from functools import cached_property + +import pypinyin +import torch from hangul_romanize import Transliter from hangul_romanize.rule import academic from num2words import num2words +from spacy.lang.ar import Arabic +from spacy.lang.en import English +from spacy.lang.es import Spanish +from spacy.lang.ja import Japanese +from spacy.lang.zh import Chinese from tokenizers import Tokenizer from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words -from spacy.lang.en import English -from spacy.lang.zh import Chinese -from spacy.lang.ja import Japanese -from spacy.lang.ar import Arabic -from spacy.lang.es import Spanish - def get_spacy_lang(lang): if lang == "zh": @@ -32,6 +31,7 @@ def get_spacy_lang(lang): # For most languages, Enlish does the job return English() + def split_sentence(text, lang, text_split_length=250): """Preprocess the input text""" text_splits = [] @@ -67,6 +67,7 @@ def split_sentence(text, lang, text_split_length=250): return text_splits + _whitespace_re = re.compile(r"\s+") # List of (regular expression, replacement) pairs for abbreviations: @@ -619,7 +620,7 @@ def katsu(self): return cutlet.Cutlet() def check_input_length(self, txt, lang): - lang = lang.split("-")[0] # remove the region + lang = lang.split("-")[0] # remove the region limit = self.char_limits.get(lang, 250) if len(txt) > limit: print( @@ -640,7 +641,7 @@ def preprocess_text(self, txt, lang): return txt def encode(self, txt, lang): - lang = lang.split("-")[0] # remove the region + lang = lang.split("-")[0] # remove the region self.check_input_length(txt, lang) txt = self.preprocess_text(txt, lang) lang = "zh-cn" if lang == "zh" else lang diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 3583591f8b..208ec4d561 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -513,13 +513,13 @@ def inference( enable_text_splitting=False, **hf_generate_kwargs, ): - language = language.split("-")[0] # remove the country code + language = language.split("-")[0] # remove the country code length_scale = 1.0 / max(speed, 0.05) if enable_text_splitting: text = split_sentence(text, language, self.tokenizer.char_limits[language]) else: text = [text] - + wavs = [] gpt_latents_list = [] for sent in text: @@ -563,9 +563,7 @@ def inference( if length_scale != 1.0: gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), - scale_factor=length_scale, - mode="linear" + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" ).transpose(1, 2) gpt_latents_list.append(gpt_latents.cpu()) @@ -623,7 +621,7 @@ def inference_stream( enable_text_splitting=False, **hf_generate_kwargs, ): - language = language.split("-")[0] # remove the country code + language = language.split("-")[0] # remove the country code length_scale = 1.0 / max(speed, 0.05) if enable_text_splitting: text = split_sentence(text, language, self.tokenizer.char_limits[language]) @@ -675,9 +673,7 @@ def inference_stream( gpt_latents = torch.cat(all_latents, dim=0)[None, :] if length_scale != 1.0: gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), - scale_factor=length_scale, - mode="linear" + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" ).transpose(1, 2) wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index a5aad5c1ea..8fa56e287a 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -186,7 +186,7 @@ def test_xtts_v2_streaming(): "en", gpt_cond_latent, speaker_embedding, - speed=1.5 + speed=1.5, ) wav_chuncks = [] for i, chunk in enumerate(chunks): @@ -198,7 +198,7 @@ def test_xtts_v2_streaming(): "en", gpt_cond_latent, speaker_embedding, - speed=0.66 + speed=0.66, ) wav_chuncks = [] for i, chunk in enumerate(chunks):