From 675f98355077310f4f0fbb88fdb135730d0426f2 Mon Sep 17 00:00:00 2001 From: Julian Weber Date: Thu, 16 Nov 2023 11:01:11 +0100 Subject: [PATCH] Add sentence splitting (#3227) * Add sentence spliting * update requirements * update default args v2 * Add spanish * Fix return gpt_latents * Update requirements * Fix requirements --- TTS/tts/layers/xtts/tokenizer.py | 78 +++++++++-- TTS/tts/models/xtts.py | 233 +++++++++++++++++-------------- requirements.txt | 1 + 3 files changed, 194 insertions(+), 118 deletions(-) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 7726d829ac..56eb78aed4 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -1,10 +1,10 @@ -import json import os import re -from functools import cached_property - -import pypinyin import torch +import pypinyin +import textwrap + +from functools import cached_property from hangul_romanize import Transliter from hangul_romanize.rule import academic from num2words import num2words @@ -12,6 +12,61 @@ from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words +from spacy.lang.en import English +from spacy.lang.zh import Chinese +from spacy.lang.ja import Japanese +from spacy.lang.ar import Arabic +from spacy.lang.es import Spanish + + +def get_spacy_lang(lang): + if lang == "zh": + return Chinese() + elif lang == "ja": + return Japanese() + elif lang == "ar": + return Arabic() + elif lang == "es": + return Spanish() + else: + # For most languages, Enlish does the job + return English() + +def split_sentence(text, lang, text_split_length=250): + """Preprocess the input text""" + text_splits = [] + if text_split_length is not None and len(text) >= text_split_length: + text_splits.append("") + nlp = get_spacy_lang(lang) + nlp.add_pipe("sentencizer") + doc = nlp(text) + for sentence in doc.sents: + if len(text_splits[-1]) + len(str(sentence)) <= text_split_length: + # if the last sentence + the current sentence is less than the text_split_length + # then add the current sentence to the last sentence + text_splits[-1] += " " + str(sentence) + text_splits[-1] = text_splits[-1].lstrip() + elif len(str(sentence)) > text_split_length: + # if the current sentence is greater than the text_split_length + for line in textwrap.wrap( + str(sentence), + width=text_split_length, + drop_whitespace=True, + break_on_hyphens=False, + tabsize=1, + ): + text_splits.append(str(line)) + else: + text_splits.append(str(sentence)) + + if len(text_splits) > 1: + if text_splits[0] == "": + del text_splits[0] + else: + text_splits = [text.lstrip()] + + return text_splits + _whitespace_re = re.compile(r"\s+") # List of (regular expression, replacement) pairs for abbreviations: @@ -464,7 +519,7 @@ def _expand_number(m, lang="en"): def expand_numbers_multilingual(text, lang="en"): - if lang == "zh" or lang == "zh-cn": + if lang == "zh": text = zh_num2words()(text) else: if lang in ["en", "ru"]: @@ -525,7 +580,7 @@ def japanese_cleaners(text, katsu): return text -def korean_cleaners(text): +def korean_transliterate(text): r = Transliter(academic) return r.translit(text) @@ -546,7 +601,7 @@ def __init__(self, vocab_file=None): "it": 213, "pt": 203, "pl": 224, - "zh-cn": 82, + "zh": 82, "ar": 166, "cs": 186, "ru": 182, @@ -571,19 +626,20 @@ def check_input_length(self, txt, lang): ) def preprocess_text(self, txt, lang): - if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}: + if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "ko"}: txt = multilingual_cleaners(txt, lang) - if lang in {"zh", "zh-cn"}: + if lang == "zh": txt = chinese_transliterate(txt) + if lang == "ko": + txt = korean_transliterate(txt) elif lang == "ja": txt = japanese_cleaners(txt, self.katsu) - elif lang == "ko": - txt = korean_cleaners(txt) else: raise NotImplementedError(f"Language '{lang}' is not supported.") return txt def encode(self, txt, lang): + lang = lang.split("-")[0] # remove the region self.check_input_length(txt, lang) txt = self.preprocess_text(txt, lang) txt = f"[{lang}]{txt}" diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index f37f08449d..5ccb26c314 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -10,7 +10,7 @@ from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support -from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec @@ -420,9 +420,9 @@ def full_inference( ref_audio_path, language, # GPT inference - temperature=0.65, - length_penalty=1, - repetition_penalty=2.0, + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, top_k=50, top_p=0.85, do_sample=True, @@ -502,71 +502,78 @@ def inference( gpt_cond_latent, speaker_embedding, # GPT inference - temperature=0.65, - length_penalty=1, - repetition_penalty=2.0, + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, top_k=50, top_p=0.85, do_sample=True, num_beams=1, speed=1.0, + enable_text_splitting=False, **hf_generate_kwargs, ): + language = language.split("-")[0] # remove the country code length_scale = 1.0 / max(speed, 0.05) - text = text.strip().lower() - text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) - - # print(" > Input text: ", text) - # print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language)) - # print(" > Input tokens: ", text_tokens) - # print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy())) - assert ( - text_tokens.shape[-1] < self.args.gpt_max_text_tokens - ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - - with torch.no_grad(): - gpt_codes = self.gpt.generate( - cond_latents=gpt_cond_latent, - text_inputs=text_tokens, - input_tokens=None, - do_sample=do_sample, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_return_sequences=self.gpt_batch_size, - num_beams=num_beams, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - output_attentions=False, - **hf_generate_kwargs, - ) - expected_output_len = torch.tensor( - [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device - ) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] + + wavs = [] + gpt_latents_list = [] + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + with torch.no_grad(): + gpt_codes = self.gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + num_beams=num_beams, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) - text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) - gpt_latents = self.gpt( - text_tokens, - text_len, - gpt_codes, - expected_output_len, - cond_latents=gpt_cond_latent, - return_attentions=False, - return_latent=True, - ) + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = self.gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) - if length_scale != 1.0: - gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), - scale_factor=length_scale, - mode="linear" - ).transpose(1, 2) + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), + scale_factor=length_scale, + mode="linear" + ).transpose(1, 2) - wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) + gpt_latents_list.append(gpt_latents.cpu()) + wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()) return { - "wav": wav.cpu().numpy().squeeze(), - "gpt_latents": gpt_latents, + "wav": torch.cat(wavs, dim=0).numpy(), + "gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(), "speaker_embedding": speaker_embedding, } @@ -606,66 +613,78 @@ def inference_stream( stream_chunk_size=20, overlap_wav_len=1024, # GPT inference - temperature=0.65, - length_penalty=1, - repetition_penalty=2.0, + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, top_k=50, top_p=0.85, do_sample=True, speed=1.0, + enable_text_splitting=False, **hf_generate_kwargs, ): + language = language.split("-")[0] # remove the country code length_scale = 1.0 / max(speed, 0.05) - text = text.strip().lower() - text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] - fake_inputs = self.gpt.compute_embeddings( - gpt_cond_latent.to(self.device), - text_tokens, - ) - gpt_generator = self.gpt.get_generator( - fake_inputs=fake_inputs, - top_k=top_k, - top_p=top_p, - temperature=temperature, - do_sample=do_sample, - num_beams=1, - num_return_sequences=1, - length_penalty=float(length_penalty), - repetition_penalty=float(repetition_penalty), - output_attentions=False, - output_hidden_states=True, - **hf_generate_kwargs, - ) + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) - last_tokens = [] - all_latents = [] - wav_gen_prev = None - wav_overlap = None - is_end = False - - while not is_end: - try: - x, latent = next(gpt_generator) - last_tokens += [x] - all_latents += [latent] - except StopIteration: - is_end = True - - if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): - gpt_latents = torch.cat(all_latents, dim=0)[None, :] - if length_scale != 1.0: - gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), - scale_factor=length_scale, - mode="linear" - ).transpose(1, 2) - wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) - wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( - wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len - ) - last_tokens = [] - yield wav_chunk + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + fake_inputs = self.gpt.compute_embeddings( + gpt_cond_latent.to(self.device), + text_tokens, + ) + gpt_generator = self.gpt.get_generator( + fake_inputs=fake_inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + do_sample=do_sample, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + + last_tokens = [] + all_latents = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + + while not is_end: + try: + x, latent = next(gpt_generator) + last_tokens += [x] + all_latents += [latent] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), + scale_factor=length_scale, + mode="linear" + ).transpose(1, 2) + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + last_tokens = [] + yield wav_chunk def forward(self): raise NotImplementedError( diff --git a/requirements.txt b/requirements.txt index 53e8af590c..836de40ab6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,3 +54,4 @@ encodec==0.1.* # deps for XTTS unidecode==1.3.* num2words +spacy[ja]>=3 \ No newline at end of file