Skip to content

Commit

Permalink
Merge pull request #124 from gordon0414/eng_g2p_fix
Browse files Browse the repository at this point in the history
Optimize English g2p function for improved speed and efficiency
  • Loading branch information
litagin02 authored May 14, 2024
2 parents e721c09 + 7de367b commit ae7c4ea
Showing 1 changed file with 27 additions and 115 deletions.
142 changes: 27 additions & 115 deletions style_bert_vits2/nlp/english/g2p.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,34 @@
import re

from g2p_en import G2p

from style_bert_vits2.constants import Languages
from style_bert_vits2.nlp import bert_models
from style_bert_vits2.nlp.english.cmudict import get_dict
from style_bert_vits2.nlp.symbols import PUNCTUATIONS, SYMBOLS

# Initialize global variables once
ARPA = {
"AH0", "S", "AH1", "EY2", "AE2", "EH0", "OW2", "UH0", "NG", "B", "G", "AY0",
"M", "AA0", "F", "AO0", "ER2", "UH1", "IY1", "AH2", "DH", "IY0", "EY1",
"IH0", "K", "N", "W", "IY2", "T", "AA1", "ER1", "EH2", "OY0", "UH2", "UW1",
"Z", "AW2", "AW1", "V", "UW2", "AA2", "ER", "AW0", "UW0", "R", "OW1", "EH1",
"ZH", "AE0", "IH2", "IH", "Y", "JH", "P", "AY1", "EY0", "OY2", "TH", "HH",
"D", "ER0", "CH", "AO1", "AE1", "AO2", "OY1", "AY2", "IH1", "OW0", "L",
"SH"
}
_g2p = G2p()
eng_dict = get_dict()

def g2p(text: str) -> tuple[list[str], list[int], list[int]]:

ARPA = {
"AH0",
"S",
"AH1",
"EY2",
"AE2",
"EH0",
"OW2",
"UH0",
"NG",
"B",
"G",
"AY0",
"M",
"AA0",
"F",
"AO0",
"ER2",
"UH1",
"IY1",
"AH2",
"DH",
"IY0",
"EY1",
"IH0",
"K",
"N",
"W",
"IY2",
"T",
"AA1",
"ER1",
"EH2",
"OY0",
"UH2",
"UW1",
"Z",
"AW2",
"AW1",
"V",
"UW2",
"AA2",
"ER",
"AW0",
"UW0",
"R",
"OW1",
"EH1",
"ZH",
"AE0",
"IH2",
"IH",
"Y",
"JH",
"P",
"AY1",
"EY0",
"OY2",
"TH",
"HH",
"D",
"ER0",
"CH",
"AO1",
"AE1",
"AO2",
"OY1",
"AY2",
"IH1",
"OW0",
"L",
"SH",
}

_g2p = G2p()

phones = []
tones = []
phone_len = []
# tokens = [tokenizer.tokenize(i) for i in words]
words = __text_to_words(text)
eng_dict = get_dict()

for word in words:
temp_phones, temp_tones = [], []
if len(word) > 1:
if "'" in word:
word = ["".join(word)]
if len(word) > 1 and "'" in word:
word = ["".join(word)]

for w in word:
if w in PUNCTUATIONS:
temp_phones.append(w)
Expand All @@ -107,11 +38,9 @@ def g2p(text: str) -> tuple[list[str], list[int], list[int]]:
phns, tns = __refine_syllables(eng_dict[w.upper()])
temp_phones += [__post_replace_ph(i) for i in phns]
temp_tones += tns
# w2ph.append(len(phns))
else:
phone_list = list(filter(lambda p: p != " ", _g2p(w))) # type: ignore
phns = []
tns = []
phone_list = list(filter(lambda p: p != " ", _g2p(w)))
phns, tns = [], []
for ph in phone_list:
if ph in ARPA:
ph, tn = __refine_ph(ph)
Expand All @@ -122,17 +51,15 @@ def g2p(text: str) -> tuple[list[str], list[int], list[int]]:
tns.append(0)
temp_phones += [__post_replace_ph(i) for i in phns]
temp_tones += tns

phones += temp_phones
tones += temp_tones
phone_len.append(len(temp_phones))
# phones = [post_replace_ph(i) for i in phones]

word2ph = []
for token, pl in zip(words, phone_len):
word_len = len(token)

aaa = __distribute_phone(pl, word_len)
word2ph += aaa
word2ph += __distribute_phone(pl, word_len)

phones = ["_"] + phones + ["_"]
tones = [0] + tones + [0]
Expand All @@ -145,27 +72,15 @@ def g2p(text: str) -> tuple[list[str], list[int], list[int]]:

def __post_replace_ph(ph: str) -> str:
REPLACE_MAP = {
":": ",",
";": ",",
",": ",",
"。": ".",
"!": "!",
"?": "?",
"\n": ".",
"·": ",",
"、": ",",
"…": "...",
"···": "...",
"・・・": "...",
"v": "V",
":": ",", ";": ",", ",": ",", "。": ".", "!": "!", "?": "?",
"\n": ".", "·": ",", "、": ",", "…": "...", "···": "...",
"・・・": "...", "v": "V"
}
if ph in REPLACE_MAP.keys():
if ph in REPLACE_MAP:
ph = REPLACE_MAP[ph]
if ph in SYMBOLS:
return ph
if ph not in SYMBOLS:
ph = "UNK"
return ph
return "UNK"


def __refine_ph(phn: str) -> tuple[str, int]:
Expand All @@ -182,8 +97,7 @@ def __refine_syllables(syllables: list[list[str]]) -> tuple[list[str], list[int]
tones = []
phonemes = []
for phn_list in syllables:
for i in range(len(phn_list)):
phn = phn_list[i]
for phn in phn_list:
phn, tone = __refine_ph(phn)
phonemes.append(phn)
tones.append(tone)
Expand Down Expand Up @@ -211,10 +125,7 @@ def __text_to_words(text: str) -> list[list[str]]:
if idx == len(tokens) - 1:
words.append([f"{t}"])
else:
if (
not tokens[idx + 1].startswith("▁")
and tokens[idx + 1] not in PUNCTUATIONS
):
if not tokens[idx + 1].startswith("▁") and tokens[idx + 1] not in PUNCTUATIONS:
if idx == 0:
words.append([])
words[-1].append(f"{t}")
Expand All @@ -238,3 +149,4 @@ def __text_to_words(text: str) -> list[list[str]]:
# for ph in group:
# all_phones.add(ph)
# print(all_phones)

0 comments on commit ae7c4ea

Please sign in to comment.