Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

音声合成時に読み上げテキストの読みを表す音素列を指定する機能を追加 + 様々な改善 #118

Merged
merged 10 commits into from
May 14, 2024
Merged
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dist/
/bert/*/*.safetensors
/bert/*/*.msgpack

/configs/paths.yml

/pretrained/*.safetensors
/pretrained/*.pth

Expand Down
20 changes: 10 additions & 10 deletions Server.bat
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
chcp 65001 > NUL
@echo off
pushd %~dp0
echo Running server_fastapi.py
venv\Scripts\python server_fastapi.py
if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% )
popd
chcp 65001 > NUL
@echo off

pushd %~dp0
echo Running server_fastapi.py
venv\Scripts\python server_fastapi.py

if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% )

popd
pause
File renamed without changes.
1 change: 1 addition & 0 deletions gradio_tabs/style_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME
from style_bert_vits2.logging import logger


# Get path settings
with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f:
path_config: dict[str, str] = yaml.safe_load(f.read())
Expand Down
8 changes: 7 additions & 1 deletion initialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import shutil
from pathlib import Path

import yaml
Expand Down Expand Up @@ -102,11 +103,16 @@ def main():
download_pretrained_models()
download_jp_extra_pretrained_models()

# If configs/paths.yml not exists, create it
default_paths_yml = Path("configs/default_paths.yml")
paths_yml = Path("configs/paths.yml")
if not paths_yml.exists():
shutil.copy(default_paths_yml, paths_yml)

if args.dataset_root is None and args.assets_root is None:
return

# Change default paths if necessary
paths_yml = Path("configs/paths.yml")
with open(paths_yml, "r", encoding="utf-8") as f:
yml_data = yaml.safe_load(f)
if args.assets_root is not None:
Expand Down
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,16 @@ dependencies = [
'cmudict',
'cn2an',
'g2p_en',
'gradio',
'jieba',
'librosa==0.9.2',
'loguru',
'num2words',
'numba',
'numpy',
'pyannote.audio>=3.1.0',
'pydantic>=2.0',
'pyopenjtalk-dict',
'pypinyin',
'pyworld-prebuilt',
'safetensors',
'scipy',
'torch>=2.1',
'transformers',
]
Expand Down
47 changes: 44 additions & 3 deletions style_bert_vits2/models/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def get_text(
device: str,
assist_text: Optional[str] = None,
assist_text_weight: float = 0.7,
given_phone: Optional[list[str]] = None,
given_tone: Optional[list[int]] = None,
):
use_jp_extra = hps.version.endswith("JP-Extra")
Expand All @@ -112,10 +113,44 @@ def get_text(
use_jp_extra=use_jp_extra,
raise_yomi_error=False,
)
if given_tone is not None:
if len(given_tone) != len(phone):
# phone と tone の両方が与えられた場合はそれを使う
if given_phone is not None and given_tone is not None:
# 指定された phone と指定された tone 両方の長さが一致していなければならない
if len(given_phone) != len(given_tone):
raise InvalidPhoneError(
f"Length of given_phone ({len(given_phone)}) != length of given_tone ({len(given_tone)})"
)
# 与えられた音素数と pyopenjtalk で生成した読みの音素数が一致しない
if len(given_phone) != sum(word2ph):
# 日本語の場合、len(given_phone) と sum(word2ph) が一致するように word2ph を適切に調整する
# 他の言語は word2ph の調整方法が思いつかないのでエラー
if language_str == Languages.JP:
from style_bert_vits2.nlp.japanese.g2p import adjust_word2ph

word2ph = adjust_word2ph(word2ph, phone, given_phone)
# 上記処理により word2ph の合計が given_phone の長さと一致するはず
# それでも一致しない場合、大半は読み上げテキストと given_phone が著しく乖離していて調整し切れなかったことを意味する
if len(given_phone) != sum(word2ph):
raise InvalidPhoneError(
f"Length of given_phone ({len(given_phone)}) != sum of word2ph ({sum(word2ph)})"
)
else:
raise InvalidPhoneError(
f"Length of given_phone ({len(given_phone)}) != sum of word2ph ({sum(word2ph)})"
)
phone = given_phone
# 生成あるいは指定された phone と指定された tone 両方の長さが一致していなければならない
if len(phone) != len(given_tone):
raise InvalidToneError(
f"Length of phone ({len(phone)}) != length of given_tone ({len(given_tone)})"
)
tone = given_tone
# tone だけが与えられた場合は clean_text() で生成した phone と合わせて使う
elif given_tone is not None:
# 生成した phone と指定された tone 両方の長さが一致していなければならない
if len(phone) != len(given_tone):
raise InvalidToneError(
f"Length of given_tone ({len(given_tone)}) != length of phone ({len(phone)})"
f"Length of phone ({len(phone)}) != length of given_tone ({len(given_tone)})"
)
tone = given_tone
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
Expand Down Expand Up @@ -179,6 +214,7 @@ def infer(
skip_end: bool = False,
assist_text: Optional[str] = None,
assist_text_weight: float = 0.7,
given_phone: Optional[list[str]] = None,
given_tone: Optional[list[int]] = None,
):
is_jp_extra = hps.version.endswith("JP-Extra")
Expand All @@ -189,6 +225,7 @@ def infer(
device,
assist_text=assist_text,
assist_text_weight=assist_text_weight,
given_phone=given_phone,
given_tone=given_tone,
)
if skip_start:
Expand Down Expand Up @@ -263,5 +300,9 @@ def infer(
return audio


class InvalidPhoneError(ValueError):
pass


class InvalidToneError(ValueError):
pass
8 changes: 7 additions & 1 deletion style_bert_vits2/models/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
import torch
from numpy.typing import NDArray
from scipy.io.wavfile import read

from style_bert_vits2.logging import logger
from style_bert_vits2.models.utils import checkpoints # type: ignore
Expand Down Expand Up @@ -162,6 +161,13 @@ def load_wav_to_torch(full_path: Union[str, Path]) -> tuple[torch.FloatTensor, i
tuple[torch.FloatTensor, int]: 音声データのテンソルとサンプリングレート
"""

# この関数は学習時以外使われないため、ライブラリとしての style_bert_vits2 が
# 重たい scipy に依存しないように遅延 import する
try:
from scipy.io.wavfile import read
except ImportError:
raise ImportError("scipy is required to load wav file")

sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate

Expand Down
73 changes: 37 additions & 36 deletions style_bert_vits2/nlp/chinese/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,41 @@
from style_bert_vits2.nlp.symbols import PUNCTUATIONS


__REPLACE_MAP = {
":": ",",
";": ",",
",": ",",
"。": ".",
"!": "!",
"?": "?",
"\n": ".",
"·": ",",
"、": ",",
"...": "…",
"$": ".",
"“": "'",
"”": "'",
'"': "'",
"‘": "'",
"’": "'",
"(": "'",
")": "'",
"(": "'",
")": "'",
"《": "'",
"》": "'",
"【": "'",
"】": "'",
"[": "'",
"]": "'",
"—": "-",
"~": "-",
"~": "-",
"「": "'",
"」": "'",
}


def normalize_text(text: str) -> str:
numbers = re.findall(r"\d+(?:\.?\d+)?", text)
for number in numbers:
Expand All @@ -15,44 +50,10 @@ def normalize_text(text: str) -> str:

def replace_punctuation(text: str) -> str:

REPLACE_MAP = {
":": ",",
";": ",",
",": ",",
"。": ".",
"!": "!",
"?": "?",
"\n": ".",
"·": ",",
"、": ",",
"...": "…",
"$": ".",
"“": "'",
"”": "'",
'"': "'",
"‘": "'",
"’": "'",
"(": "'",
")": "'",
"(": "'",
")": "'",
"《": "'",
"》": "'",
"【": "'",
"】": "'",
"[": "'",
"]": "'",
"—": "-",
"~": "-",
"~": "-",
"「": "'",
"」": "'",
}

text = text.replace("嗯", "恩").replace("呣", "母")
pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys()))
pattern = re.compile("|".join(re.escape(p) for p in __REPLACE_MAP.keys()))

replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text)
replaced_text = pattern.sub(lambda x: __REPLACE_MAP[x.group()], text)

replaced_text = re.sub(
r"[^\u4e00-\u9fa5" + "".join(PUNCTUATIONS) + r"]+", "", replaced_text
Expand Down
Loading