Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize English g2p function for improved speed and efficiency #124

Merged
merged 2 commits into from
May 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 27 additions & 115 deletions style_bert_vits2/nlp/english/g2p.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,34 @@
import re

from g2p_en import G2p

from style_bert_vits2.constants import Languages
from style_bert_vits2.nlp import bert_models
from style_bert_vits2.nlp.english.cmudict import get_dict
from style_bert_vits2.nlp.symbols import PUNCTUATIONS, SYMBOLS

# Initialize global variables once
ARPA = {
"AH0", "S", "AH1", "EY2", "AE2", "EH0", "OW2", "UH0", "NG", "B", "G", "AY0",
"M", "AA0", "F", "AO0", "ER2", "UH1", "IY1", "AH2", "DH", "IY0", "EY1",
"IH0", "K", "N", "W", "IY2", "T", "AA1", "ER1", "EH2", "OY0", "UH2", "UW1",
"Z", "AW2", "AW1", "V", "UW2", "AA2", "ER", "AW0", "UW0", "R", "OW1", "EH1",
"ZH", "AE0", "IH2", "IH", "Y", "JH", "P", "AY1", "EY0", "OY2", "TH", "HH",
"D", "ER0", "CH", "AO1", "AE1", "AO2", "OY1", "AY2", "IH1", "OW0", "L",
"SH"
}
_g2p = G2p()
eng_dict = get_dict()

def g2p(text: str) -> tuple[list[str], list[int], list[int]]:

ARPA = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}

_g2p = G2p()

phones = []
tones = []
phone_len = []
# tokens = [tokenizer.tokenize(i) for i in words]
words = __text_to_words(text)
eng_dict = get_dict()

for word in words:
temp_phones, temp_tones = [], []
if len(word) > 1:
if "'" in word:
word = ["".join(word)]
if len(word) > 1 and "'" in word:
word = ["".join(word)]

for w in word:
if w in PUNCTUATIONS:
temp_phones.append(w)
Expand All @@ -107,11 +38,9 @@ def g2p(text: str) -> tuple[list[str], list[int], list[int]]:
phns, tns = __refine_syllables(eng_dict[w.upper()])
temp_phones += [__post_replace_ph(i) for i in phns]
temp_tones += tns
# w2ph.append(len(phns))
else:
phone_list = list(filter(lambda p: p != " ", _g2p(w))) # type: ignore
phns = []
tns = []
phone_list = list(filter(lambda p: p != " ", _g2p(w)))
phns, tns = [], []
for ph in phone_list:
if ph in ARPA:
ph, tn = __refine_ph(ph)
Expand All @@ -122,17 +51,15 @@ def g2p(text: str) -> tuple[list[str], list[int], list[int]]:
tns.append(0)
temp_phones += [__post_replace_ph(i) for i in phns]
temp_tones += tns

phones += temp_phones
tones += temp_tones
phone_len.append(len(temp_phones))
# phones = [post_replace_ph(i) for i in phones]

word2ph = []
for token, pl in zip(words, phone_len):
word_len = len(token)

aaa = __distribute_phone(pl, word_len)
word2ph += aaa
word2ph += __distribute_phone(pl, word_len)

phones = ["_"] + phones + ["_"]
tones = [0] + tones + [0]
Expand All @@ -145,27 +72,15 @@ def g2p(text: str) -> tuple[list[str], list[int], list[int]]:

def __post_replace_ph(ph: str) -> str:
REPLACE_MAP = {
":": ",",
";": ",",
",": ",",
"。": ".",
"!": "!",
"?": "?",
"\n": ".",
"·": ",",
"、": ",",
"…": "...",
"···": "...",
"・・・": "...",
"v": "V",
":": ",", ";": ",", ",": ",", "。": ".", "!": "!", "?": "?",
"\n": ".", "·": ",", "、": ",", "…": "...", "···": "...",
"・・・": "...", "v": "V"
}
if ph in REPLACE_MAP.keys():
if ph in REPLACE_MAP:
ph = REPLACE_MAP[ph]
if ph in SYMBOLS:
return ph
if ph not in SYMBOLS:
ph = "UNK"
return ph
return "UNK"


def __refine_ph(phn: str) -> tuple[str, int]:
Expand All @@ -182,8 +97,7 @@ def __refine_syllables(syllables: list[list[str]]) -> tuple[list[str], list[int]
tones = []
phonemes = []
for phn_list in syllables:
for i in range(len(phn_list)):
phn = phn_list[i]
for phn in phn_list:
phn, tone = __refine_ph(phn)
phonemes.append(phn)
tones.append(tone)
Expand Down Expand Up @@ -211,10 +125,7 @@ def __text_to_words(text: str) -> list[list[str]]:
if idx == len(tokens) - 1:
words.append([f"{t}"])
else:
if (
not tokens[idx + 1].startswith("▁")
and tokens[idx + 1] not in PUNCTUATIONS
):
if not tokens[idx + 1].startswith("▁") and tokens[idx + 1] not in PUNCTUATIONS:
if idx == 0:
words.append([])
words[-1].append(f"{t}")
Expand All @@ -238,3 +149,4 @@ def __text_to_words(text: str) -> list[list[str]]:
# for ph in group:
# all_phones.add(ph)
# print(all_phones)