From bbef82fa86cb99ba08c6c71d8144e51689b7bc7e Mon Sep 17 00:00:00 2001 From: KamioRinn Date: Tue, 20 Feb 2024 22:41:39 +0800 Subject: [PATCH 1/2] Refactoring get phones and bert --- GPT_SoVITS/inference_webui.py | 167 +++++++++++----------------------- GPT_SoVITS/text/chinese.py | 2 +- requirements.txt | 2 +- 3 files changed, 55 insertions(+), 116 deletions(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 407437f4..70519dab 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -209,54 +209,8 @@ def get_spepc(hps, filename): } -def splite_en_inf(sentence, language): - pattern = re.compile(r'[a-zA-Z ]+') - textlist = [] - langlist = [] - pos = 0 - for match in pattern.finditer(sentence): - start, end = match.span() - if start > pos: - textlist.append(sentence[pos:start]) - langlist.append(language) - textlist.append(sentence[start:end]) - langlist.append("en") - pos = end - if pos < len(sentence): - textlist.append(sentence[pos:]) - langlist.append(language) - # Merge punctuation into previous word - for i in range(len(textlist)-1, 0, -1): - if re.match(r'^[\W_]+$', textlist[i]): - textlist[i-1] += textlist[i] - del textlist[i] - del langlist[i] - # Merge consecutive words with the same language tag - i = 0 - while i < len(langlist) - 1: - if langlist[i] == langlist[i+1]: - textlist[i] += textlist[i+1] - del textlist[i+1] - del langlist[i+1] - else: - i += 1 - - return textlist, langlist - - def clean_text_inf(text, language): - formattext = "" - language = language.replace("all_","") - for tmp in LangSegment.getTexts(text): - if language == "ja": - if tmp["lang"] == language or tmp["lang"] == "zh": - formattext += tmp["text"] + " " - continue - if tmp["lang"] == language: - formattext += tmp["text"] + " " - while " " in formattext: - formattext = formattext.replace(" ", " ") - phones, word2ph, norm_text = clean_text(formattext, language) + phones, word2ph, norm_text = clean_text(text, language) phones = cleaned_text_to_sequence(phones) return phones, word2ph, norm_text @@ -274,55 +228,6 @@ def get_bert_inf(phones, word2ph, norm_text, language): return bert -def nonen_clean_text_inf(text, language): - if(language!="auto"): - textlist, langlist = splite_en_inf(text, language) - else: - textlist=[] - langlist=[] - for tmp in LangSegment.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - phones_list = [] - word2ph_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) - phones_list.append(phones) - if lang == "zh": - word2ph_list.append(word2ph) - norm_text_list.append(norm_text) - print(word2ph_list) - phones = sum(phones_list, []) - word2ph = sum(word2ph_list, []) - norm_text = ' '.join(norm_text_list) - - return phones, word2ph, norm_text - - -def nonen_get_bert_inf(text, language): - if(language!="auto"): - textlist, langlist = splite_en_inf(text, language) - else: - textlist=[] - langlist=[] - for tmp in LangSegment.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - print(textlist) - print(langlist) - bert_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) - bert = get_bert_inf(phones, word2ph, norm_text, lang) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - - return bert - - splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } @@ -332,23 +237,59 @@ def get_first(text): return text -def get_cleaned_text_final(text,language): +def get_phones_and_bert(text,language): if language in {"en","all_zh","all_ja"}: - phones, word2ph, norm_text = clean_text_inf(text, language) + language = language.replace("all_","") + if language == "en": + LangSegment.setfilters(["en"]) + formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) + else: + # 因无法区别中日文汉字,以用户输入为准 + formattext = text + while " " in formattext: + formattext = formattext.replace(" ", " ") + phones, word2ph, norm_text = clean_text_inf(formattext, language) + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) elif language in {"zh", "ja","auto"}: - phones, word2ph, norm_text = nonen_clean_text_inf(text, language) - return phones, word2ph, norm_text + textlist=[] + langlist=[] + LangSegment.setfilters(["zh","ja","en"]) + if language == "auto": + for tmp in LangSegment.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegment.getTexts(text): + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + print(textlist) + print(langlist) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = ' '.join(norm_text_list) + + return phones,bert.to(dtype),norm_text -def get_bert_final(phones, word2ph, text,language,device): - if language == "en": - bert = get_bert_inf(phones, word2ph, text, language) - elif language in {"zh", "ja","auto"}: - bert = nonen_get_bert_inf(text, language) - elif language == "all_zh": - bert = get_bert_feature(text, word2ph).to(device) - else: - bert = torch.zeros((1024, len(phones))).to(device) - return bert def merge_short_text_in_array(texts, threshold): if (len(texts)) < 2: @@ -425,8 +366,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, texts = merge_short_text_in_array(texts, 5) audio_opt = [] if not ref_free: - phones1, word2ph1, norm_text1=get_cleaned_text_final(prompt_text, prompt_language) - bert1=get_bert_final(phones1, word2ph1, norm_text1,prompt_language,device).to(dtype) + phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language) for text in texts: # 解决输入目标文本的空行导致报错的问题 @@ -434,8 +374,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, continue if (text[-1] not in splits): text += "。" if text_language != "en" else "." print(i18n("实际输入的目标文本(每句):"), text) - phones2, word2ph2, norm_text2 = get_cleaned_text_final(text, text_language) - bert2 = get_bert_final(phones2, word2ph2, norm_text2, text_language, device).to(dtype) + phones2,bert2,norm_text2=get_phones_and_bert(text, text_language) if not ref_free: bert = torch.cat([bert1, bert2], 1) all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0) diff --git a/GPT_SoVITS/text/chinese.py b/GPT_SoVITS/text/chinese.py index 5334326e..ea41db1f 100644 --- a/GPT_SoVITS/text/chinese.py +++ b/GPT_SoVITS/text/chinese.py @@ -30,7 +30,7 @@ "\n": ".", "·": ",", "、": ",", - # "...": "…", + "...": "…", "$": ".", "/": ",", "—": "-", diff --git a/requirements.txt b/requirements.txt index fae6198d..75bd945d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,5 +23,5 @@ PyYAML psutil jieba_fast jieba -LangSegment +LangSegment>=0.2.0 Faster_Whisper \ No newline at end of file From 76570cff52ff81e90b6b5f98e80aa657afc70738 Mon Sep 17 00:00:00 2001 From: KamioRinn Date: Tue, 20 Feb 2024 22:45:49 +0800 Subject: [PATCH 2/2] Del a-zA-Z --- GPT_SoVITS/inference_webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 70519dab..c427b25f 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -245,7 +245,7 @@ def get_phones_and_bert(text,language): formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text)) else: # 因无法区别中日文汉字,以用户输入为准 - formattext = text + formattext = re.sub('[a-zA-Z]', '', text) while " " in formattext: formattext = formattext.replace(" ", " ") phones, word2ph, norm_text = clean_text_inf(formattext, language)