Skip to content

Commit

Permalink
Refactor: run "hatch run style:fmt"
Browse files Browse the repository at this point in the history
  • Loading branch information
tsukumijima committed Apr 19, 2024
1 parent 797c354 commit 9d15d57
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
2 changes: 2 additions & 0 deletions style_bert_vits2/models/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def get_text(
# 他の言語は word2ph の調整方法が思いつかないのでエラー
if language_str == Languages.JP:
from style_bert_vits2.nlp.japanese.g2p import adjust_word2ph

word2ph = adjust_word2ph(word2ph, phone, given_phone)
# 上記処理により word2ph の合計が given_phone の長さと一致するはず
# それでも一致しない場合、大半は読み上げテキストと given_phone が著しく乖離していて調整し切れなかったことを意味する
Expand Down Expand Up @@ -302,5 +303,6 @@ def infer(
class InvalidPhoneError(ValueError):
pass


class InvalidToneError(ValueError):
pass
65 changes: 48 additions & 17 deletions style_bert_vits2/nlp/japanese/g2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def text_to_sep_kata(
raise YomiError(f"Cannot read: {word} in:\n{norm_text}")
## 例外を送出しない場合
## 読めない文字は「ん」として扱う
logger.warning(f"Cannot read: {word} in:\n{norm_text}, replaced with 'ん'")
logger.warning(
f"Cannot read: {word} in:\n{norm_text}, replaced with 'ん'"
)
# word の文字数分「ん」を追加
yomi = "ン" * len(word)
else:
Expand All @@ -159,7 +161,9 @@ def text_to_sep_kata(


def adjust_word2ph(
word2ph: list[int], generated_phone: list[str], given_phone: list[str],
word2ph: list[int],
generated_phone: list[str],
given_phone: list[str],
) -> list[int]:
"""
`g2p()` で得られた `word2ph` を、generated_phone と given_phone の差分情報を使っていい感じに調整する。
Expand Down Expand Up @@ -188,16 +192,21 @@ class DiffDetail(TypedDict):
begin_index: int
end_index: int
value: list[str]

class Diff(TypedDict):
generated: DiffDetail
given: DiffDetail

def extract_differences(generated_phone: list[str], given_phone: list[str]) -> list[Diff]:
def extract_differences(
generated_phone: list[str], given_phone: list[str]
) -> list[Diff]:
"""
最長共通部分列を基にして、二つのリストの異なる部分を抽出する。
"""

def longest_common_subsequence(X: list[str], Y: list[str]) -> list[tuple[int, int]]:
def longest_common_subsequence(
X: list[str], Y: list[str]
) -> list[tuple[int, int]]:
"""
二つのリストの最長共通部分列のインデックスのペアを返す。
"""
Expand Down Expand Up @@ -230,21 +239,42 @@ def longest_common_subsequence(X: list[str], Y: list[str]) -> list[tuple[int, in
prev_x, prev_y = -1, -1

# 共通部分のインデックスを基にして差分を抽出
for (x, y) in common_indices:
diff_X = {"begin_index": prev_x + 1, "end_index": x, "value": generated_phone[prev_x + 1:x]}
diff_Y = {"begin_index": prev_y + 1, "end_index": y, "value": given_phone[prev_y + 1:y]}
for x, y in common_indices:
diff_X = {
"begin_index": prev_x + 1,
"end_index": x,
"value": generated_phone[prev_x + 1 : x],
}
diff_Y = {
"begin_index": prev_y + 1,
"end_index": y,
"value": given_phone[prev_y + 1 : y],
}
if diff_X or diff_Y:
differences.append({"generated": diff_X, "given": diff_Y})
prev_x, prev_y = x, y
# 最後の非共通部分を追加
if prev_x < len(generated_phone) - 1 or prev_y < len(given_phone) - 1:
differences.append({
"generated": {"begin_index": prev_x + 1, "end_index": len(generated_phone) - 1, "value": generated_phone[prev_x + 1:len(generated_phone) - 1]},
"given": {"begin_index": prev_y + 1, "end_index": len(given_phone) - 1, "value": given_phone[prev_y + 1:len(given_phone) - 1]}
})
differences.append(
{
"generated": {
"begin_index": prev_x + 1,
"end_index": len(generated_phone) - 1,
"value": generated_phone[prev_x + 1 : len(generated_phone) - 1],
},
"given": {
"begin_index": prev_y + 1,
"end_index": len(given_phone) - 1,
"value": given_phone[prev_y + 1 : len(given_phone) - 1],
},
}
)
# generated.value と given.value の両方が空の要素を diffrences から削除
for diff in differences[:]:
if len(diff["generated"]["value"]) == 0 and len(diff["given"]["value"]) == 0:
if (
len(diff["generated"]["value"]) == 0
and len(diff["given"]["value"]) == 0
):
differences.remove(diff)

return differences
Expand Down Expand Up @@ -275,7 +305,8 @@ def longest_common_subsequence(X: list[str], Y: list[str]) -> list[tuple[int, in
# current_diff が None でない場合、generated_phone から始まる差分がある
if current_diff is not None:
# generated から given で変わった音素数の差分を取得 (2増えた場合は +2 だし、2減った場合は -2)
diff_in_phonemes = len(current_diff["given"]["value"]) - len(current_diff["generated"]["value"])
diff_in_phonemes = \
len(current_diff["given"]["value"]) - len(current_diff["generated"]["value"]) # fmt: skip
# adjusted_word2ph[(読み上げテキストの各文字のインデックス)] に上記差分を反映
adjusted_word2ph[word2ph_element_index] += diff_in_phonemes
# adjusted_word2ph[(読み上げテキストの各文字のインデックス)] に処理が完了した分の音素として 1 を加える
Expand All @@ -284,13 +315,13 @@ def longest_common_subsequence(X: list[str], Y: list[str]) -> list[tuple[int, in
current_generated_index += 1

# この時点で given_phone の長さと adjusted_word2ph に記録されている音素数の合計が一致しているはず
assert len(given_phone) == sum(adjusted_word2ph), f"{len(given_phone)} != {sum(adjusted_word2ph)}"
assert len(given_phone) == sum(adjusted_word2ph), f"{len(given_phone)} != {sum(adjusted_word2ph)}" # fmt: skip

# generated_phone から given_phone の間で音素が減った場合 (例: a, sh, i, t, a -> a, s, u) 、
# adjusted_word2ph の要素の値が 1 未満になることがあるので、1 になるように値を増やす
## この時、adjusted_word2ph に記録されている音素数の合計を変えないために、
## 値を 1 にした分だけ右隣の要素から増やした分の差分を差し引く
for adjusted_word2ph_element_index, adjusted_word2ph_element in enumerate(adjusted_word2ph):
for adjusted_word2ph_element_index, adjusted_word2ph_element in enumerate(adjusted_word2ph): # fmt: skip
# もし現在の要素が 1 未満ならば
if adjusted_word2ph_element < 1:
# 値を 1 にするためにどれだけ足せばいいかを計算
Expand All @@ -316,7 +347,7 @@ def longest_common_subsequence(X: list[str], Y: list[str]) -> list[tuple[int, in
# 逆に、generated_phone から given_phone の間で音素が増えた場合 (例: a, s, u -> a, sh, i, t, a) 、
# 1文字あたり7音素以上も割り当てられてしまう場合があるので、最大6音素にした上で削った分の差分を次の要素に加える
# 次の要素に差分を加えた結果7音素以上になってしまう場合は、その差分をさらに次の要素に加える
for adjusted_word2ph_element_index, adjusted_word2ph_element in enumerate(adjusted_word2ph):
for adjusted_word2ph_element_index, adjusted_word2ph_element in enumerate(adjusted_word2ph): # fmt: skip
if adjusted_word2ph_element > 6:
diff = adjusted_word2ph_element - 6
adjusted_word2ph[adjusted_word2ph_element_index] = 6
Expand Down Expand Up @@ -634,7 +665,7 @@ def __align_tones(
elif phone in PUNCTUATIONS:
# phone が punctuation の場合 → (phone, 0) を追加
result.append((phone, 0))
elif phone == 'N' or phone == 'n':
elif phone == "N" or phone == "n":
# ここに到達するのは raise_yomi_error=False 時に読めない文字を「ん」に代替した場合のみ
# この際、phone_tone_list には「ん」の音素は含まれていないので、(phone, 0) を追加
result.append((phone, 0))
Expand Down

0 comments on commit 9d15d57

Please sign in to comment.