From 81baabbb30c66aaf32b09b216af882aba93ce410 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Mon, 4 Mar 2024 11:46:09 +0900 Subject: [PATCH 001/148] wip --- app.py | 518 +----------------- webui/__init__.py | 16 + webui_dataset.py => webui/dataset.py | 428 ++++++++------- webui/inference.py | 504 +++++++++++++++++ webui_merge.py => webui/merge.py | 345 ++++++------ .../style_vectors.py | 295 +++++----- webui_train.py => webui/train.py | 6 +- 7 files changed, 1091 insertions(+), 1021 deletions(-) create mode 100644 webui/__init__.py rename webui_dataset.py => webui/dataset.py (50%) create mode 100644 webui/inference.py rename webui_merge.py => webui/merge.py (65%) rename webui_style_vectors.py => webui/style_vectors.py (65%) rename webui_train.py => webui/train.py (99%) diff --git a/app.py b/app.py index 1514444f7..7057f459c 100644 --- a/app.py +++ b/app.py @@ -1,502 +1,30 @@ -import argparse -import datetime -import json -import os -import sys -from pathlib import Path -from typing import Optional - +import pyopenjtalk import gradio as gr -import torch -import yaml - -from common.constants import ( - DEFAULT_ASSIST_TEXT_WEIGHT, - DEFAULT_LENGTH, - DEFAULT_LINE_SPLIT, - DEFAULT_NOISE, - DEFAULT_NOISEW, - DEFAULT_SDP_RATIO, - DEFAULT_SPLIT_INTERVAL, - DEFAULT_STYLE, - DEFAULT_STYLE_WEIGHT, - GRADIO_THEME, - LATEST_VERSION, - Languages, +from webui import ( + create_dataset_app, + create_train_app, + create_merge_app, + create_style_vectors_app, ) -from common.log import logger -from common.tts_model import ModelHolder -from infer import InvalidToneError -from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize - -# Get path settings -with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - # dataset_root = path_config["dataset_root"] - assets_root = path_config["assets_root"] - -languages = [l.value for l in Languages] - - -def tts_fn( - model_name, - model_path, - text, - language, - reference_audio_path, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - line_split, - split_interval, - assist_text, - assist_text_weight, - use_assist_text, - style, - style_weight, - kata_tone_json_str, - use_tone, - speaker, - pitch_scale, - intonation_scale, -): - model_holder.load_model_gr(model_name, model_path) - - wrong_tone_message = "" - kata_tone: Optional[list[tuple[str, int]]] = None - if use_tone and kata_tone_json_str != "": - if language != "JP": - logger.warning("Only Japanese is supported for tone generation.") - wrong_tone_message = "アクセント指定は現在日本語のみ対応しています。" - if line_split: - logger.warning("Tone generation is not supported for line split.") - wrong_tone_message = ( - "アクセント指定は改行で分けて生成を使わない場合のみ対応しています。" - ) - try: - kata_tone = [] - json_data = json.loads(kata_tone_json_str) - # tupleを使うように変換 - for kana, tone in json_data: - assert isinstance(kana, str) and tone in (0, 1), f"{kana}, {tone}" - kata_tone.append((kana, tone)) - except Exception as e: - logger.warning(f"Error occurred when parsing kana_tone_json: {e}") - wrong_tone_message = f"アクセント指定が不正です: {e}" - kata_tone = None - - # toneは実際に音声合成に代入される際のみnot Noneになる - tone: Optional[list[int]] = None - if kata_tone is not None: - phone_tone = kata_tone2phone_tone(kata_tone) - tone = [t for _, t in phone_tone] - - speaker_id = model_holder.current_model.spk2id[speaker] - - start_time = datetime.datetime.now() - - assert model_holder.current_model is not None - - try: - sr, audio = model_holder.current_model.infer( - text=text, - language=language, - reference_audio_path=reference_audio_path, - sdp_ratio=sdp_ratio, - noise=noise_scale, - noisew=noise_scale_w, - length=length_scale, - line_split=line_split, - split_interval=split_interval, - assist_text=assist_text, - assist_text_weight=assist_text_weight, - use_assist_text=use_assist_text, - style=style, - style_weight=style_weight, - given_tone=tone, - sid=speaker_id, - pitch_scale=pitch_scale, - intonation_scale=intonation_scale, - ) - except InvalidToneError as e: - logger.error(f"Tone error: {e}") - return f"Error: アクセント指定が不正です:\n{e}", None, kata_tone_json_str - except ValueError as e: - logger.error(f"Value error: {e}") - return f"Error: {e}", None, kata_tone_json_str - - end_time = datetime.datetime.now() - duration = (end_time - start_time).total_seconds() - - if tone is None and language == "JP": - # アクセント指定に使えるようにアクセント情報を返す - norm_text = text_normalize(text) - kata_tone = g2kata_tone(norm_text) - kata_tone_json_str = json.dumps(kata_tone, ensure_ascii=False) - elif tone is None: - kata_tone_json_str = "" - message = f"Success, time: {duration} seconds." - if wrong_tone_message != "": - message = wrong_tone_message + "\n" + message - return message, (sr, audio), kata_tone_json_str - - -initial_text = "こんにちは、初めまして。あなたの名前はなんていうの?" - -examples = [ - [initial_text, "JP"], - [ - """あなたがそんなこと言うなんて、私はとっても嬉しい。 -あなたがそんなこと言うなんて、私はとっても怒ってる。 -あなたがそんなこと言うなんて、私はとっても驚いてる。 -あなたがそんなこと言うなんて、私はとっても辛い。""", - "JP", - ], - [ # ChatGPTに考えてもらった告白セリフ - """私、ずっと前からあなたのことを見てきました。あなたの笑顔、優しさ、強さに、心惹かれていたんです。 -友達として過ごす中で、あなたのことがだんだんと特別な存在になっていくのがわかりました。 -えっと、私、あなたのことが好きです!もしよければ、私と付き合ってくれませんか?""", - "JP", - ], - [ # 夏目漱石『吾輩は猫である』 - """吾輩は猫である。名前はまだ無い。 -どこで生れたかとんと見当がつかぬ。なんでも薄暗いじめじめした所でニャーニャー泣いていた事だけは記憶している。 -吾輩はここで初めて人間というものを見た。しかもあとで聞くと、それは書生という、人間中で一番獰悪な種族であったそうだ。 -この書生というのは時々我々を捕まえて煮て食うという話である。""", - "JP", - ], - [ # 梶井基次郎『桜の樹の下には』 - """桜の樹の下には屍体が埋まっている!これは信じていいことなんだよ。 -何故って、桜の花があんなにも見事に咲くなんて信じられないことじゃないか。俺はあの美しさが信じられないので、このにさんにち不安だった。 -しかしいま、やっとわかるときが来た。桜の樹の下には屍体が埋まっている。これは信じていいことだ。""", - "JP", - ], - [ # ChatGPTと考えた、感情を表すセリフ - """やったー!テストで満点取れた!私とっても嬉しいな! -どうして私の意見を無視するの?許せない!ムカつく!あんたなんか死ねばいいのに。 -あはははっ!この漫画めっちゃ笑える、見てよこれ、ふふふ、あはは。 -あなたがいなくなって、私は一人になっちゃって、泣いちゃいそうなほど悲しい。""", - "JP", - ], - [ # 上の丁寧語バージョン - """やりました!テストで満点取れましたよ!私とっても嬉しいです! -どうして私の意見を無視するんですか?許せません!ムカつきます!あんたなんか死んでください。 -あはははっ!この漫画めっちゃ笑えます、見てくださいこれ、ふふふ、あはは。 -あなたがいなくなって、私は一人になっちゃって、泣いちゃいそうなほど悲しいです。""", - "JP", - ], - [ # ChatGPTに考えてもらった音声合成の説明文章 - """音声合成は、機械学習を活用して、テキストから人の声を再現する技術です。この技術は、言語の構造を解析し、それに基づいて音声を生成します。 -この分野の最新の研究成果を使うと、より自然で表現豊かな音声の生成が可能である。深層学習の応用により、感情やアクセントを含む声質の微妙な変化も再現することが出来る。""", - "JP", - ], - [ - "Speech synthesis is the artificial production of human speech. A computer system used for this purpose is called a speech synthesizer, and can be implemented in software or hardware products.", - "EN", - ], - [ - "语音合成是人工制造人类语音。用于此目的的计算机系统称为语音合成器,可以通过软件或硬件产品实现。", - "ZH", - ], -] - -initial_md = f""" -# Style-Bert-VITS2 ver {LATEST_VERSION} 音声合成 - -- Ver 2.3で追加されたエディターのほうが実際に読み上げさせるには使いやすいかもしれません。`Editor.bat`か`python server_editor.py`で起動できます。 - -- 初期からある[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)は、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。ライセンスは[CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.ja)です。 -""" - -how_to_md = """ -下のように`model_assets`ディレクトリの中にモデルファイルたちを置いてください。 -``` -model_assets -├── your_model -│ ├── config.json -│ ├── your_model_file1.safetensors -│ ├── your_model_file2.safetensors -│ ├── ... -│ └── style_vectors.npy -└── another_model - ├── ... -``` -各モデルにはファイルたちが必要です: -- `config.json`:学習時の設定ファイル -- `*.safetensors`:学習済みモデルファイル(1つ以上が必要、複数可) -- `style_vectors.npy`:スタイルベクトルファイル - -上2つは`Train.bat`による学習で自動的に正しい位置に保存されます。`style_vectors.npy`は`Style.bat`を実行して指示に従って生成してください。 -""" - -style_md = f""" -- プリセットまたは音声ファイルから読み上げの声音・感情・スタイルのようなものを制御できます。 -- デフォルトの{DEFAULT_STYLE}でも、十分に読み上げる文に応じた感情で感情豊かに読み上げられます。このスタイル制御は、それを重み付きで上書きするような感じです。 -- 強さを大きくしすぎると発音が変になったり声にならなかったりと崩壊することがあります。 -- どのくらいに強さがいいかはモデルやスタイルによって異なるようです。 -- 音声ファイルを入力する場合は、学習データと似た声音の話者(特に同じ性別)でないとよい効果が出ないかもしれません。 -""" - - -def make_interactive(): - return gr.update(interactive=True, value="音声合成") - - -def make_non_interactive(): - return gr.update(interactive=False, value="音声合成(モデルをロードしてください)") - - -def gr_util(item): - if item == "プリセットから選ぶ": - return (gr.update(visible=True), gr.Audio(visible=False, value=None)) - else: - return (gr.update(visible=False), gr.update(visible=True)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU") - parser.add_argument( - "--dir", "-d", type=str, help="Model directory", default=assets_root - ) - parser.add_argument( - "--share", action="store_true", help="Share this app publicly", default=False - ) - parser.add_argument( - "--server-name", - type=str, - default=None, - help="Server name for Gradio app", - ) - parser.add_argument( - "--no-autolaunch", - action="store_true", - default=False, - help="Do not launch app automatically", - ) - args = parser.parse_args() - model_dir = Path(args.dir) - - if args.cpu: - device = "cpu" - else: - device = "cuda" if torch.cuda.is_available() else "cpu" - - model_holder = ModelHolder(model_dir, device) - - model_names = model_holder.model_names - if len(model_names) == 0: - logger.error( - f"モデルが見つかりませんでした。{model_dir}にモデルを置いてください。" - ) - sys.exit(1) - initial_id = 0 - initial_pth_files = model_holder.model_files_dict[model_names[initial_id]] - - with gr.Blocks(theme=GRADIO_THEME) as app: - gr.Markdown(initial_md) - with gr.Accordion(label="使い方", open=False): - gr.Markdown(how_to_md) - with gr.Row(): - with gr.Column(): - with gr.Row(): - with gr.Column(scale=3): - model_name = gr.Dropdown( - label="モデル一覧", - choices=model_names, - value=model_names[initial_id], - ) - model_path = gr.Dropdown( - label="モデルファイル", - choices=initial_pth_files, - value=initial_pth_files[0], - ) - refresh_button = gr.Button("更新", scale=1, visible=True) - load_button = gr.Button("ロード", scale=1, variant="primary") - text_input = gr.TextArea(label="テキスト", value=initial_text) - pitch_scale = gr.Slider( - minimum=0.8, - maximum=1.5, - value=1, - step=0.05, - label="音程(1以外では音質劣化)", - visible=False, # pyworldが必要 - ) - intonation_scale = gr.Slider( - minimum=0, - maximum=2, - value=1, - step=0.1, - label="抑揚(1以外では音質劣化)", - visible=False, # pyworldが必要 - ) - - line_split = gr.Checkbox( - label="改行で分けて生成(分けたほうが感情が乗ります)", - value=DEFAULT_LINE_SPLIT, - ) - split_interval = gr.Slider( - minimum=0.0, - maximum=2, - value=DEFAULT_SPLIT_INTERVAL, - step=0.1, - label="改行ごとに挟む無音の長さ(秒)", - ) - line_split.change( - lambda x: (gr.Slider(visible=x)), - inputs=[line_split], - outputs=[split_interval], - ) - tone = gr.Textbox( - label="アクセント調整(数値は 0=低 か1=高 のみ)", - info="改行で分けない場合のみ使えます。万能ではありません。", - ) - use_tone = gr.Checkbox(label="アクセント調整を使う", value=False) - use_tone.change( - lambda x: (gr.Checkbox(value=False) if x else gr.Checkbox()), - inputs=[use_tone], - outputs=[line_split], - ) - language = gr.Dropdown(choices=languages, value="JP", label="Language") - speaker = gr.Dropdown(label="話者") - with gr.Accordion(label="詳細設定", open=False): - sdp_ratio = gr.Slider( - minimum=0, - maximum=1, - value=DEFAULT_SDP_RATIO, - step=0.1, - label="SDP Ratio", - ) - noise_scale = gr.Slider( - minimum=0.1, - maximum=2, - value=DEFAULT_NOISE, - step=0.1, - label="Noise", - ) - noise_scale_w = gr.Slider( - minimum=0.1, - maximum=2, - value=DEFAULT_NOISEW, - step=0.1, - label="Noise_W", - ) - length_scale = gr.Slider( - minimum=0.1, - maximum=2, - value=DEFAULT_LENGTH, - step=0.1, - label="Length", - ) - use_assist_text = gr.Checkbox( - label="Assist textを使う", value=False - ) - assist_text = gr.Textbox( - label="Assist text", - placeholder="どうして私の意見を無視するの?許せない、ムカつく!死ねばいいのに。", - info="このテキストの読み上げと似た声音・感情になりやすくなります。ただ抑揚やテンポ等が犠牲になる傾向があります。", - visible=False, - ) - assist_text_weight = gr.Slider( - minimum=0, - maximum=1, - value=DEFAULT_ASSIST_TEXT_WEIGHT, - step=0.1, - label="Assist textの強さ", - visible=False, - ) - use_assist_text.change( - lambda x: (gr.Textbox(visible=x), gr.Slider(visible=x)), - inputs=[use_assist_text], - outputs=[assist_text, assist_text_weight], - ) - with gr.Column(): - with gr.Accordion("スタイルについて詳細", open=False): - gr.Markdown(style_md) - style_mode = gr.Radio( - ["プリセットから選ぶ", "音声ファイルを入力"], - label="スタイルの指定方法", - value="プリセットから選ぶ", - ) - style = gr.Dropdown( - label=f"スタイル({DEFAULT_STYLE}が平均スタイル)", - choices=["モデルをロードしてください"], - value="モデルをロードしてください", - ) - style_weight = gr.Slider( - minimum=0, - maximum=50, - value=DEFAULT_STYLE_WEIGHT, - step=0.1, - label="スタイルの強さ", - ) - ref_audio_path = gr.Audio( - label="参照音声", type="filepath", visible=False - ) - tts_button = gr.Button( - "音声合成(モデルをロードしてください)", - variant="primary", - interactive=False, - ) - text_output = gr.Textbox(label="情報") - audio_output = gr.Audio(label="結果") - with gr.Accordion("テキスト例", open=False): - gr.Examples(examples, inputs=[text_input, language]) - - tts_button.click( - tts_fn, - inputs=[ - model_name, - model_path, - text_input, - language, - ref_audio_path, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - line_split, - split_interval, - assist_text, - assist_text_weight, - use_assist_text, - style, - style_weight, - tone, - use_tone, - speaker, - pitch_scale, - intonation_scale, - ], - outputs=[text_output, audio_output, tone], - ) - - model_name.change( - model_holder.update_model_files_gr, - inputs=[model_name], - outputs=[model_path], - ) +from pathlib import Path - model_path.change(make_non_interactive, outputs=[tts_button]) +pyopenjtalk.unset_user_dict() - refresh_button.click( - model_holder.update_model_names_gr, - outputs=[model_name, model_path, tts_button], - ) +setting_json = Path("webui/setting.json") - load_button.click( - model_holder.load_model_gr, - inputs=[model_name, model_path], - outputs=[style, tts_button, speaker], - ) +with gr.Blocks() as app: + with gr.Tabs(): + with gr.Tab("Hello"): + gr.Markdown("## Hello, Gradio!") + gr.Textbox("input", label="Input Text") + with gr.Tab("Dataset"): + create_dataset_app() + with gr.Tab("Train"): + create_train_app() + with gr.Tab("Merge"): + create_merge_app() + with gr.Tab("Create Style Vectors"): + create_style_vectors_app() - style_mode.change( - gr_util, - inputs=[style_mode], - outputs=[style, ref_audio_path], - ) - app.launch( - inbrowser=not args.no_autolaunch, share=args.share, server_name=args.server_name - ) +app.launch(inbrowser=True) diff --git a/webui/__init__.py b/webui/__init__.py new file mode 100644 index 000000000..4bc8efab8 --- /dev/null +++ b/webui/__init__.py @@ -0,0 +1,16 @@ +from .dataset import create_dataset_app +from .inference import create_inference_app +from .merge import create_merge_app +from .style_vectors import create_style_vectors_app +from .train import create_train_app + + +class TrainSettings: + def __init__(self, setting_json): + self.setting_json = setting_json + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass diff --git a/webui_dataset.py b/webui/dataset.py similarity index 50% rename from webui_dataset.py rename to webui/dataset.py index fec7a9ac9..5ed656c4d 100644 --- a/webui_dataset.py +++ b/webui/dataset.py @@ -1,208 +1,220 @@ -import argparse -import os - -import gradio as gr -import yaml - -from common.constants import GRADIO_THEME -from common.log import logger -from common.subprocess_utils import run_script_with_log - -# Get path settings -with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - dataset_root = path_config["dataset_root"] - # assets_root = path_config["assets_root"] - - -def do_slice( - model_name: str, - min_sec: float, - max_sec: float, - min_silence_dur_ms: int, - input_dir: str, -): - if model_name == "": - return "Error: モデル名を入力してください。" - logger.info("Start slicing...") - cmd = [ - "slice.py", - "--model_name", - model_name, - "--min_sec", - str(min_sec), - "--max_sec", - str(max_sec), - "--min_silence_dur_ms", - str(min_silence_dur_ms), - ] - if input_dir != "": - cmd += ["--input_dir", input_dir] - # onnxの警告が出るので無視する - success, message = run_script_with_log(cmd, ignore_warning=True) - if not success: - return f"Error: {message}" - return "音声のスライスが完了しました。" - - -def do_transcribe( - model_name, whisper_model, compute_type, language, initial_prompt, device -): - if model_name == "": - return "Error: モデル名を入力してください。" - - success, message = run_script_with_log( - [ - "transcribe.py", - "--model_name", - model_name, - "--model", - whisper_model, - "--compute_type", - compute_type, - "--device", - device, - "--language", - language, - "--initial_prompt", - f'"{initial_prompt}"', - ] - ) - if not success: - return f"Error: {message}" - return "音声の文字起こしが完了しました。" - - -initial_md = """ -# 簡易学習用データセット作成ツール - -Style-Bert-VITS2の学習用データセットを作成するためのツールです。以下の2つからなります。 - -- 与えられた音声からちょうどいい長さの発話区間を切り取りスライス -- 音声に対して文字起こし - -このうち両方を使ってもよいし、スライスする必要がない場合は後者のみを使ってもよいです。 - -## 必要なもの - -学習したい音声が入ったwavファイルいくつか。 -合計時間がある程度はあったほうがいいかも、10分とかでも大丈夫だったとの報告あり。単一ファイルでも良いし複数ファイルでもよい。 - -## スライス使い方 -1. `inputs`フォルダにwavファイルをすべて入れる -2. `モデル名`を入力して、設定を必要なら調整して`音声のスライス`ボタンを押す -3. 出来上がった音声ファイルたちは`Data/{モデル名}/raw`に保存される - -## 書き起こし使い方 - -1. 書き起こしたい音声ファイルのあるフォルダを指定(デフォルトは`Data/{モデル名}/raw`なのでスライス後に行う場合は省略してよい) -2. 設定を必要なら調整してボタンを押す -3. 書き起こしファイルは`Data/{モデル名}/esd.list`に保存される - -## 注意 - -- 長すぎる秒数(12-15秒くらいより長い?)のwavファイルは学習に用いられないようです。また短すぎてもあまりよくない可能性もあります。 -- 書き起こしの結果をどれだけ修正すればいいかはデータセットに依存しそうです。 -- 手動で書き起こしをいろいろ修正したり結果を細かく確認したい場合は、[Aivis Dataset](https://github.com/litagin02/Aivis-Dataset)もおすすめします。書き起こし部分もかなり工夫されています。ですがファイル数が多い場合などは、このツールで簡易的に切り出してデータセットを作るだけでも十分という気もしています。 -""" - -with gr.Blocks(theme=GRADIO_THEME) as app: - gr.Markdown(initial_md) - model_name = gr.Textbox( - label="モデル名を入力してください(話者名としても使われます)。" - ) - with gr.Accordion("音声のスライス"): - with gr.Row(): - with gr.Column(): - input_dir = gr.Textbox( - label="入力フォルダ名(デフォルトはinputs)", - placeholder="inputs", - info="下記フォルダにwavファイルを入れておいてください", - ) - min_sec = gr.Slider( - minimum=0, - maximum=10, - value=2, - step=0.5, - label="この秒数未満は切り捨てる", - ) - max_sec = gr.Slider( - minimum=0, - maximum=15, - value=12, - step=0.5, - label="この秒数以上は切り捨てる", - ) - min_silence_dur_ms = gr.Slider( - minimum=0, - maximum=2000, - value=700, - step=100, - label="無音とみなして区切る最小の無音の長さ(ms)", - ) - slice_button = gr.Button("スライスを実行") - result1 = gr.Textbox(label="結果") - with gr.Row(): - with gr.Column(): - whisper_model = gr.Dropdown( - ["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], - label="Whisperモデル", - value="large-v3", - ) - compute_type = gr.Dropdown( - [ - "int8", - "int8_float32", - "int8_float16", - "int8_bfloat16", - "int16", - "float16", - "bfloat16", - "float32", - ], - label="計算精度", - value="bfloat16", - ) - device = gr.Radio(["cuda", "cpu"], label="デバイス", value="cuda") - language = gr.Dropdown(["ja", "en", "zh"], value="ja", label="言語") - initial_prompt = gr.Textbox( - label="初期プロンプト", - value="こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!", - info="このように書き起こしてほしいという例文(句読点の入れ方・笑い方・固有名詞等)", - ) - transcribe_button = gr.Button("音声の文字起こし") - result2 = gr.Textbox(label="結果") - slice_button.click( - do_slice, - inputs=[model_name, min_sec, max_sec, min_silence_dur_ms, input_dir], - outputs=[result1], - ) - transcribe_button.click( - do_transcribe, - inputs=[ - model_name, - whisper_model, - compute_type, - language, - initial_prompt, - device, - ], - outputs=[result2], - ) - -parser = argparse.ArgumentParser() -parser.add_argument( - "--server-name", - type=str, - default=None, - help="Server name for Gradio app", -) -parser.add_argument( - "--no-autolaunch", - action="store_true", - default=False, - help="Do not launch app automatically", -) -args = parser.parse_args() - -app.launch(inbrowser=not args.no_autolaunch, server_name=args.server_name) +import argparse +import os + +import gradio as gr +import yaml + +from common.constants import GRADIO_THEME +from common.log import logger +from common.subprocess_utils import run_script_with_log + +# Get path settings +with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: + path_config: dict[str, str] = yaml.safe_load(f.read()) + dataset_root = path_config["dataset_root"] + # assets_root = path_config["assets_root"] + + +def do_slice( + model_name: str, + min_sec: float, + max_sec: float, + min_silence_dur_ms: int, + input_dir: str, +): + if model_name == "": + return "Error: モデル名を入力してください。" + logger.info("Start slicing...") + cmd = [ + "slice.py", + "--model_name", + model_name, + "--min_sec", + str(min_sec), + "--max_sec", + str(max_sec), + "--min_silence_dur_ms", + str(min_silence_dur_ms), + ] + if input_dir != "": + cmd += ["--input_dir", input_dir] + # onnxの警告が出るので無視する + success, message = run_script_with_log(cmd, ignore_warning=True) + if not success: + return f"Error: {message}" + return "音声のスライスが完了しました。" + + +def do_transcribe( + model_name, whisper_model, compute_type, language, initial_prompt, device +): + if model_name == "": + return "Error: モデル名を入力してください。" + + success, message = run_script_with_log( + [ + "transcribe.py", + "--model_name", + model_name, + "--model", + whisper_model, + "--compute_type", + compute_type, + "--device", + device, + "--language", + language, + "--initial_prompt", + f'"{initial_prompt}"', + ] + ) + if not success: + return f"Error: {message}" + return "音声の文字起こしが完了しました。" + + +initial_md = """ +# 簡易学習用データセット作成ツール + +Style-Bert-VITS2の学習用データセットを作成するためのツールです。以下の2つからなります。 + +- 与えられた音声からちょうどいい長さの発話区間を切り取りスライス +- 音声に対して文字起こし + +このうち両方を使ってもよいし、スライスする必要がない場合は後者のみを使ってもよいです。 + +## 必要なもの + +学習したい音声が入ったwavファイルいくつか。 +合計時間がある程度はあったほうがいいかも、10分とかでも大丈夫だったとの報告あり。単一ファイルでも良いし複数ファイルでもよい。 + +## スライス使い方 +1. `inputs`フォルダにwavファイルをすべて入れる +2. `モデル名`を入力して、設定を必要なら調整して`音声のスライス`ボタンを押す +3. 出来上がった音声ファイルたちは`Data/{モデル名}/raw`に保存される + +## 書き起こし使い方 + +1. 書き起こしたい音声ファイルのあるフォルダを指定(デフォルトは`Data/{モデル名}/raw`なのでスライス後に行う場合は省略してよい) +2. 設定を必要なら調整してボタンを押す +3. 書き起こしファイルは`Data/{モデル名}/esd.list`に保存される + +## 注意 + +- 長すぎる秒数(12-15秒くらいより長い?)のwavファイルは学習に用いられないようです。また短すぎてもあまりよくない可能性もあります。 +- 書き起こしの結果をどれだけ修正すればいいかはデータセットに依存しそうです。 +- 手動で書き起こしをいろいろ修正したり結果を細かく確認したい場合は、[Aivis Dataset](https://github.com/litagin02/Aivis-Dataset)もおすすめします。書き起こし部分もかなり工夫されています。ですがファイル数が多い場合などは、このツールで簡易的に切り出してデータセットを作るだけでも十分という気もしています。 +""" + + +def create_dataset_app(): + + with gr.Blocks(theme=GRADIO_THEME) as app: + gr.Markdown(initial_md) + model_name = gr.Textbox( + label="モデル名を入力してください(話者名としても使われます)。" + ) + with gr.Accordion("音声のスライス"): + with gr.Row(): + with gr.Column(): + input_dir = gr.Textbox( + label="入力フォルダ名(デフォルトはinputs)", + placeholder="inputs", + info="下記フォルダにwavファイルを入れておいてください", + ) + min_sec = gr.Slider( + minimum=0, + maximum=10, + value=2, + step=0.5, + label="この秒数未満は切り捨てる", + ) + max_sec = gr.Slider( + minimum=0, + maximum=15, + value=12, + step=0.5, + label="この秒数以上は切り捨てる", + ) + min_silence_dur_ms = gr.Slider( + minimum=0, + maximum=2000, + value=700, + step=100, + label="無音とみなして区切る最小の無音の長さ(ms)", + ) + slice_button = gr.Button("スライスを実行") + result1 = gr.Textbox(label="結果") + with gr.Row(): + with gr.Column(): + whisper_model = gr.Dropdown( + [ + "tiny", + "base", + "small", + "medium", + "large", + "large-v2", + "large-v3", + ], + label="Whisperモデル", + value="large-v3", + ) + compute_type = gr.Dropdown( + [ + "int8", + "int8_float32", + "int8_float16", + "int8_bfloat16", + "int16", + "float16", + "bfloat16", + "float32", + ], + label="計算精度", + value="bfloat16", + ) + device = gr.Radio(["cuda", "cpu"], label="デバイス", value="cuda") + language = gr.Dropdown(["ja", "en", "zh"], value="ja", label="言語") + initial_prompt = gr.Textbox( + label="初期プロンプト", + value="こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!", + info="このように書き起こしてほしいという例文(句読点の入れ方・笑い方・固有名詞等)", + ) + transcribe_button = gr.Button("音声の文字起こし") + result2 = gr.Textbox(label="結果") + slice_button.click( + do_slice, + inputs=[model_name, min_sec, max_sec, min_silence_dur_ms, input_dir], + outputs=[result1], + ) + transcribe_button.click( + do_transcribe, + inputs=[ + model_name, + whisper_model, + compute_type, + language, + initial_prompt, + device, + ], + outputs=[result2], + ) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--server-name", + type=str, + default=None, + help="Server name for Gradio app", + ) + parser.add_argument( + "--no-autolaunch", + action="store_true", + default=False, + help="Do not launch app automatically", + ) + args = parser.parse_args() + + # app.launch(inbrowser=not args.no_autolaunch, server_name=args.server_name) + return app diff --git a/webui/inference.py b/webui/inference.py new file mode 100644 index 000000000..663d5a637 --- /dev/null +++ b/webui/inference.py @@ -0,0 +1,504 @@ +import argparse +import datetime +import json +import os +import sys +from pathlib import Path +from typing import Optional + +import gradio as gr +import torch +import yaml + +from common.constants import ( + DEFAULT_ASSIST_TEXT_WEIGHT, + DEFAULT_LENGTH, + DEFAULT_LINE_SPLIT, + DEFAULT_NOISE, + DEFAULT_NOISEW, + DEFAULT_SDP_RATIO, + DEFAULT_SPLIT_INTERVAL, + DEFAULT_STYLE, + DEFAULT_STYLE_WEIGHT, + GRADIO_THEME, + LATEST_VERSION, + Languages, +) +from common.log import logger +from common.tts_model import ModelHolder +from infer import InvalidToneError +from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize + +# Get path settings +with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: + path_config: dict[str, str] = yaml.safe_load(f.read()) + # dataset_root = path_config["dataset_root"] + assets_root = path_config["assets_root"] + +languages = [l.value for l in Languages] + + +def tts_fn( + model_name, + model_path, + text, + language, + reference_audio_path, + sdp_ratio, + noise_scale, + noise_scale_w, + length_scale, + line_split, + split_interval, + assist_text, + assist_text_weight, + use_assist_text, + style, + style_weight, + kata_tone_json_str, + use_tone, + speaker, + pitch_scale, + intonation_scale, +): + model_holder.load_model_gr(model_name, model_path) + + wrong_tone_message = "" + kata_tone: Optional[list[tuple[str, int]]] = None + if use_tone and kata_tone_json_str != "": + if language != "JP": + logger.warning("Only Japanese is supported for tone generation.") + wrong_tone_message = "アクセント指定は現在日本語のみ対応しています。" + if line_split: + logger.warning("Tone generation is not supported for line split.") + wrong_tone_message = ( + "アクセント指定は改行で分けて生成を使わない場合のみ対応しています。" + ) + try: + kata_tone = [] + json_data = json.loads(kata_tone_json_str) + # tupleを使うように変換 + for kana, tone in json_data: + assert isinstance(kana, str) and tone in (0, 1), f"{kana}, {tone}" + kata_tone.append((kana, tone)) + except Exception as e: + logger.warning(f"Error occurred when parsing kana_tone_json: {e}") + wrong_tone_message = f"アクセント指定が不正です: {e}" + kata_tone = None + + # toneは実際に音声合成に代入される際のみnot Noneになる + tone: Optional[list[int]] = None + if kata_tone is not None: + phone_tone = kata_tone2phone_tone(kata_tone) + tone = [t for _, t in phone_tone] + + speaker_id = model_holder.current_model.spk2id[speaker] + + start_time = datetime.datetime.now() + + assert model_holder.current_model is not None + + try: + sr, audio = model_holder.current_model.infer( + text=text, + language=language, + reference_audio_path=reference_audio_path, + sdp_ratio=sdp_ratio, + noise=noise_scale, + noisew=noise_scale_w, + length=length_scale, + line_split=line_split, + split_interval=split_interval, + assist_text=assist_text, + assist_text_weight=assist_text_weight, + use_assist_text=use_assist_text, + style=style, + style_weight=style_weight, + given_tone=tone, + sid=speaker_id, + pitch_scale=pitch_scale, + intonation_scale=intonation_scale, + ) + except InvalidToneError as e: + logger.error(f"Tone error: {e}") + return f"Error: アクセント指定が不正です:\n{e}", None, kata_tone_json_str + except ValueError as e: + logger.error(f"Value error: {e}") + return f"Error: {e}", None, kata_tone_json_str + + end_time = datetime.datetime.now() + duration = (end_time - start_time).total_seconds() + + if tone is None and language == "JP": + # アクセント指定に使えるようにアクセント情報を返す + norm_text = text_normalize(text) + kata_tone = g2kata_tone(norm_text) + kata_tone_json_str = json.dumps(kata_tone, ensure_ascii=False) + elif tone is None: + kata_tone_json_str = "" + message = f"Success, time: {duration} seconds." + if wrong_tone_message != "": + message = wrong_tone_message + "\n" + message + return message, (sr, audio), kata_tone_json_str + + +initial_text = "こんにちは、初めまして。あなたの名前はなんていうの?" + +examples = [ + [initial_text, "JP"], + [ + """あなたがそんなこと言うなんて、私はとっても嬉しい。 +あなたがそんなこと言うなんて、私はとっても怒ってる。 +あなたがそんなこと言うなんて、私はとっても驚いてる。 +あなたがそんなこと言うなんて、私はとっても辛い。""", + "JP", + ], + [ # ChatGPTに考えてもらった告白セリフ + """私、ずっと前からあなたのことを見てきました。あなたの笑顔、優しさ、強さに、心惹かれていたんです。 +友達として過ごす中で、あなたのことがだんだんと特別な存在になっていくのがわかりました。 +えっと、私、あなたのことが好きです!もしよければ、私と付き合ってくれませんか?""", + "JP", + ], + [ # 夏目漱石『吾輩は猫である』 + """吾輩は猫である。名前はまだ無い。 +どこで生れたかとんと見当がつかぬ。なんでも薄暗いじめじめした所でニャーニャー泣いていた事だけは記憶している。 +吾輩はここで初めて人間というものを見た。しかもあとで聞くと、それは書生という、人間中で一番獰悪な種族であったそうだ。 +この書生というのは時々我々を捕まえて煮て食うという話である。""", + "JP", + ], + [ # 梶井基次郎『桜の樹の下には』 + """桜の樹の下には屍体が埋まっている!これは信じていいことなんだよ。 +何故って、桜の花があんなにも見事に咲くなんて信じられないことじゃないか。俺はあの美しさが信じられないので、このにさんにち不安だった。 +しかしいま、やっとわかるときが来た。桜の樹の下には屍体が埋まっている。これは信じていいことだ。""", + "JP", + ], + [ # ChatGPTと考えた、感情を表すセリフ + """やったー!テストで満点取れた!私とっても嬉しいな! +どうして私の意見を無視するの?許せない!ムカつく!あんたなんか死ねばいいのに。 +あはははっ!この漫画めっちゃ笑える、見てよこれ、ふふふ、あはは。 +あなたがいなくなって、私は一人になっちゃって、泣いちゃいそうなほど悲しい。""", + "JP", + ], + [ # 上の丁寧語バージョン + """やりました!テストで満点取れましたよ!私とっても嬉しいです! +どうして私の意見を無視するんですか?許せません!ムカつきます!あんたなんか死んでください。 +あはははっ!この漫画めっちゃ笑えます、見てくださいこれ、ふふふ、あはは。 +あなたがいなくなって、私は一人になっちゃって、泣いちゃいそうなほど悲しいです。""", + "JP", + ], + [ # ChatGPTに考えてもらった音声合成の説明文章 + """音声合成は、機械学習を活用して、テキストから人の声を再現する技術です。この技術は、言語の構造を解析し、それに基づいて音声を生成します。 +この分野の最新の研究成果を使うと、より自然で表現豊かな音声の生成が可能である。深層学習の応用により、感情やアクセントを含む声質の微妙な変化も再現することが出来る。""", + "JP", + ], + [ + "Speech synthesis is the artificial production of human speech. A computer system used for this purpose is called a speech synthesizer, and can be implemented in software or hardware products.", + "EN", + ], + [ + "语音合成是人工制造人类语音。用于此目的的计算机系统称为语音合成器,可以通过软件或硬件产品实现。", + "ZH", + ], +] + +initial_md = f""" +# Style-Bert-VITS2 ver {LATEST_VERSION} 音声合成 + +- Ver 2.3で追加されたエディターのほうが実際に読み上げさせるには使いやすいかもしれません。`Editor.bat`か`python server_editor.py`で起動できます。 + +- 初期からある[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)は、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。ライセンスは[CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.ja)です。 +""" + +how_to_md = """ +下のように`model_assets`ディレクトリの中にモデルファイルたちを置いてください。 +``` +model_assets +├── your_model +│ ├── config.json +│ ├── your_model_file1.safetensors +│ ├── your_model_file2.safetensors +│ ├── ... +│ └── style_vectors.npy +└── another_model + ├── ... +``` +各モデルにはファイルたちが必要です: +- `config.json`:学習時の設定ファイル +- `*.safetensors`:学習済みモデルファイル(1つ以上が必要、複数可) +- `style_vectors.npy`:スタイルベクトルファイル + +上2つは`Train.bat`による学習で自動的に正しい位置に保存されます。`style_vectors.npy`は`Style.bat`を実行して指示に従って生成してください。 +""" + +style_md = f""" +- プリセットまたは音声ファイルから読み上げの声音・感情・スタイルのようなものを制御できます。 +- デフォルトの{DEFAULT_STYLE}でも、十分に読み上げる文に応じた感情で感情豊かに読み上げられます。このスタイル制御は、それを重み付きで上書きするような感じです。 +- 強さを大きくしすぎると発音が変になったり声にならなかったりと崩壊することがあります。 +- どのくらいに強さがいいかはモデルやスタイルによって異なるようです。 +- 音声ファイルを入力する場合は、学習データと似た声音の話者(特に同じ性別)でないとよい効果が出ないかもしれません。 +""" + + +def make_interactive(): + return gr.update(interactive=True, value="音声合成") + + +def make_non_interactive(): + return gr.update(interactive=False, value="音声合成(モデルをロードしてください)") + + +def gr_util(item): + if item == "プリセットから選ぶ": + return (gr.update(visible=True), gr.Audio(visible=False, value=None)) + else: + return (gr.update(visible=False), gr.update(visible=True)) + + +def create_inference_app(): + parser = argparse.ArgumentParser() + parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU") + parser.add_argument( + "--dir", "-d", type=str, help="Model directory", default=assets_root + ) + parser.add_argument( + "--share", action="store_true", help="Share this app publicly", default=False + ) + parser.add_argument( + "--server-name", + type=str, + default=None, + help="Server name for Gradio app", + ) + parser.add_argument( + "--no-autolaunch", + action="store_true", + default=False, + help="Do not launch app automatically", + ) + args = parser.parse_args() + model_dir = Path(args.dir) + + if args.cpu: + device = "cpu" + else: + device = "cuda" if torch.cuda.is_available() else "cpu" + + model_holder = ModelHolder(model_dir, device) + + model_names = model_holder.model_names + if len(model_names) == 0: + logger.error( + f"モデルが見つかりませんでした。{model_dir}にモデルを置いてください。" + ) + sys.exit(1) + initial_id = 0 + initial_pth_files = model_holder.model_files_dict[model_names[initial_id]] + + with gr.Blocks(theme=GRADIO_THEME) as app: + gr.Markdown(initial_md) + with gr.Accordion(label="使い方", open=False): + gr.Markdown(how_to_md) + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(scale=3): + model_name = gr.Dropdown( + label="モデル一覧", + choices=model_names, + value=model_names[initial_id], + ) + model_path = gr.Dropdown( + label="モデルファイル", + choices=initial_pth_files, + value=initial_pth_files[0], + ) + refresh_button = gr.Button("更新", scale=1, visible=True) + load_button = gr.Button("ロード", scale=1, variant="primary") + text_input = gr.TextArea(label="テキスト", value=initial_text) + pitch_scale = gr.Slider( + minimum=0.8, + maximum=1.5, + value=1, + step=0.05, + label="音程(1以外では音質劣化)", + visible=False, # pyworldが必要 + ) + intonation_scale = gr.Slider( + minimum=0, + maximum=2, + value=1, + step=0.1, + label="抑揚(1以外では音質劣化)", + visible=False, # pyworldが必要 + ) + + line_split = gr.Checkbox( + label="改行で分けて生成(分けたほうが感情が乗ります)", + value=DEFAULT_LINE_SPLIT, + ) + split_interval = gr.Slider( + minimum=0.0, + maximum=2, + value=DEFAULT_SPLIT_INTERVAL, + step=0.1, + label="改行ごとに挟む無音の長さ(秒)", + ) + line_split.change( + lambda x: (gr.Slider(visible=x)), + inputs=[line_split], + outputs=[split_interval], + ) + tone = gr.Textbox( + label="アクセント調整(数値は 0=低 か1=高 のみ)", + info="改行で分けない場合のみ使えます。万能ではありません。", + ) + use_tone = gr.Checkbox(label="アクセント調整を使う", value=False) + use_tone.change( + lambda x: (gr.Checkbox(value=False) if x else gr.Checkbox()), + inputs=[use_tone], + outputs=[line_split], + ) + language = gr.Dropdown(choices=languages, value="JP", label="Language") + speaker = gr.Dropdown(label="話者") + with gr.Accordion(label="詳細設定", open=False): + sdp_ratio = gr.Slider( + minimum=0, + maximum=1, + value=DEFAULT_SDP_RATIO, + step=0.1, + label="SDP Ratio", + ) + noise_scale = gr.Slider( + minimum=0.1, + maximum=2, + value=DEFAULT_NOISE, + step=0.1, + label="Noise", + ) + noise_scale_w = gr.Slider( + minimum=0.1, + maximum=2, + value=DEFAULT_NOISEW, + step=0.1, + label="Noise_W", + ) + length_scale = gr.Slider( + minimum=0.1, + maximum=2, + value=DEFAULT_LENGTH, + step=0.1, + label="Length", + ) + use_assist_text = gr.Checkbox( + label="Assist textを使う", value=False + ) + assist_text = gr.Textbox( + label="Assist text", + placeholder="どうして私の意見を無視するの?許せない、ムカつく!死ねばいいのに。", + info="このテキストの読み上げと似た声音・感情になりやすくなります。ただ抑揚やテンポ等が犠牲になる傾向があります。", + visible=False, + ) + assist_text_weight = gr.Slider( + minimum=0, + maximum=1, + value=DEFAULT_ASSIST_TEXT_WEIGHT, + step=0.1, + label="Assist textの強さ", + visible=False, + ) + use_assist_text.change( + lambda x: (gr.Textbox(visible=x), gr.Slider(visible=x)), + inputs=[use_assist_text], + outputs=[assist_text, assist_text_weight], + ) + with gr.Column(): + with gr.Accordion("スタイルについて詳細", open=False): + gr.Markdown(style_md) + style_mode = gr.Radio( + ["プリセットから選ぶ", "音声ファイルを入力"], + label="スタイルの指定方法", + value="プリセットから選ぶ", + ) + style = gr.Dropdown( + label=f"スタイル({DEFAULT_STYLE}が平均スタイル)", + choices=["モデルをロードしてください"], + value="モデルをロードしてください", + ) + style_weight = gr.Slider( + minimum=0, + maximum=50, + value=DEFAULT_STYLE_WEIGHT, + step=0.1, + label="スタイルの強さ", + ) + ref_audio_path = gr.Audio( + label="参照音声", type="filepath", visible=False + ) + tts_button = gr.Button( + "音声合成(モデルをロードしてください)", + variant="primary", + interactive=False, + ) + text_output = gr.Textbox(label="情報") + audio_output = gr.Audio(label="結果") + with gr.Accordion("テキスト例", open=False): + gr.Examples(examples, inputs=[text_input, language]) + + tts_button.click( + tts_fn, + inputs=[ + model_name, + model_path, + text_input, + language, + ref_audio_path, + sdp_ratio, + noise_scale, + noise_scale_w, + length_scale, + line_split, + split_interval, + assist_text, + assist_text_weight, + use_assist_text, + style, + style_weight, + tone, + use_tone, + speaker, + pitch_scale, + intonation_scale, + ], + outputs=[text_output, audio_output, tone], + ) + + model_name.change( + model_holder.update_model_files_gr, + inputs=[model_name], + outputs=[model_path], + ) + + model_path.change(make_non_interactive, outputs=[tts_button]) + + refresh_button.click( + model_holder.update_model_names_gr, + outputs=[model_name, model_path, tts_button], + ) + + load_button.click( + model_holder.load_model_gr, + inputs=[model_name, model_path], + outputs=[style, tts_button, speaker], + ) + + style_mode.change( + gr_util, + inputs=[style_mode], + outputs=[style, ref_audio_path], + ) + + # app.launch( + # inbrowser=not args.no_autolaunch, share=args.share, server_name=args.server_name + # ) + + return app diff --git a/webui_merge.py b/webui/merge.py similarity index 65% rename from webui_merge.py rename to webui/merge.py index a58471a80..c002e3388 100644 --- a/webui_merge.py +++ b/webui/merge.py @@ -330,187 +330,192 @@ def load_styles_gr(model_name_a, model_name_b): - 構造上の相性の関係で、スタイルベクトルを混ぜる重みは、上の「話し方」と同じ比率で混ぜられます。例えば「話し方」が0のときはモデルAのみしか使われません。 """ -model_names = model_holder.model_names -if len(model_names) == 0: - logger.error( - f"モデルが見つかりませんでした。{assets_root}にモデルを置いてください。" - ) - sys.exit(1) -initial_id = 0 -initial_model_files = model_holder.model_files_dict[model_names[initial_id]] -with gr.Blocks(theme=GRADIO_THEME) as app: - gr.Markdown(initial_md) - with gr.Accordion(label="使い方", open=False): +def create_merge_app(): + model_names = model_holder.model_names + if len(model_names) == 0: + logger.error( + f"モデルが見つかりませんでした。{assets_root}にモデルを置いてください。" + ) + sys.exit(1) + initial_id = 0 + initial_model_files = model_holder.model_files_dict[model_names[initial_id]] + + with gr.Blocks(theme=GRADIO_THEME) as app: gr.Markdown(initial_md) - with gr.Row(): - with gr.Column(scale=3): - model_name_a = gr.Dropdown( - label="モデルA", - choices=model_names, - value=model_names[initial_id], - ) - model_path_a = gr.Dropdown( - label="モデルファイル", - choices=initial_model_files, - value=initial_model_files[0], - ) - with gr.Column(scale=3): - model_name_b = gr.Dropdown( - label="モデルB", - choices=model_names, - value=model_names[initial_id], - ) - model_path_b = gr.Dropdown( - label="モデルファイル", - choices=initial_model_files, - value=initial_model_files[0], - ) - refresh_button = gr.Button("更新", scale=1, visible=True) - with gr.Column(variant="panel"): - new_name = gr.Textbox(label="新しいモデル名", placeholder="new_model") + with gr.Accordion(label="使い方", open=False): + gr.Markdown(initial_md) with gr.Row(): - voice_slider = gr.Slider( - label="声質", - value=0, - minimum=0, - maximum=1, - step=0.1, - ) - voice_pitch_slider = gr.Slider( - label="声の高さ", - value=0, - minimum=0, - maximum=1, - step=0.1, - ) - speech_style_slider = gr.Slider( - label="話し方(抑揚・感情表現等)", - value=0, - minimum=0, - maximum=1, - step=0.1, - ) - tempo_slider = gr.Slider( - label="話す速さ・リズム・テンポ", - value=0, - minimum=0, - maximum=1, - step=0.1, - ) - use_slerp_instead_of_lerp = gr.Checkbox( - label="線形補完のかわりに球面線形補完を使う", - value=False, - ) - with gr.Column(variant="panel"): - gr.Markdown("## モデルファイル(safetensors)のマージ") - model_merge_button = gr.Button("モデルファイルのマージ", variant="primary") - info_model_merge = gr.Textbox(label="情報") + with gr.Column(scale=3): + model_name_a = gr.Dropdown( + label="モデルA", + choices=model_names, + value=model_names[initial_id], + ) + model_path_a = gr.Dropdown( + label="モデルファイル", + choices=initial_model_files, + value=initial_model_files[0], + ) + with gr.Column(scale=3): + model_name_b = gr.Dropdown( + label="モデルB", + choices=model_names, + value=model_names[initial_id], + ) + model_path_b = gr.Dropdown( + label="モデルファイル", + choices=initial_model_files, + value=initial_model_files[0], + ) + refresh_button = gr.Button("更新", scale=1, visible=True) with gr.Column(variant="panel"): - gr.Markdown(style_merge_md) + new_name = gr.Textbox(label="新しいモデル名", placeholder="new_model") with gr.Row(): - load_style_button = gr.Button("スタイル一覧をロード", scale=1) - styles_a = gr.Textbox(label="モデルAのスタイル一覧") - styles_b = gr.Textbox(label="モデルBのスタイル一覧") - style_triple_list = gr.TextArea( - label="スタイルのマージリスト", - placeholder=f"{DEFAULT_STYLE}, {DEFAULT_STYLE},{DEFAULT_STYLE}\nAngry, Angry, Angry", - value=f"{DEFAULT_STYLE}, {DEFAULT_STYLE}, {DEFAULT_STYLE}", - ) - style_merge_button = gr.Button("スタイルのマージ", variant="primary") - info_style_merge = gr.Textbox(label="情報") + voice_slider = gr.Slider( + label="声質", + value=0, + minimum=0, + maximum=1, + step=0.1, + ) + voice_pitch_slider = gr.Slider( + label="声の高さ", + value=0, + minimum=0, + maximum=1, + step=0.1, + ) + speech_style_slider = gr.Slider( + label="話し方(抑揚・感情表現等)", + value=0, + minimum=0, + maximum=1, + step=0.1, + ) + tempo_slider = gr.Slider( + label="話す速さ・リズム・テンポ", + value=0, + minimum=0, + maximum=1, + step=0.1, + ) + use_slerp_instead_of_lerp = gr.Checkbox( + label="線形補完のかわりに球面線形補完を使う", + value=False, + ) + with gr.Column(variant="panel"): + gr.Markdown("## モデルファイル(safetensors)のマージ") + model_merge_button = gr.Button( + "モデルファイルのマージ", variant="primary" + ) + info_model_merge = gr.Textbox(label="情報") + with gr.Column(variant="panel"): + gr.Markdown(style_merge_md) + with gr.Row(): + load_style_button = gr.Button("スタイル一覧をロード", scale=1) + styles_a = gr.Textbox(label="モデルAのスタイル一覧") + styles_b = gr.Textbox(label="モデルBのスタイル一覧") + style_triple_list = gr.TextArea( + label="スタイルのマージリスト", + placeholder=f"{DEFAULT_STYLE}, {DEFAULT_STYLE},{DEFAULT_STYLE}\nAngry, Angry, Angry", + value=f"{DEFAULT_STYLE}, {DEFAULT_STYLE}, {DEFAULT_STYLE}", + ) + style_merge_button = gr.Button("スタイルのマージ", variant="primary") + info_style_merge = gr.Textbox(label="情報") + + text_input = gr.TextArea( + label="テキスト", value="これはテストです。聞こえていますか?" + ) + style = gr.Dropdown( + label="スタイル", + choices=["スタイルをマージしてください"], + value="スタイルをマージしてください", + ) + emotion_weight = gr.Slider( + minimum=0, + maximum=50, + value=1, + step=0.1, + label="スタイルの強さ", + ) + tts_button = gr.Button("音声合成", variant="primary") + audio_output = gr.Audio(label="結果") - text_input = gr.TextArea( - label="テキスト", value="これはテストです。聞こえていますか?" - ) - style = gr.Dropdown( - label="スタイル", - choices=["スタイルをマージしてください"], - value="スタイルをマージしてください", - ) - emotion_weight = gr.Slider( - minimum=0, - maximum=50, - value=1, - step=0.1, - label="スタイルの強さ", - ) - tts_button = gr.Button("音声合成", variant="primary") - audio_output = gr.Audio(label="結果") + model_name_a.change( + model_holder.update_model_files_gr, + inputs=[model_name_a], + outputs=[model_path_a], + ) + model_name_b.change( + model_holder.update_model_files_gr, + inputs=[model_name_b], + outputs=[model_path_b], + ) - model_name_a.change( - model_holder.update_model_files_gr, - inputs=[model_name_a], - outputs=[model_path_a], - ) - model_name_b.change( - model_holder.update_model_files_gr, - inputs=[model_name_b], - outputs=[model_path_b], - ) + refresh_button.click( + update_two_model_names_dropdown, + outputs=[model_name_a, model_path_a, model_name_b, model_path_b], + ) - refresh_button.click( - update_two_model_names_dropdown, - outputs=[model_name_a, model_path_a, model_name_b, model_path_b], - ) + load_style_button.click( + load_styles_gr, + inputs=[model_name_a, model_name_b], + outputs=[styles_a, styles_b, style_triple_list], + ) - load_style_button.click( - load_styles_gr, - inputs=[model_name_a, model_name_b], - outputs=[styles_a, styles_b, style_triple_list], - ) + model_merge_button.click( + merge_models_gr, + inputs=[ + model_name_a, + model_path_a, + model_name_b, + model_path_b, + new_name, + voice_slider, + voice_pitch_slider, + speech_style_slider, + tempo_slider, + use_slerp_instead_of_lerp, + ], + outputs=[info_model_merge], + ) - model_merge_button.click( - merge_models_gr, - inputs=[ - model_name_a, - model_path_a, - model_name_b, - model_path_b, - new_name, - voice_slider, - voice_pitch_slider, - speech_style_slider, - tempo_slider, - use_slerp_instead_of_lerp, - ], - outputs=[info_model_merge], - ) + style_merge_button.click( + merge_style_gr, + inputs=[ + model_name_a, + model_name_b, + speech_style_slider, + new_name, + style_triple_list, + ], + outputs=[info_style_merge, style], + ) - style_merge_button.click( - merge_style_gr, - inputs=[ - model_name_a, - model_name_b, - speech_style_slider, - new_name, - style_triple_list, - ], - outputs=[info_style_merge, style], - ) + tts_button.click( + simple_tts, + inputs=[new_name, text_input, style, emotion_weight], + outputs=[audio_output], + ) - tts_button.click( - simple_tts, - inputs=[new_name, text_input, style, emotion_weight], - outputs=[audio_output], + parser = argparse.ArgumentParser() + parser.add_argument( + "--server-name", + type=str, + default=None, + help="Server name for Gradio app", + ) + parser.add_argument( + "--no-autolaunch", + action="store_true", + default=False, + help="Do not launch app automatically", ) + parser.add_argument("--share", action="store_true", default=False) + args = parser.parse_args() -parser = argparse.ArgumentParser() -parser.add_argument( - "--server-name", - type=str, - default=None, - help="Server name for Gradio app", -) -parser.add_argument( - "--no-autolaunch", - action="store_true", - default=False, - help="Do not launch app automatically", -) -parser.add_argument("--share", action="store_true", default=False) -args = parser.parse_args() - -app.launch( - inbrowser=not args.no_autolaunch, server_name=args.server_name, share=args.share -) + # app.launch( + # inbrowser=not args.no_autolaunch, server_name=args.server_name, share=args.share + # ) + return app diff --git a/webui_style_vectors.py b/webui/style_vectors.py similarity index 65% rename from webui_style_vectors.py rename to webui/style_vectors.py index b89c9c9f8..27b95bb65 100644 --- a/webui_style_vectors.py +++ b/webui/style_vectors.py @@ -323,158 +323,161 @@ def save_style_vectors_from_files( https://ja.wikipedia.org/wiki/DBSCAN """ -with gr.Blocks(theme=GRADIO_THEME) as app: - gr.Markdown(initial_md) - with gr.Row(): - model_name = gr.Textbox(placeholder="your_model_name", label="モデル名") - reduction_method = gr.Radio( - choices=["UMAP", "t-SNE"], - label="次元削減方法", - info="v 1.3以前はt-SNEでしたがUMAPのほうがよい可能性もあります。", - value="UMAP", - ) - load_button = gr.Button("スタイルベクトルを読み込む", variant="primary") - output = gr.Plot(label="音声スタイルの可視化") - load_button.click(load, inputs=[model_name, reduction_method], outputs=[output]) - with gr.Tab("方法1: スタイル分けを自動で行う"): - with gr.Tab("スタイル分け1"): - n_clusters = gr.Slider( - minimum=2, - maximum=10, - step=1, - value=4, - label="作るスタイルの数(平均スタイルを除く)", - info="上の図を見ながらスタイルの数を試行錯誤してください。", - ) - c_method = gr.Radio( - choices=[ - "Agglomerative after reduction", - "KMeans after reduction", - "Agglomerative", - "KMeans", - ], - label="アルゴリズム", - info="分類する(クラスタリング)アルゴリズムを選択します。いろいろ試してみてください。", - value="Agglomerative after reduction", - ) - c_button = gr.Button("スタイル分けを実行") - with gr.Tab("スタイル分け2: DBSCAN"): - gr.Markdown(dbscan_md) - eps = gr.Slider( - minimum=0.1, - maximum=10, - step=0.01, - value=0.3, - label="eps", - ) - min_samples = gr.Slider( - minimum=1, - maximum=50, - step=1, - value=15, - label="min_samples", - ) - with gr.Row(): - dbscan_button = gr.Button("スタイル分けを実行") - num_styles_result = gr.Textbox(label="スタイル数") - gr.Markdown("スタイル分けの結果") - gr.Markdown( - "注意: もともと256次元なものをを2次元に落としているので、正確なベクトルの位置関係ではありません。" - ) + +def create_style_vectors_app(): + with gr.Blocks(theme=GRADIO_THEME) as app: + gr.Markdown(initial_md) with gr.Row(): - gr_plot = gr.Plot() - with gr.Column(): - with gr.Row(): - cluster_index = gr.Slider( - minimum=1, - maximum=MAX_CLUSTER_NUM, - step=1, - value=1, - label="スタイル番号", - info="選択したスタイルの代表音声を表示します。", - ) - num_files = gr.Slider( - minimum=1, - maximum=MAX_AUDIO_NUM, - step=1, - value=5, - label="代表音声の数をいくつ表示するか", - ) - get_audios_button = gr.Button("代表音声を取得") - with gr.Row(): - audio_list = [] - for i in range(MAX_AUDIO_NUM): - audio_list.append(gr.Audio(visible=False, show_label=True)) - c_button.click( - do_clustering_gradio, - inputs=[n_clusters, c_method], - outputs=[gr_plot, cluster_index] + audio_list, + model_name = gr.Textbox(placeholder="your_model_name", label="モデル名") + reduction_method = gr.Radio( + choices=["UMAP", "t-SNE"], + label="次元削減方法", + info="v 1.3以前はt-SNEでしたがUMAPのほうがよい可能性もあります。", + value="UMAP", ) - dbscan_button.click( - do_dbscan_gradio, - inputs=[eps, min_samples], - outputs=[gr_plot, cluster_index, num_styles_result] + audio_list, + load_button = gr.Button("スタイルベクトルを読み込む", variant="primary") + output = gr.Plot(label="音声スタイルの可視化") + load_button.click(load, inputs=[model_name, reduction_method], outputs=[output]) + with gr.Tab("方法1: スタイル分けを自動で行う"): + with gr.Tab("スタイル分け1"): + n_clusters = gr.Slider( + minimum=2, + maximum=10, + step=1, + value=4, + label="作るスタイルの数(平均スタイルを除く)", + info="上の図を見ながらスタイルの数を試行錯誤してください。", + ) + c_method = gr.Radio( + choices=[ + "Agglomerative after reduction", + "KMeans after reduction", + "Agglomerative", + "KMeans", + ], + label="アルゴリズム", + info="分類する(クラスタリング)アルゴリズムを選択します。いろいろ試してみてください。", + value="Agglomerative after reduction", + ) + c_button = gr.Button("スタイル分けを実行") + with gr.Tab("スタイル分け2: DBSCAN"): + gr.Markdown(dbscan_md) + eps = gr.Slider( + minimum=0.1, + maximum=10, + step=0.01, + value=0.3, + label="eps", + ) + min_samples = gr.Slider( + minimum=1, + maximum=50, + step=1, + value=15, + label="min_samples", + ) + with gr.Row(): + dbscan_button = gr.Button("スタイル分けを実行") + num_styles_result = gr.Textbox(label="スタイル数") + gr.Markdown("スタイル分けの結果") + gr.Markdown( + "注意: もともと256次元なものをを2次元に落としているので、正確なベクトルの位置関係ではありません。" ) - get_audios_button.click( - representative_wav_files_gradio, - inputs=[cluster_index, num_files], - outputs=audio_list, + with gr.Row(): + gr_plot = gr.Plot() + with gr.Column(): + with gr.Row(): + cluster_index = gr.Slider( + minimum=1, + maximum=MAX_CLUSTER_NUM, + step=1, + value=1, + label="スタイル番号", + info="選択したスタイルの代表音声を表示します。", + ) + num_files = gr.Slider( + minimum=1, + maximum=MAX_AUDIO_NUM, + step=1, + value=5, + label="代表音声の数をいくつ表示するか", + ) + get_audios_button = gr.Button("代表音声を取得") + with gr.Row(): + audio_list = [] + for i in range(MAX_AUDIO_NUM): + audio_list.append(gr.Audio(visible=False, show_label=True)) + c_button.click( + do_clustering_gradio, + inputs=[n_clusters, c_method], + outputs=[gr_plot, cluster_index] + audio_list, + ) + dbscan_button.click( + do_dbscan_gradio, + inputs=[eps, min_samples], + outputs=[gr_plot, cluster_index, num_styles_result] + audio_list, + ) + get_audios_button.click( + representative_wav_files_gradio, + inputs=[cluster_index, num_files], + outputs=audio_list, + ) + gr.Markdown("結果が良さそうなら、これを保存します。") + style_names = gr.Textbox( + "Angry, Sad, Happy", + label="スタイルの名前", + info=f"スタイルの名前を`,`で区切って入力してください(日本語可)。例: `Angry, Sad, Happy`や`怒り, 悲しみ, 喜び`など。平均音声は{DEFAULT_STYLE}として自動的に保存されます。", ) - gr.Markdown("結果が良さそうなら、これを保存します。") - style_names = gr.Textbox( - "Angry, Sad, Happy", - label="スタイルの名前", - info=f"スタイルの名前を`,`で区切って入力してください(日本語可)。例: `Angry, Sad, Happy`や`怒り, 悲しみ, 喜び`など。平均音声は{DEFAULT_STYLE}として自動的に保存されます。", - ) - with gr.Row(): - save_button1 = gr.Button("スタイルベクトルを保存", variant="primary") - info2 = gr.Textbox(label="保存結果") + with gr.Row(): + save_button1 = gr.Button("スタイルベクトルを保存", variant="primary") + info2 = gr.Textbox(label="保存結果") - save_button1.click( - save_style_vectors_from_clustering, - inputs=[model_name, style_names], - outputs=[info2], - ) - with gr.Tab("方法2: 手動でスタイルを選ぶ"): - gr.Markdown( - "下のテキスト欄に、各スタイルの代表音声のファイル名を`,`区切りで、その横に対応するスタイル名を`,`区切りで入力してください。" - ) - gr.Markdown("例: `angry.wav, sad.wav, happy.wav`と`Angry, Sad, Happy`") - gr.Markdown( - f"注意: {DEFAULT_STYLE}スタイルは自動的に保存されます、手動では{DEFAULT_STYLE}という名前のスタイルは指定しないでください。" - ) - with gr.Row(): - audio_files_text = gr.Textbox( - label="音声ファイル名", placeholder="angry.wav, sad.wav, happy.wav" + save_button1.click( + save_style_vectors_from_clustering, + inputs=[model_name, style_names], + outputs=[info2], ) - style_names_text = gr.Textbox( - label="スタイル名", placeholder="Angry, Sad, Happy" + with gr.Tab("方法2: 手動でスタイルを選ぶ"): + gr.Markdown( + "下のテキスト欄に、各スタイルの代表音声のファイル名を`,`区切りで、その横に対応するスタイル名を`,`区切りで入力してください。" ) - with gr.Row(): - save_button2 = gr.Button("スタイルベクトルを保存", variant="primary") - info2 = gr.Textbox(label="保存結果") - save_button2.click( - save_style_vectors_from_files, - inputs=[model_name, audio_files_text, style_names_text], - outputs=[info2], + gr.Markdown("例: `angry.wav, sad.wav, happy.wav`と`Angry, Sad, Happy`") + gr.Markdown( + f"注意: {DEFAULT_STYLE}スタイルは自動的に保存されます、手動では{DEFAULT_STYLE}という名前のスタイルは指定しないでください。" ) + with gr.Row(): + audio_files_text = gr.Textbox( + label="音声ファイル名", placeholder="angry.wav, sad.wav, happy.wav" + ) + style_names_text = gr.Textbox( + label="スタイル名", placeholder="Angry, Sad, Happy" + ) + with gr.Row(): + save_button2 = gr.Button("スタイルベクトルを保存", variant="primary") + info2 = gr.Textbox(label="保存結果") + save_button2.click( + save_style_vectors_from_files, + inputs=[model_name, audio_files_text, style_names_text], + outputs=[info2], + ) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--server-name", + type=str, + default=None, + help="Server name for Gradio app", + ) + parser.add_argument( + "--no-autolaunch", + action="store_true", + default=False, + help="Do not launch app automatically", + ) + parser.add_argument("--share", action="store_true", default=False) + args = parser.parse_args() -parser = argparse.ArgumentParser() -parser.add_argument( - "--server-name", - type=str, - default=None, - help="Server name for Gradio app", -) -parser.add_argument( - "--no-autolaunch", - action="store_true", - default=False, - help="Do not launch app automatically", -) -parser.add_argument("--share", action="store_true", default=False) -args = parser.parse_args() - -app.launch( - inbrowser=not args.no_autolaunch, server_name=args.server_name, share=args.share -) + # app.launch( + # inbrowser=not args.no_autolaunch, server_name=args.server_name, share=args.share + # ) + return app diff --git a/webui_train.py b/webui/train.py similarity index 99% rename from webui_train.py rename to webui/train.py index 59cc9f594..1c9a2145e 100644 --- a/webui_train.py +++ b/webui/train.py @@ -450,7 +450,8 @@ def run_tensorboard(model_name): 日本語話者の単一話者データセットでも構いません。 """ -if __name__ == "__main__": + +def create_train_app(): with gr.Blocks(theme=GRADIO_THEME).queue() as app: gr.Markdown(initial_md) with gr.Accordion(label="データの前準備", open=False): @@ -808,4 +809,5 @@ def run_tensorboard(model_name): ) args = parser.parse_args() - app.launch(inbrowser=not args.no_autolaunch, server_name=args.server_name) + # app.launch(inbrowser=not args.no_autolaunch, server_name=args.server_name) + return app From bc1058270c9a8648cae01bd301a136ef4ff92466 Mon Sep 17 00:00:00 2001 From: kale4eat Date: Tue, 5 Mar 2024 12:16:14 +0900 Subject: [PATCH 002/148] Delete pyopenjtalk import --- server_editor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server_editor.py b/server_editor.py index afb9321df..1ae323f32 100644 --- a/server_editor.py +++ b/server_editor.py @@ -19,7 +19,6 @@ import yaml import numpy as np -import pyopenjtalk import requests import torch import uvicorn From 1eb8fb4b08c5dba3c2d7bdd40bf56da46369d321 Mon Sep 17 00:00:00 2001 From: kale4eat Date: Tue, 5 Mar 2024 12:17:15 +0900 Subject: [PATCH 003/148] add openjtalk worker pkg --- text/pyopenjtalk_worker/__init__.py | 100 ++++++++++++++++++++++ text/pyopenjtalk_worker/__main__.py | 16 ++++ text/pyopenjtalk_worker/worker_client.py | 40 +++++++++ text/pyopenjtalk_worker/worker_common.py | 44 ++++++++++ text/pyopenjtalk_worker/worker_server.py | 103 +++++++++++++++++++++++ 5 files changed, 303 insertions(+) create mode 100644 text/pyopenjtalk_worker/__init__.py create mode 100644 text/pyopenjtalk_worker/__main__.py create mode 100644 text/pyopenjtalk_worker/worker_client.py create mode 100644 text/pyopenjtalk_worker/worker_common.py create mode 100644 text/pyopenjtalk_worker/worker_server.py diff --git a/text/pyopenjtalk_worker/__init__.py b/text/pyopenjtalk_worker/__init__.py new file mode 100644 index 000000000..6ab4ca162 --- /dev/null +++ b/text/pyopenjtalk_worker/__init__.py @@ -0,0 +1,100 @@ +""" +Run the pyopenjtalk worker in a separate process +to avoid user dictionary access error +""" + +from typing import Optional, Any + +from .worker_common import WOKER_PORT +from .worker_client import WorkerClient + +from common.log import logger + +WORKER_CLIENT: Optional[WorkerClient] = None + +# pyopenjtalk interface + +# g2p: not used + + +def run_frontend(text: str) -> list[dict[str, Any]]: + assert WORKER_CLIENT + ret = WORKER_CLIENT.dispatch_pyopenjtalk("run_frontend", text) + assert isinstance(ret, list) + return ret + + +def make_label(njd_features) -> list[str]: + assert WORKER_CLIENT + ret = WORKER_CLIENT.dispatch_pyopenjtalk("make_label", njd_features) + assert isinstance(ret, list) + return ret + + +def mecab_dict_index(path: str, out_path: str, dn_mecab: Optional[str] = None): + assert WORKER_CLIENT + WORKER_CLIENT.dispatch_pyopenjtalk("mecab_dict_index", path, out_path, dn_mecab) + + +def update_global_jtalk_with_user_dict(path: str): + assert WORKER_CLIENT + WORKER_CLIENT.dispatch_pyopenjtalk("update_global_jtalk_with_user_dict", path) + + +def unset_user_dict(): + assert WORKER_CLIENT + WORKER_CLIENT.dispatch_pyopenjtalk("unset_user_dict") + + +# initialize module when imported + + +def initialize(port: int = WOKER_PORT): + import time + import socket + import sys + import atexit + + global WORKER_CLIENT + logger.debug("initialize") + if WORKER_CLIENT: + return + + client = None + try: + client = WorkerClient(port) + except (socket.timeout, socket.error): + logger.debug("try starting worker server") + import os + import subprocess + + worker_pkg_path = os.path.relpath( + os.path.dirname(__file__), os.getcwd() + ).replace(os.sep, ".") + subprocess.Popen([sys.executable, "-m", worker_pkg_path, "--port", str(port)]) + # wait until server listening + count = 0 + while True: + try: + client = WorkerClient(port) + break + except socket.error: + time.sleep(1) + count += 1 + # 10: max number of retries + if count == 10: + raise TimeoutError("サーバーに接続できませんでした") + + WORKER_CLIENT = client + + def terminate(): + global WORKER_CLIENT + if not WORKER_CLIENT: + return + + if WORKER_CLIENT.status().get("client-count") == 1: + WORKER_CLIENT.quit_server() + WORKER_CLIENT.close() + WORKER_CLIENT = None + + atexit.register(terminate) diff --git a/text/pyopenjtalk_worker/__main__.py b/text/pyopenjtalk_worker/__main__.py new file mode 100644 index 000000000..8b67aa0bb --- /dev/null +++ b/text/pyopenjtalk_worker/__main__.py @@ -0,0 +1,16 @@ +import argparse + +from .worker_server import WorkerServer +from .worker_common import WOKER_PORT + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=WOKER_PORT) + args = parser.parse_args() + server = WorkerServer() + server.start_server(port=args.port) + + +if __name__ == "__main__": + main() diff --git a/text/pyopenjtalk_worker/worker_client.py b/text/pyopenjtalk_worker/worker_client.py new file mode 100644 index 000000000..bd9a32abc --- /dev/null +++ b/text/pyopenjtalk_worker/worker_client.py @@ -0,0 +1,40 @@ +from typing import Any +import socket + +from .worker_common import RequestType, receive_data, send_data + + +class WorkerClient: + def __init__(self, port: int) -> None: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # 5: timeout + sock.settimeout(5) + sock.connect((socket.gethostname(), port)) + self.sock = sock + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def close(self): + self.sock.close() + + def dispatch_pyopenjtalk(self, func: str, *args, **kwargs): + data = { + "request-type": RequestType.PYOPENJTALK, + "func": func, + "args": args, + "kwargs": kwargs, + } + send_data(self.sock, data) + return receive_data(self.sock).get("return") + + def status(self): + send_data(self.sock, {"request-type": RequestType.STATUS}) + return receive_data(self.sock) + + def quit_server(self): + send_data(self.sock, {"request-type": RequestType.QUIT_SERVER}) + receive_data(self.sock) diff --git a/text/pyopenjtalk_worker/worker_common.py b/text/pyopenjtalk_worker/worker_common.py new file mode 100644 index 000000000..bea552e6a --- /dev/null +++ b/text/pyopenjtalk_worker/worker_common.py @@ -0,0 +1,44 @@ +from typing import Any, Optional, Final +from enum import IntEnum, auto +import socket +import json + +WOKER_PORT: Final[int] = 7861 +HEADER_SIZE: Final[int] = 4 + + +class RequestType(IntEnum): + STATUS = auto() + QUIT_SERVER = auto() + PYOPENJTALK = auto() + + +class ConnectionClosedException(Exception): + pass + + +# socket communication + + +def send_data(sock: socket.socket, data: dict[str, Any]): + json_data = json.dumps(data).encode() + header = len(json_data).to_bytes(HEADER_SIZE, byteorder="big") + sock.sendall(header + json_data) + + +def _receive_until(sock: socket.socket, size: int): + data = b"" + while len(data) < size: + part = sock.recv(size - len(data)) + if part == b"": + raise ConnectionClosedException("接続が閉じられました") + data += part + + return data + + +def receive_data(sock: socket.socket) -> dict[str, Any]: + header = _receive_until(sock, HEADER_SIZE) + data_length = int.from_bytes(header, byteorder="big") + body = _receive_until(sock, data_length) + return json.loads(body.decode()) diff --git a/text/pyopenjtalk_worker/worker_server.py b/text/pyopenjtalk_worker/worker_server.py new file mode 100644 index 000000000..a2b9f2f0e --- /dev/null +++ b/text/pyopenjtalk_worker/worker_server.py @@ -0,0 +1,103 @@ +import pyopenjtalk +import socket +import select + + +from .worker_common import ( + ConnectionClosedException, + RequestType, + receive_data, + send_data, +) + +from common.log import logger + +# To make it as fast as possible +# Probably faster than calling getattr every time +_PYOPENJTALK_FUNC_DICT = { + "run_frontend": pyopenjtalk.run_frontend, + "make_label": pyopenjtalk.make_label, + "mecab_dict_index": pyopenjtalk.mecab_dict_index, + "update_global_jtalk_with_user_dict": pyopenjtalk.update_global_jtalk_with_user_dict, + "unset_user_dict": pyopenjtalk.unset_user_dict, +} + + +class WorkerServer: + def __init__(self) -> None: + self.client_count: int = 0 + self.quit: bool = False + + def handle_request(self, request): + request_type = None + try: + request_type = RequestType(request.get("request-type")) + except Exception: + return { + "success": False, + "reason": "request-type is invalid", + } + + if request_type: + if request_type == RequestType.STATUS: + response = { + "success": True, + "client-count": self.client_count, + } + elif request_type == RequestType.QUIT_SERVER: + self.quit = True + response = {"success": True} + elif request_type == RequestType.PYOPENJTALK: + func_name = request.get("func") + assert isinstance(func_name, str) + func = _PYOPENJTALK_FUNC_DICT[func_name] + args = request.get("args") + kwargs = request.get("kwargs") + assert isinstance(args, list) + assert isinstance(kwargs, dict) + ret = func(*args, **kwargs) + response = {"success": True, "return": ret} + else: + # NOT REACHED + response = request + + return response + + def start_server(self, port: int): + logger.info("start pyopenjtalk worker server") + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: + server_socket.bind((socket.gethostname(), port)) + server_socket.listen() + sockets = [server_socket] + while True: + ready_sockets, _, _ = select.select(sockets, [], [], 0.1) + for sock in ready_sockets: + if sock is server_socket: + logger.info("new client connected") + client_socket, _ = server_socket.accept() + sockets.append(client_socket) + self.client_count += 1 + else: + # client + try: + request = receive_data(sock) + except ConnectionClosedException as e: + sock.close() + sockets.remove(sock) + self.client_count -= 1 + logger.info("close connection") + continue + + logger.debug(f"receive request: {request}") + + response = self.handle_request(request) + logger.debug(f"send response: {response}") + try: + send_data(sock, response) + except Exception: + logger.warning( + "an exception occurred during sending responce" + ) + if self.quit: + logger.info("quit pyopenjtalk worker server") + return From 85f5b9bd25e930dc8e0fd8e96614cd1eaa5127f7 Mon Sep 17 00:00:00 2001 From: kale4eat Date: Tue, 5 Mar 2024 12:18:05 +0900 Subject: [PATCH 004/148] replace pyopenjtalk import with worker --- text/japanese.py | 4 +++- text/user_dict/__init__.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/text/japanese.py b/text/japanese.py index b18bc682e..fea0eaa5d 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -4,7 +4,9 @@ import unicodedata from pathlib import Path -import pyopenjtalk +from . import pyopenjtalk_worker as pyopenjtalk + +pyopenjtalk.initialize() from num2words import num2words from transformers import AutoTokenizer diff --git a/text/user_dict/__init__.py b/text/user_dict/__init__.py index c12b3d1aa..95515cbbc 100644 --- a/text/user_dict/__init__.py +++ b/text/user_dict/__init__.py @@ -12,7 +12,9 @@ from uuid import UUID, uuid4 import numpy as np -import pyopenjtalk +from .. import pyopenjtalk_worker as pyopenjtalk + +pyopenjtalk.initialize() from fastapi import HTTPException from .word_model import UserDictWord, WordTypes From 98ff976ec0e187bef23c6bf5d91bbe09829c9474 Mon Sep 17 00:00:00 2001 From: kale4eat Date: Tue, 5 Mar 2024 15:04:27 +0900 Subject: [PATCH 005/148] modify logging --- text/pyopenjtalk_worker/__init__.py | 2 +- text/pyopenjtalk_worker/worker_client.py | 8 +++++++- text/pyopenjtalk_worker/worker_server.py | 6 +++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/text/pyopenjtalk_worker/__init__.py b/text/pyopenjtalk_worker/__init__.py index 6ab4ca162..ef5873c7a 100644 --- a/text/pyopenjtalk_worker/__init__.py +++ b/text/pyopenjtalk_worker/__init__.py @@ -64,7 +64,7 @@ def initialize(port: int = WOKER_PORT): try: client = WorkerClient(port) except (socket.timeout, socket.error): - logger.debug("try starting worker server") + logger.debug("try starting pyopenjtalk worker server") import os import subprocess diff --git a/text/pyopenjtalk_worker/worker_client.py b/text/pyopenjtalk_worker/worker_client.py index bd9a32abc..23f7dbe8a 100644 --- a/text/pyopenjtalk_worker/worker_client.py +++ b/text/pyopenjtalk_worker/worker_client.py @@ -3,6 +3,8 @@ from .worker_common import RequestType, receive_data, send_data +from common.log import logger + class WorkerClient: def __init__(self, port: int) -> None: @@ -28,8 +30,12 @@ def dispatch_pyopenjtalk(self, func: str, *args, **kwargs): "args": args, "kwargs": kwargs, } + logger.trace(f"client sends request: {data}") send_data(self.sock, data) - return receive_data(self.sock).get("return") + logger.trace("client sent request successfully") + response = receive_data(self.sock) + logger.trace(f"client received response: {response}") + return response.get("return") def status(self): send_data(self.sock, {"request-type": RequestType.STATUS}) diff --git a/text/pyopenjtalk_worker/worker_server.py b/text/pyopenjtalk_worker/worker_server.py index a2b9f2f0e..13bc7716d 100644 --- a/text/pyopenjtalk_worker/worker_server.py +++ b/text/pyopenjtalk_worker/worker_server.py @@ -2,7 +2,6 @@ import socket import select - from .worker_common import ( ConnectionClosedException, RequestType, @@ -88,12 +87,13 @@ def start_server(self, port: int): logger.info("close connection") continue - logger.debug(f"receive request: {request}") + logger.trace(f"server received request: {request}") response = self.handle_request(request) - logger.debug(f"send response: {response}") + logger.trace(f"server sends response: {response}") try: send_data(sock, response) + logger.trace("server sent response successfully") except Exception: logger.warning( "an exception occurred during sending responce" From 2ab025d5fe4a21522ffe7c9ba7d12a0a061cfa1d Mon Sep 17 00:00:00 2001 From: kale4eat Date: Wed, 6 Mar 2024 09:45:25 +0900 Subject: [PATCH 006/148] Run in a separate process group to avoid receiving signals by Ctrl + C Minor Correction: * logging * declaration of terminate --- text/pyopenjtalk_worker/__init__.py | 34 +++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/text/pyopenjtalk_worker/__init__.py b/text/pyopenjtalk_worker/__init__.py index ef5873c7a..3fd997153 100644 --- a/text/pyopenjtalk_worker/__init__.py +++ b/text/pyopenjtalk_worker/__init__.py @@ -55,8 +55,8 @@ def initialize(port: int = WOKER_PORT): import sys import atexit - global WORKER_CLIENT logger.debug("initialize") + global WORKER_CLIENT if WORKER_CLIENT: return @@ -71,7 +71,16 @@ def initialize(port: int = WOKER_PORT): worker_pkg_path = os.path.relpath( os.path.dirname(__file__), os.getcwd() ).replace(os.sep, ".") - subprocess.Popen([sys.executable, "-m", worker_pkg_path, "--port", str(port)]) + args = [sys.executable, "-m", worker_pkg_path, "--port", str(port)] + # new session, new process group + if sys.platform.startswith("win"): + cf = subprocess.DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore + subprocess.Popen(args, creationflags=cf) + else: + # align with Windows behavior + # start_new_session is same as specifying setsid in preexec_fn + subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True) # type: ignore + # wait until server listening count = 0 while True: @@ -86,15 +95,22 @@ def initialize(port: int = WOKER_PORT): raise TimeoutError("サーバーに接続できませんでした") WORKER_CLIENT = client + atexit.register(terminate) - def terminate(): - global WORKER_CLIENT - if not WORKER_CLIENT: - return +# top-level declaration +def terminate(): + logger.debug("terminate") + global WORKER_CLIENT + if not WORKER_CLIENT: + return + + # repare for unexpected errors + try: if WORKER_CLIENT.status().get("client-count") == 1: WORKER_CLIENT.quit_server() - WORKER_CLIENT.close() - WORKER_CLIENT = None + except Exception as e: + logger.error(e) - atexit.register(terminate) + WORKER_CLIENT.close() + WORKER_CLIENT = None From 45c6bde2e75b3115ce32cb50059d783495d8168e Mon Sep 17 00:00:00 2001 From: kale4eat Date: Wed, 6 Mar 2024 11:19:55 +0900 Subject: [PATCH 007/148] In Windows, create new console and hide it Minor Correction: * logging * change status return type --- text/pyopenjtalk_worker/__init__.py | 9 ++++++--- text/pyopenjtalk_worker/worker_client.py | 17 +++++++++++++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/text/pyopenjtalk_worker/__init__.py b/text/pyopenjtalk_worker/__init__.py index 3fd997153..8a1026636 100644 --- a/text/pyopenjtalk_worker/__init__.py +++ b/text/pyopenjtalk_worker/__init__.py @@ -74,8 +74,11 @@ def initialize(port: int = WOKER_PORT): args = [sys.executable, "-m", worker_pkg_path, "--port", str(port)] # new session, new process group if sys.platform.startswith("win"): - cf = subprocess.DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore - subprocess.Popen(args, creationflags=cf) + cf = subprocess.CREATE_NEW_CONSOLE | subprocess.CREATE_NEW_PROCESS_GROUP # type: ignore + si = subprocess.STARTUPINFO() # type: ignore + si.dwFlags |= subprocess.STARTF_USESHOWWINDOW # type: ignore + si.wShowWindow = subprocess.SW_HIDE # type: ignore + subprocess.Popen(args, creationflags=cf, startupinfo=si) else: # align with Windows behavior # start_new_session is same as specifying setsid in preexec_fn @@ -107,7 +110,7 @@ def terminate(): # repare for unexpected errors try: - if WORKER_CLIENT.status().get("client-count") == 1: + if WORKER_CLIENT.status() == 1: WORKER_CLIENT.quit_server() except Exception as e: logger.error(e) diff --git a/text/pyopenjtalk_worker/worker_client.py b/text/pyopenjtalk_worker/worker_client.py index 23f7dbe8a..86d8969a9 100644 --- a/text/pyopenjtalk_worker/worker_client.py +++ b/text/pyopenjtalk_worker/worker_client.py @@ -38,9 +38,18 @@ def dispatch_pyopenjtalk(self, func: str, *args, **kwargs): return response.get("return") def status(self): - send_data(self.sock, {"request-type": RequestType.STATUS}) - return receive_data(self.sock) + data = {"request-type": RequestType.STATUS} + logger.trace(f"client sends request: {data}") + send_data(self.sock, data) + logger.trace("client sent request successfully") + response = receive_data(self.sock) + logger.trace(f"client received response: {response}") + return response.get("client-count") def quit_server(self): - send_data(self.sock, {"request-type": RequestType.QUIT_SERVER}) - receive_data(self.sock) + data = {"request-type": RequestType.QUIT_SERVER} + logger.trace(f"client sends request: {data}") + send_data(self.sock, data) + logger.trace("client sent request successfully") + response = receive_data(self.sock) + logger.trace(f"client received response: {response}") From ed90af1b8718fe31f7bf3e7565b8583a35f77113 Mon Sep 17 00:00:00 2001 From: kale4eat Date: Wed, 6 Mar 2024 12:13:27 +0900 Subject: [PATCH 008/148] Enhanced server error handling Add signal handling for when the process is killed --- text/pyopenjtalk_worker/__init__.py | 10 ++++++++++ text/pyopenjtalk_worker/worker_server.py | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/text/pyopenjtalk_worker/__init__.py b/text/pyopenjtalk_worker/__init__.py index 8a1026636..e968aed76 100644 --- a/text/pyopenjtalk_worker/__init__.py +++ b/text/pyopenjtalk_worker/__init__.py @@ -54,6 +54,7 @@ def initialize(port: int = WOKER_PORT): import socket import sys import atexit + import signal logger.debug("initialize") global WORKER_CLIENT @@ -100,6 +101,15 @@ def initialize(port: int = WOKER_PORT): WORKER_CLIENT = client atexit.register(terminate) + # when the process is killed + def signal_handler(signum, frame): + with open("signal_handler.txt", mode="w") as f: + + pass + terminate() + + signal.signal(signal.SIGTERM, signal_handler) + # top-level declaration def terminate(): diff --git a/text/pyopenjtalk_worker/worker_server.py b/text/pyopenjtalk_worker/worker_server.py index 13bc7716d..dc6d476c1 100644 --- a/text/pyopenjtalk_worker/worker_server.py +++ b/text/pyopenjtalk_worker/worker_server.py @@ -86,6 +86,12 @@ def start_server(self, port: int): self.client_count -= 1 logger.info("close connection") continue + except Exception as e: + sock.close() + sockets.remove(sock) + self.client_count -= 1 + logger.error(e) + continue logger.trace(f"server received request: {request}") From e4cd4f84239a6121d767f6d082a80ea2c4305d41 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 07:38:51 +0000 Subject: [PATCH 009/148] Fix: ignore .venv/ --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b556dfff7..88048d770 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ venv/ +.venv/ .ipynb_checkpoints/ /*.yml From f26def436972c13d8a01cf57e2ee874970cbca90 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 12:39:24 +0000 Subject: [PATCH 010/148] Remove: code that is not referenced anywhere --- spec_gen.py | 87 -------------------------------------------- update_status.py | 93 ------------------------------------------------ 2 files changed, 180 deletions(-) delete mode 100644 spec_gen.py delete mode 100644 update_status.py diff --git a/spec_gen.py b/spec_gen.py deleted file mode 100644 index b6715faff..000000000 --- a/spec_gen.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -from tqdm import tqdm -from multiprocessing import Pool -from mel_processing import spectrogram_torch, mel_spectrogram_torch -from utils import load_wav_to_torch - - -class AudioProcessor: - def __init__( - self, - max_wav_value, - use_mel_spec_posterior, - filter_length, - n_mel_channels, - sampling_rate, - hop_length, - win_length, - mel_fmin, - mel_fmax, - ): - self.max_wav_value = max_wav_value - self.use_mel_spec_posterior = use_mel_spec_posterior - self.filter_length = filter_length - self.n_mel_channels = n_mel_channels - self.sampling_rate = sampling_rate - self.hop_length = hop_length - self.win_length = win_length - self.mel_fmin = mel_fmin - self.mel_fmax = mel_fmax - - def process_audio(self, filename): - audio, sampling_rate = load_wav_to_torch(filename) - audio_norm = audio / self.max_wav_value - audio_norm = audio_norm.unsqueeze(0) - spec_filename = filename.replace(".wav", ".spec.pt") - if self.use_mel_spec_posterior: - spec_filename = spec_filename.replace(".spec.pt", ".mel.pt") - try: - spec = torch.load(spec_filename) - except: - if self.use_mel_spec_posterior: - spec = mel_spectrogram_torch( - audio_norm, - self.filter_length, - self.n_mel_channels, - self.sampling_rate, - self.hop_length, - self.win_length, - self.mel_fmin, - self.mel_fmax, - center=False, - ) - else: - spec = spectrogram_torch( - audio_norm, - self.filter_length, - self.sampling_rate, - self.hop_length, - self.win_length, - center=False, - ) - spec = torch.squeeze(spec, 0) - torch.save(spec, spec_filename) - return spec, audio_norm - - -# 使用示例 -processor = AudioProcessor( - max_wav_value=32768.0, - use_mel_spec_posterior=False, - filter_length=2048, - n_mel_channels=128, - sampling_rate=44100, - hop_length=512, - win_length=2048, - mel_fmin=0.0, - mel_fmax="null", -) - -with open("filelists/train.list", "r") as f: - filepaths = [line.split("|")[0] for line in f] # 取每一行的第一部分作为audiopath - -# 使用多进程处理 -with Pool(processes=32) as pool: # 使用4个进程 - with tqdm(total=len(filepaths)) as pbar: - for i, _ in enumerate(pool.imap_unordered(processor.process_audio, filepaths)): - pbar.update() diff --git a/update_status.py b/update_status.py deleted file mode 100644 index 7d768c663..000000000 --- a/update_status.py +++ /dev/null @@ -1,93 +0,0 @@ -import os -import gradio as gr - -lang_dict = {"EN(英文)": "_en", "ZH(中文)": "_zh", "JP(日语)": "_jp"} - - -def raw_dir_convert_to_path(target_dir: str, lang): - res = target_dir.rstrip("/").rstrip("\\") - if (not target_dir.startswith("raw")) and (not target_dir.startswith("./raw")): - res = os.path.join("./raw", res) - if ( - (not res.endswith("_zh")) - and (not res.endswith("_jp")) - and (not res.endswith("_en")) - ): - res += lang_dict[lang] - return res - - -def update_g_files(): - g_files = [] - cnt = 0 - for root, dirs, files in os.walk(os.path.abspath("./logs")): - for file in files: - if file.startswith("G_") and file.endswith(".pth"): - g_files.append(os.path.join(root, file)) - cnt += 1 - print(g_files) - return f"更新模型列表完成, 共找到{cnt}个模型", gr.Dropdown.update(choices=g_files) - - -def update_c_files(): - c_files = [] - cnt = 0 - for root, dirs, files in os.walk(os.path.abspath("./logs")): - for file in files: - if file.startswith("config.json"): - c_files.append(os.path.join(root, file)) - cnt += 1 - print(c_files) - return f"更新模型列表完成, 共找到{cnt}个配置文件", gr.Dropdown.update( - choices=c_files - ) - - -def update_model_folders(): - subdirs = [] - cnt = 0 - for root, dirs, files in os.walk(os.path.abspath("./logs")): - for dir_name in dirs: - if os.path.basename(dir_name) != "eval": - subdirs.append(os.path.join(root, dir_name)) - cnt += 1 - print(subdirs) - return f"更新模型文件夹列表完成, 共找到{cnt}个文件夹", gr.Dropdown.update( - choices=subdirs - ) - - -def update_wav_lab_pairs(): - wav_count = tot_count = 0 - for root, _, files in os.walk("./raw"): - for file in files: - # print(file) - file_path = os.path.join(root, file) - if file.lower().endswith(".wav"): - lab_file = os.path.splitext(file_path)[0] + ".lab" - if os.path.exists(lab_file): - wav_count += 1 - tot_count += 1 - return f"{wav_count} / {tot_count}" - - -def update_raw_folders(): - subdirs = [] - cnt = 0 - script_path = os.path.dirname(os.path.abspath(__file__)) # 获取当前脚本的绝对路径 - raw_path = os.path.join(script_path, "raw") - print(raw_path) - os.makedirs(raw_path, exist_ok=True) - for root, dirs, files in os.walk(raw_path): - for dir_name in dirs: - relative_path = os.path.relpath( - os.path.join(root, dir_name), script_path - ) # 获取相对路径 - subdirs.append(relative_path) - cnt += 1 - print(subdirs) - return ( - f"更新raw音频文件夹列表完成, 共找到{cnt}个文件夹", - gr.Dropdown.update(choices=subdirs), - gr.Textbox.update(value=update_wav_lab_pairs()), - ) From 419d0e5bed6e91028c570dad1282f4cc027cf433 Mon Sep 17 00:00:00 2001 From: kale4eat Date: Wed, 6 Mar 2024 22:38:44 +0900 Subject: [PATCH 011/148] Delete debugging traces --- text/pyopenjtalk_worker/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/text/pyopenjtalk_worker/__init__.py b/text/pyopenjtalk_worker/__init__.py index e968aed76..9677666b2 100644 --- a/text/pyopenjtalk_worker/__init__.py +++ b/text/pyopenjtalk_worker/__init__.py @@ -103,9 +103,6 @@ def initialize(port: int = WOKER_PORT): # when the process is killed def signal_handler(signum, frame): - with open("signal_handler.txt", mode="w") as f: - - pass terminate() signal.signal(signal.SIGTERM, signal_handler) From 918d168ae707c464a0582b186f5ffdfdad8fbf94 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 20:56:21 +0000 Subject: [PATCH 012/148] Refactor: rewrote Japanese natural language processing code imported from server_editor.py The logic has not been changed, only renaming, splitting and moving modules on a per-function basis. Existing code will be left in place for the time being to avoid breaking the training code, which is not subject to refactoring this time. --- server_editor.py | 40 +- style_bert_vits2/.editorconfig | 15 + style_bert_vits2/__init__.py | 0 style_bert_vits2/constants.py | 32 ++ style_bert_vits2/logging.py | 15 + .../text_processing/japanese/g2p.py | 493 ++++++++++++++++++ .../text_processing/japanese/g2p_utils.py | 94 ++++ .../text_processing/japanese/mora_list.py | 236 +++++++++ .../text_processing/japanese/normalizer.py | 161 ++++++ style_bert_vits2/text_processing/symbols.py | 192 +++++++ style_bert_vits2/utils/stdout_wrapper.py | 47 ++ 11 files changed, 1312 insertions(+), 13 deletions(-) create mode 100644 style_bert_vits2/.editorconfig create mode 100644 style_bert_vits2/__init__.py create mode 100644 style_bert_vits2/constants.py create mode 100644 style_bert_vits2/logging.py create mode 100644 style_bert_vits2/text_processing/japanese/g2p.py create mode 100644 style_bert_vits2/text_processing/japanese/g2p_utils.py create mode 100644 style_bert_vits2/text_processing/japanese/mora_list.py create mode 100644 style_bert_vits2/text_processing/japanese/normalizer.py create mode 100644 style_bert_vits2/text_processing/symbols.py create mode 100644 style_bert_vits2/utils/stdout_wrapper.py diff --git a/server_editor.py b/server_editor.py index afb9321df..70f7c9bc4 100644 --- a/server_editor.py +++ b/server_editor.py @@ -12,14 +12,13 @@ import shutil import sys import webbrowser +import yaml import zipfile from datetime import datetime from io import BytesIO from pathlib import Path -import yaml import numpy as np -import pyopenjtalk import requests import torch import uvicorn @@ -29,21 +28,29 @@ from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from scipy.io import wavfile +from transformers import AutoTokenizer -from common.constants import ( +from common.tts_model import ModelHolder +from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_NOISE, DEFAULT_NOISEW, DEFAULT_SDP_RATIO, DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, - LATEST_VERSION, + VERSION, Languages, ) -from common.log import logger -from common.tts_model import ModelHolder -from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize -from text.user_dict import apply_word, update_dict, read_dict, rewrite_word, delete_word +from style_bert_vits2.logging import logger +from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone +from style_bert_vits2.text_processing.japanese.normalizer import normalize_text +from text.user_dict import ( + apply_word, + update_dict, + read_dict, + rewrite_word, + delete_word, +) # ---フロントエンド部分に関する処理--- @@ -140,6 +147,12 @@ def save_last_download(latest_release): # ---フロントエンド部分に関する処理ここまで--- # 以降はAPIの設定 +# 最初に pyopenjtalk の辞書を更新 +update_dict() + +# 単語分割に使う BERT トークナイザーをロード +tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese-char-wwm") + class AudioResponse(Response): media_type = "audio/wav" @@ -197,7 +210,7 @@ class AudioResponse(Response): @router.get("/version") def version() -> str: - return LATEST_VERSION + return VERSION class MoraTone(BaseModel): @@ -213,8 +226,8 @@ class TextRequest(BaseModel): async def read_item(item: TextRequest): try: # 最初に正規化しないと整合性がとれない - text = text_normalize(item.text) - kata_tone_list = g2kata_tone(text) + text = normalize_text(item.text) + kata_tone_list = g2kata_tone(text, tokenizer) except Exception as e: raise HTTPException( status_code=400, @@ -224,8 +237,8 @@ async def read_item(item: TextRequest): @router.post("/normalize") -async def normalize_text(item: TextRequest): - return text_normalize(item.text) +async def normalize(item: TextRequest): + return normalize_text(item.text) @router.get("/models_info") @@ -311,6 +324,7 @@ def multi_synthesis(request: MultiSynthesisRequest): detail=f"行数は{args.line_count}行以下にしてください。", ) audios = [] + sr = None for i, req in enumerate(lines): if args.line_length is not None and len(req.text) > args.line_length: raise HTTPException( diff --git a/style_bert_vits2/.editorconfig b/style_bert_vits2/.editorconfig new file mode 100644 index 000000000..bbc8a70b6 --- /dev/null +++ b/style_bert_vits2/.editorconfig @@ -0,0 +1,15 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +indent_size = 4 +indent_style = space +trim_trailing_whitespace = true + +[*.md] +trim_trailing_whitespace = false + +[*.yml] +indent_size = 2 diff --git a/style_bert_vits2/__init__.py b/style_bert_vits2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py new file mode 100644 index 000000000..a90bc01ba --- /dev/null +++ b/style_bert_vits2/constants.py @@ -0,0 +1,32 @@ +from enum import Enum +from pathlib import Path + + +# Style-Bert-VITS2 のバージョン +VERSION = "2.3.1" + +# ユーザー辞書ディレクトリ +USER_DICT_DIR = Path("dict_data") + +# Gradio のテーマ +## Built-in theme: "default", "base", "monochrome", "soft", "glass" +## See https://huggingface.co/spaces/gradio/theme-gallery for more themes +GRADIO_THEME = "NoCrypt/miku" + +# 利用可能な言語 +class Languages(str, Enum): + JP = "JP" + EN = "EN" + ZH = "ZH" + +# 推論パラメータのデフォルト値 +DEFAULT_STYLE = "Neutral" +DEFAULT_STYLE_WEIGHT = 5.0 +DEFAULT_SDP_RATIO = 0.2 +DEFAULT_NOISE = 0.6 +DEFAULT_NOISEW = 0.8 +DEFAULT_LENGTH = 1.0 +DEFAULT_LINE_SPLIT = True +DEFAULT_SPLIT_INTERVAL = 0.5 +DEFAULT_ASSIST_TEXT_WEIGHT = 0.7 +DEFAULT_ASSIST_TEXT_WEIGHT = 1.0 diff --git a/style_bert_vits2/logging.py b/style_bert_vits2/logging.py new file mode 100644 index 000000000..eec887ce3 --- /dev/null +++ b/style_bert_vits2/logging.py @@ -0,0 +1,15 @@ +from loguru import logger + +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + + +# Remove all default handlers +logger.remove() + +# Add a new handler +logger.add( + SAFE_STDOUT, + format = "{time:MM-DD HH:mm:ss} |{level:^8}| {file}:{line} | {message}", + backtrace = True, + diagnose = True, +) diff --git a/style_bert_vits2/text_processing/japanese/g2p.py b/style_bert_vits2/text_processing/japanese/g2p.py new file mode 100644 index 000000000..1b8bd6972 --- /dev/null +++ b/style_bert_vits2/text_processing/japanese/g2p.py @@ -0,0 +1,493 @@ +import pyopenjtalk +import re +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from style_bert_vits2.logging import logger +from style_bert_vits2.text_processing.japanese.mora_list import MORA_KATA_TO_MORA_PHONEMES +from style_bert_vits2.text_processing.japanese.normalizer import replace_punctuation +from style_bert_vits2.text_processing.symbols import PUNCTUATIONS + + +def g2p( + norm_text: str, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + use_jp_extra: bool = True, + raise_yomi_error: bool = False +) -> tuple[list[str], list[int], list[int]]: + """ + 他で使われるメインの関数。`normalize_text()` で正規化された `norm_text` を受け取り、 + - phones: 音素のリスト(ただし `!` や `,` や `.` など punctuation が含まれうる) + - tones: アクセントのリスト、0(低)と1(高)からなり、phones と同じ長さ + - word2ph: 元のテキストの各文字に音素が何個割り当てられるかを表すリスト + のタプルを返す。 + ただし `phones` と `tones` の最初と終わりに `_` が入り、応じて `word2ph` の最初と最後に 1 が追加される。 + tokenizer には deberta-v2-large-japanese-char-wwm を AutoTokenizer.from_pretrained() でロードしたものを指定する。 + + Args: + norm_text (str): 正規化されたテキスト + tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast): 単語分割に使うロード済みの BERT Tokenizer インスタンス + use_jp_extra (bool, optional): False の場合、「ん」の音素を「N」ではなく「n」とする。Defaults to True. + raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False. + + Returns: + tuple[list[str], list[int], list[int]]: 音素のリスト、アクセントのリスト、word2ph のリスト + """ + + # pyopenjtalk のフルコンテキストラベルを使ってアクセントを取り出すと、punctuation の位置が消えてしまい情報が失われてしまう: + # 「こんにちは、世界。」と「こんにちは!世界。」と「こんにちは!!!???世界……。」は全て同じになる。 + # よって、まず punctuation 無しの音素とアクセントのリストを作り、 + # それとは別に pyopenjtalk.run_frontend() で得られる音素リスト(こちらは punctuation が保持される)を使い、 + # アクセント割当をしなおすことによって punctuation を含めた音素とアクセントのリストを作る。 + + # punctuation がすべて消えた、音素とアクセントのタプルのリスト(「ん」は「N」) + phone_tone_list_wo_punct = __g2phone_tone_wo_punct(norm_text) + + # sep_text: 単語単位の単語のリスト、読めない文字があったら raise_yomi_error なら例外、そうでないなら読めない文字が消えて返ってくる + # sep_kata: 単語単位の単語のカタカナ読みのリスト + sep_text, sep_kata = text_to_sep_kata(norm_text, raise_yomi_error=raise_yomi_error) + + # sep_phonemes: 各単語ごとの音素のリストのリスト + sep_phonemes = __handle_long([__kata_to_phoneme_list(i) for i in sep_kata]) + + # phone_w_punct: sep_phonemes を結合した、punctuation を元のまま保持した音素列 + phone_w_punct: list[str] = [] + for i in sep_phonemes: + phone_w_punct += i + + # punctuation 無しのアクセント情報を使って、punctuation を含めたアクセント情報を作る + phone_tone_list = __align_tones(phone_w_punct, phone_tone_list_wo_punct) + # logger.debug(f"phone_tone_list:\n{phone_tone_list}") + + # word2ph は厳密な解答は不可能なので(「今日」「眼鏡」等の熟字訓が存在)、 + # Bert-VITS2 では、単語単位の分割を使って、単語の文字ごとにだいたい均等に音素を分配する + + # sep_text から、各単語を1文字1文字分割して、文字のリスト(のリスト)を作る + sep_tokenized: list[list[str]] = [] + for i in sep_text: + if i not in PUNCTUATIONS: + sep_tokenized.append( + tokenizer.tokenize(i) + ) # ここでおそらく`i`が文字単位に分割される + else: + sep_tokenized.append([i]) + + # 各単語について、音素の数と文字の数を比較して、均等っぽく分配する + word2ph = [] + for token, phoneme in zip(sep_tokenized, sep_phonemes): + phone_len = len(phoneme) + word_len = len(token) + word2ph += __distribute_phone(phone_len, word_len) + + # 最初と最後に `_` 記号を追加、アクセントは 0(低)、word2ph もそれに合わせて追加 + phone_tone_list = [("_", 0)] + phone_tone_list + [("_", 0)] + word2ph = [1] + word2ph + [1] + + phones = [phone for phone, _ in phone_tone_list] + tones = [tone for _, tone in phone_tone_list] + + assert len(phones) == sum(word2ph), f"{len(phones)} != {sum(word2ph)}" + + # use_jp_extra でない場合は「N」を「n」に変換 + if not use_jp_extra: + phones = [phone if phone != "N" else "n" for phone in phones] + + return phones, tones, word2ph + + +def text_to_sep_kata( + norm_text: str, + raise_yomi_error: bool = False +) -> tuple[list[str], list[str]]: + """ + `normalize_text` で正規化済みの `norm_text` を受け取り、それを単語分割し、 + 分割された単語リストとその読み(カタカナ or 記号1文字)のリストのタプルを返す。 + 単語分割結果は、`g2p()` の `word2ph` で1文字あたりに割り振る音素記号の数を決めるために使う。 + 例: + `私はそう思う!って感じ?` → + ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"] + + Args: + norm_text (str): 正規化されたテキスト + raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False. + + Returns: + tuple[list[str], list[str]]: 分割された単語リストと、その読み(カタカナ or 記号1文字)のリスト + """ + + # parsed: OpenJTalkの解析結果 + parsed = pyopenjtalk.run_frontend(norm_text) + sep_text: list[str] = [] + sep_kata: list[str] = [] + + for parts in parsed: + # word: 実際の単語の文字列 + # yomi: その読み、但し無声化サインの`’`は除去 + word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace( + "’", "" + ) + """ + ここで `yomi` の取りうる値は以下の通りのはず。 + - `word` が通常単語 → 通常の読み(カタカナ) + (カタカナからなり、長音記号も含みうる、`アー` 等) + - `word` が `ー` から始まる → `ーラー` や `ーーー` など + - `word` が句読点や空白等 → `、` + - `word` が punctuation の繰り返し → 全角にしたもの + 基本的に punctuation は1文字ずつ分かれるが、何故かある程度連続すると1つにまとまる。 + 他にも `word` が読めないキリル文字アラビア文字等が来ると `、` になるが、正規化でこの場合は起きないはず。 + また元のコードでは `yomi` が空白の場合の処理があったが、これは起きないはず。 + 処理すべきは `yomi` が `、` の場合のみのはず。 + """ + assert yomi != "", f"Empty yomi: {word}" + if yomi == "、": + # word は正規化されているので、`.`, `,`, `!`, `'`, `-`, `--` のいずれか + if not set(word).issubset(set(PUNCTUATIONS)): # 記号繰り返しか判定 + # ここは pyopenjtalk が読めない文字等のときに起こる + if raise_yomi_error: + raise YomiError(f"Cannot read: {word} in:\n{norm_text}") + logger.warning(f"Ignoring unknown: {word} in:\n{norm_text}") + continue + # yomi は元の記号のままに変更 + yomi = word + elif yomi == "?": + assert word == "?", f"yomi `?` comes from: {word}" + yomi = "?" + sep_text.append(word) + sep_kata.append(yomi) + + return sep_text, sep_kata + + +def __g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: + """ + テキストに対して、音素とアクセント(0か1)のペアのリストを返す。 + ただし「!」「.」「?」等の非音素記号 (punctuation) は全て消える(ポーズ記号も残さない)。 + 非音素記号を含める処理は `align_tones()` で行われる。 + また「っ」は「q」に、「ん」は「N」に変換される。 + 例: "こんにちは、世界ー。。元気?!" → + [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)] + + Args: + text (str): テキスト + + Returns: + list[tuple[str, int]]: 音素とアクセントのペアのリスト + """ + + prosodies = pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True) + # logger.debug(f"prosodies: {prosodies}") + result: list[tuple[str, int]] = [] + current_phrase: list[tuple[str, int]] = [] + current_tone = 0 + + for i, letter in enumerate(prosodies): + # 特殊記号の処理 + + # 文頭記号、無視する + if letter == "^": + assert i == 0, "Unexpected ^" + # アクセント句の終わりに来る記号 + elif letter in ("$", "?", "_", "#"): + # 保持しているフレーズを、アクセント数値を 0-1 に修正し結果に追加 + result.extend(__fix_phone_tone(current_phrase)) + # 末尾に来る終了記号、無視(文中の疑問文は `_` になる) + if letter in ("$", "?"): + assert i == len(prosodies) - 1, f"Unexpected {letter}" + # あとは "_"(ポーズ)と "#"(アクセント句の境界)のみ + # これらは残さず、次のアクセント句に備える。 + current_phrase = [] + # 0 を基準点にしてそこから上昇・下降する(負の場合は上の `fix_phone_tone` で直る) + current_tone = 0 + # アクセント上昇記号 + elif letter == "[": + current_tone = current_tone + 1 + # アクセント下降記号 + elif letter == "]": + current_tone = current_tone - 1 + # それ以外は通常の音素 + else: + if letter == "cl": # 「っ」の処理 + letter = "q" + # elif letter == "N": # 「ん」の処理 + # letter = "n" + current_phrase.append((letter, current_tone)) + + return result + + +def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> list[str]: + """ + ESPnet の実装から引用、変更点無し。「ん」は「N」なことに注意。 + ref: https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py + ------------------------------------------------------------------------------------------ + + Extract phoneme + prosoody symbol sequence from input full-context labels. + + The algorithm is based on `Prosodic features control by symbols as input of + sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks. + + Args: + text (str): Input text. + drop_unvoiced_vowels (bool): whether to drop unvoiced vowels. + + Returns: + List[str]: List of phoneme + prosody symbols. + + Examples: + >>> from espnet2.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody + >>> pyopenjtalk_g2p_prosody("こんにちは。") + ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$'] + + .. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic + modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104 + """ + + def _numeric_feature_by_regex(regex: str, s: str) -> int: + match = re.search(regex, s) + if match is None: + return -50 + return int(match.group(1)) + + labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text)) + N = len(labels) + + phones = [] + for n in range(N): + lab_curr = labels[n] + + # current phoneme + p3 = re.search(r"\-(.*?)\+", lab_curr).group(1) # type: ignore + # deal unvoiced vowels as normal vowels + if drop_unvoiced_vowels and p3 in "AEIOU": + p3 = p3.lower() + + # deal with sil at the beginning and the end of text + if p3 == "sil": + assert n == 0 or n == N - 1 + if n == 0: + phones.append("^") + elif n == N - 1: + # check question form or not + e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr) + if e3 == 0: + phones.append("$") + elif e3 == 1: + phones.append("?") + continue + elif p3 == "pau": + phones.append("_") + continue + else: + phones.append(p3) + + # accent type and position info (forward or backward) + a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr) + a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr) + a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr) + + # number of mora in accent phrase + f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr) + + a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1]) + # accent phrase border + if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl": + phones.append("#") + # pitch falling + elif a1 == 0 and a2_next == a2 + 1 and a2 != f1: + phones.append("]") + # pitch rising + elif a2 == 1 and a2_next == 2: + phones.append("[") + + return phones + + +def __fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]: + """ + `phone_tone_list` の tone(アクセントの値)を 0 か 1 の範囲に修正する。 + 例: [(a, 0), (i, -1), (u, -1)] → [(a, 1), (i, 0), (u, 0)] + + Args: + phone_tone_list (list[tuple[str, int]]): 音素とアクセントのペアのリスト + + Returns: + list[tuple[str, int]]: 修正された音素とアクセントのペアのリスト + """ + + tone_values = set(tone for _, tone in phone_tone_list) + if len(tone_values) == 1: + assert tone_values == {0}, tone_values + return phone_tone_list + elif len(tone_values) == 2: + if tone_values == {0, 1}: + return phone_tone_list + elif tone_values == {-1, 0}: + return [ + (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list + ] + else: + raise ValueError(f"Unexpected tone values: {tone_values}") + else: + raise ValueError(f"Unexpected tone values: {tone_values}") + + +def __handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]: + """ + フレーズごとに分かれた音素(長音記号がそのまま)のリストのリスト `sep_phonemes` を受け取り、 + その長音記号を処理して、音素のリストのリストを返す。 + 基本的には直前の音素を伸ばすが、直前の音素が母音でない場合もしくは冒頭の場合は、 + おそらく長音記号とダッシュを勘違いしていると思われるので、ダッシュに対応する音素 `-` に変換する。 + + Args: + sep_phonemes (list[list[str]]): フレーズごとに分かれた音素のリストのリスト + + Returns: + list[list[str]]: 長音記号を処理した音素のリストのリスト + """ + + # 母音の集合 (便宜上「ん」を含める) + VOWELS = {"a", "i", "u", "e", "o", "N"} + + for i in range(len(sep_phonemes)): + if len(sep_phonemes[i]) == 0: + # 空白文字等でリストが空の場合 + continue + if sep_phonemes[i][0] == "ー": + if i != 0: + prev_phoneme = sep_phonemes[i - 1][-1] + if prev_phoneme in VOWELS: + # 母音と「ん」のあとの伸ばし棒なので、その母音に変換 + sep_phonemes[i][0] = sep_phonemes[i - 1][-1] + else: + # 「。ーー」等おそらく予期しない長音記号 + # ダッシュの勘違いだと思われる + sep_phonemes[i][0] = "-" + else: + # 冒頭に長音記号が来ていおり、これはダッシュの勘違いと思われる + sep_phonemes[i][0] = "-" + if "ー" in sep_phonemes[i]: + for j in range(len(sep_phonemes[i])): + if sep_phonemes[i][j] == "ー": + sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1] + + return sep_phonemes + + +def __kata_to_phoneme_list(text: str) -> list[str]: + """ + 原則カタカナの `text` を受け取り、それをそのままいじらずに音素記号のリストに変換。 + 注意点: + - punctuation かその繰り返しが来た場合、punctuation たちをそのままリストにして返す。 + - 冒頭に続く「ー」はそのまま「ー」のままにする(`handle_long()` で処理される) + - 文中の「ー」は前の音素記号の最後の音素記号に変換される。 + 例: + `ーーソーナノカーー` → ["ー", "ー", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"] + `?` → ["?"] + `!?!?!?!?!` → ["!", "?", "!", "?", "!", "?", "!", "?", "!"] + + Args: + text (str): カタカナのテキスト + + Returns: + list[str]: 音素記号のリスト + """ + + if set(text).issubset(set(PUNCTUATIONS)): + return list(text) + # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック + if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None: + raise ValueError(f"Input must be katakana only: {text}") + sorted_keys = sorted(MORA_KATA_TO_MORA_PHONEMES.keys(), key=len, reverse=True) + pattern = "|".join(map(re.escape, sorted_keys)) + + def mora2phonemes(mora: str) -> str: + cosonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora] + if cosonant is None: + return f" {vowel}" + return f" {cosonant} {vowel}" + + spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text) + + # 長音記号「ー」の処理 + long_pattern = r"(\w)(ー*)" + long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2)) # type: ignore + spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes) + + return spaced_phonemes.strip().split(" ") + + +def __align_tones( + phones_with_punct: list[str], + phone_tone_list: list[tuple[str, int]] +) -> list[tuple[str, int]]: + """ + 例: …私は、、そう思う。 + phones_with_punct: + [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."] + phone_tone_list: + [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("_", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))] + Return: + [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)] + + Args: + phones_with_punct (list[str]): punctuation を含む音素のリスト + phone_tone_list (list[tuple[str, int]]): punctuation を含まない音素とアクセントのペアのリスト + + Returns: + list[tuple[str, int]]: punctuation を含む音素とアクセントのペアのリスト + """ + + result: list[tuple[str, int]] = [] + tone_index = 0 + for phone in phones_with_punct: + if tone_index >= len(phone_tone_list): + # 余ったpunctuationがある場合 → (punctuation, 0)を追加 + result.append((phone, 0)) + elif phone == phone_tone_list[tone_index][0]: + # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加 + result.append((phone, phone_tone_list[tone_index][1])) + # 探すindexを1つ進める + tone_index += 1 + elif phone in PUNCTUATIONS: + # phoneがpunctuationの場合 → (phone, 0)を追加 + result.append((phone, 0)) + else: + logger.debug(f"phones: {phones_with_punct}") + logger.debug(f"phone_tone_list: {phone_tone_list}") + logger.debug(f"result: {result}") + logger.debug(f"tone_index: {tone_index}") + logger.debug(f"phone: {phone}") + raise ValueError(f"Unexpected phone: {phone}") + + return result + + +def __distribute_phone(n_phone: int, n_word: int) -> list[int]: + """ + 左から右に 1 ずつ振り分け、次にまた左から右に1ずつ増やし、というふうに、 + 音素の数 `n_phone` を単語の数 `n_word` に分配する。 + + Args: + n_phone (int): 音素の数 + n_word (int): 単語の数 + + Returns: + list[int]: 単語ごとの音素の数のリスト + """ + + phones_per_word = [0] * n_word + for _ in range(n_phone): + min_tasks = min(phones_per_word) + min_index = phones_per_word.index(min_tasks) + phones_per_word[min_index] += 1 + + return phones_per_word + + +class YomiError(Exception): + """ + OpenJTalk で、読みが正しく取得できない箇所があるときに発生する例外。 + 基本的に「学習の前処理のテキスト処理時」には発生させ、そうでない場合は、 + ignore_yomi_error=True にしておいて、この例外を発生させないようにする。 + """ + + pass diff --git a/style_bert_vits2/text_processing/japanese/g2p_utils.py b/style_bert_vits2/text_processing/japanese/g2p_utils.py new file mode 100644 index 000000000..e09560263 --- /dev/null +++ b/style_bert_vits2/text_processing/japanese/g2p_utils.py @@ -0,0 +1,94 @@ +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from style_bert_vits2.text_processing.japanese.g2p import g2p +from style_bert_vits2.text_processing.japanese.mora_list import ( + MORA_KATA_TO_MORA_PHONEMES, + MORA_PHONEMES_TO_MORA_KATA, +) +from style_bert_vits2.text_processing.symbols import PUNCTUATIONS + + +def g2kata_tone(norm_text: str, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast) -> list[tuple[str, int]]: + """ + テキストからカタカナとアクセントのペアのリストを返す。 + 推論時のみに使われるので、常に`raise_yomi_error=False`でg2pを呼ぶ。 + tokenizer には deberta-v2-large-japanese-char-wwm を AutoTokenizer.from_pretrained() でロードしたものを指定する。 + + Args: + norm_text: 正規化されたテキスト。 + tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast): 単語分割に使うロード済みの BERT Tokenizer インスタンス + + Returns: + カタカナと音高のリスト。 + """ + + phones, tones, _ = g2p(norm_text, tokenizer, use_jp_extra=True, raise_yomi_error=False) + return phone_tone2kata_tone(list(zip(phones, tones))) + + +def phone_tone2kata_tone(phone_tone: list[tuple[str, int]]) -> list[tuple[str, int]]: + """ + phone_tone の phone 部分をカタカナに変換する。ただし最初と最後の ("_", 0) は無視する。 + + Args: + phone_tone: 音素と音高のリスト。 + + Returns: + カタカナと音高のリスト。 + """ + + # 子音の集合 + CONSONANTS = set([ + consonant + for consonant, _ in MORA_KATA_TO_MORA_PHONEMES.values() + if consonant is not None + ]) + + phone_tone = phone_tone[1:] # 最初の("_", 0)を無視 + phones = [phone for phone, _ in phone_tone] + tones = [tone for _, tone in phone_tone] + result: list[tuple[str, int]] = [] + current_mora = "" + for phone, next_phone, tone, next_tone in zip(phones, phones[1:], tones, tones[1:]): + # zip の関係で最後の ("_", 0) は無視されている + if phone in PUNCTUATIONS: + result.append((phone, tone)) + continue + if phone in CONSONANTS: # n以外の子音の場合 + assert current_mora == "", f"Unexpected {phone} after {current_mora}" + assert tone == next_tone, f"Unexpected {phone} tone {tone} != {next_tone}" + current_mora = phone + else: + # phoneが母音もしくは「N」 + current_mora += phone + result.append((MORA_PHONEMES_TO_MORA_KATA[current_mora], tone)) + current_mora = "" + + return result + + +def kata_tone2phone_tone(kata_tone: list[tuple[str, int]]) -> list[tuple[str, int]]: + """ + `phone_tone2kata_tone()` の逆の変換を行う。 + + Args: + kata_tone: カタカナと音高のリスト。 + + Returns: + 音素と音高のリスト。 + """ + + result: list[tuple[str, int]] = [("_", 0)] + for mora, tone in kata_tone: + if mora in PUNCTUATIONS: + result.append((mora, tone)) + else: + consonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora] + if consonant is None: + result.append((vowel, tone)) + else: + result.append((consonant, tone)) + result.append((vowel, tone)) + result.append(("_", 0)) + + return result diff --git a/style_bert_vits2/text_processing/japanese/mora_list.py b/style_bert_vits2/text_processing/japanese/mora_list.py new file mode 100644 index 000000000..a0dab2fc7 --- /dev/null +++ b/style_bert_vits2/text_processing/japanese/mora_list.py @@ -0,0 +1,236 @@ +""" +以下のコードは VOICEVOX のソースコードからお借りし最低限の改造を行ったもの。 +https://github.com/VOICEVOX/voicevox_engine/blob/master/voicevox_engine/tts_pipeline/mora_list.py +""" + +""" +以下のモーラ対応表は OpenJTalk のソースコードから取得し、 +カタカナ表記とモーラが一対一対応するように改造した。 +ライセンス表記: +----------------------------------------------------------------- + The Japanese TTS System "Open JTalk" + developed by HTS Working Group + http://open-jtalk.sourceforge.net/ +----------------------------------------------------------------- + + Copyright (c) 2008-2014 Nagoya Institute of Technology + Department of Computer Science + +All rights reserved. + +Redistribution and use in source and binary forms, with or +without modification, are permitted provided that the following +conditions are met: + +- Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +- Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. +- Neither the name of the HTS working group nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS +BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import Optional + + +# (カタカナ, 子音, 母音)の順。子音がない場合は None を入れる。 +# 但し「ン」と「ッ」は母音のみという扱いで、「ン」は「N」、「ッ」は「q」とする。 +# (元々「ッ」は「cl」) +# また「デェ = dy e」は pyopenjtalk の出力(de e)と合わないため削除 +__MORA_LIST_MINIMUM: list[tuple[str, Optional[str], str]] = [ + ("ヴォ", "v", "o"), + ("ヴェ", "v", "e"), + ("ヴィ", "v", "i"), + ("ヴァ", "v", "a"), + ("ヴ", "v", "u"), + ("ン", None, "N"), + ("ワ", "w", "a"), + ("ロ", "r", "o"), + ("レ", "r", "e"), + ("ル", "r", "u"), + ("リョ", "ry", "o"), + ("リュ", "ry", "u"), + ("リャ", "ry", "a"), + ("リェ", "ry", "e"), + ("リ", "r", "i"), + ("ラ", "r", "a"), + ("ヨ", "y", "o"), + ("ユ", "y", "u"), + ("ヤ", "y", "a"), + ("モ", "m", "o"), + ("メ", "m", "e"), + ("ム", "m", "u"), + ("ミョ", "my", "o"), + ("ミュ", "my", "u"), + ("ミャ", "my", "a"), + ("ミェ", "my", "e"), + ("ミ", "m", "i"), + ("マ", "m", "a"), + ("ポ", "p", "o"), + ("ボ", "b", "o"), + ("ホ", "h", "o"), + ("ペ", "p", "e"), + ("ベ", "b", "e"), + ("ヘ", "h", "e"), + ("プ", "p", "u"), + ("ブ", "b", "u"), + ("フォ", "f", "o"), + ("フェ", "f", "e"), + ("フィ", "f", "i"), + ("ファ", "f", "a"), + ("フ", "f", "u"), + ("ピョ", "py", "o"), + ("ピュ", "py", "u"), + ("ピャ", "py", "a"), + ("ピェ", "py", "e"), + ("ピ", "p", "i"), + ("ビョ", "by", "o"), + ("ビュ", "by", "u"), + ("ビャ", "by", "a"), + ("ビェ", "by", "e"), + ("ビ", "b", "i"), + ("ヒョ", "hy", "o"), + ("ヒュ", "hy", "u"), + ("ヒャ", "hy", "a"), + ("ヒェ", "hy", "e"), + ("ヒ", "h", "i"), + ("パ", "p", "a"), + ("バ", "b", "a"), + ("ハ", "h", "a"), + ("ノ", "n", "o"), + ("ネ", "n", "e"), + ("ヌ", "n", "u"), + ("ニョ", "ny", "o"), + ("ニュ", "ny", "u"), + ("ニャ", "ny", "a"), + ("ニェ", "ny", "e"), + ("ニ", "n", "i"), + ("ナ", "n", "a"), + ("ドゥ", "d", "u"), + ("ド", "d", "o"), + ("トゥ", "t", "u"), + ("ト", "t", "o"), + ("デョ", "dy", "o"), + ("デュ", "dy", "u"), + ("デャ", "dy", "a"), + # ("デェ", "dy", "e"), + ("ディ", "d", "i"), + ("デ", "d", "e"), + ("テョ", "ty", "o"), + ("テュ", "ty", "u"), + ("テャ", "ty", "a"), + ("ティ", "t", "i"), + ("テ", "t", "e"), + ("ツォ", "ts", "o"), + ("ツェ", "ts", "e"), + ("ツィ", "ts", "i"), + ("ツァ", "ts", "a"), + ("ツ", "ts", "u"), + ("ッ", None, "q"), # 「cl」から「q」に変更 + ("チョ", "ch", "o"), + ("チュ", "ch", "u"), + ("チャ", "ch", "a"), + ("チェ", "ch", "e"), + ("チ", "ch", "i"), + ("ダ", "d", "a"), + ("タ", "t", "a"), + ("ゾ", "z", "o"), + ("ソ", "s", "o"), + ("ゼ", "z", "e"), + ("セ", "s", "e"), + ("ズィ", "z", "i"), + ("ズ", "z", "u"), + ("スィ", "s", "i"), + ("ス", "s", "u"), + ("ジョ", "j", "o"), + ("ジュ", "j", "u"), + ("ジャ", "j", "a"), + ("ジェ", "j", "e"), + ("ジ", "j", "i"), + ("ショ", "sh", "o"), + ("シュ", "sh", "u"), + ("シャ", "sh", "a"), + ("シェ", "sh", "e"), + ("シ", "sh", "i"), + ("ザ", "z", "a"), + ("サ", "s", "a"), + ("ゴ", "g", "o"), + ("コ", "k", "o"), + ("ゲ", "g", "e"), + ("ケ", "k", "e"), + ("グヮ", "gw", "a"), + ("グ", "g", "u"), + ("クヮ", "kw", "a"), + ("ク", "k", "u"), + ("ギョ", "gy", "o"), + ("ギュ", "gy", "u"), + ("ギャ", "gy", "a"), + ("ギェ", "gy", "e"), + ("ギ", "g", "i"), + ("キョ", "ky", "o"), + ("キュ", "ky", "u"), + ("キャ", "ky", "a"), + ("キェ", "ky", "e"), + ("キ", "k", "i"), + ("ガ", "g", "a"), + ("カ", "k", "a"), + ("オ", None, "o"), + ("エ", None, "e"), + ("ウォ", "w", "o"), + ("ウェ", "w", "e"), + ("ウィ", "w", "i"), + ("ウ", None, "u"), + ("イェ", "y", "e"), + ("イ", None, "i"), + ("ア", None, "a"), +] +__MORA_LIST_ADDITIONAL: list[tuple[str, Optional[str], str]] = [ + ("ヴョ", "by", "o"), + ("ヴュ", "by", "u"), + ("ヴャ", "by", "a"), + ("ヲ", None, "o"), + ("ヱ", None, "e"), + ("ヰ", None, "i"), + ("ヮ", "w", "a"), + ("ョ", "y", "o"), + ("ュ", "y", "u"), + ("ヅ", "z", "u"), + ("ヂ", "j", "i"), + ("ヶ", "k", "e"), + ("ャ", "y", "a"), + ("ォ", None, "o"), + ("ェ", None, "e"), + ("ゥ", None, "u"), + ("ィ", None, "i"), + ("ァ", None, "a"), +] + +# モーラの音素表記とカタカナの対応表 +# 例: "vo" -> "ヴォ", "a" -> "ア" +MORA_PHONEMES_TO_MORA_KATA: dict[str, str] = { + (consonant or "") + vowel: kana for [kana, consonant, vowel] in __MORA_LIST_MINIMUM +} + +# モーラのカタカナ表記と音素の対応表 +# 例: "ヴォ" -> ("v", "o"), "ア" -> (None, "a") +MORA_KATA_TO_MORA_PHONEMES: dict[str, tuple[Optional[str], str]] = { + kana: (consonant, vowel) + for [kana, consonant, vowel] in __MORA_LIST_MINIMUM + __MORA_LIST_ADDITIONAL +} diff --git a/style_bert_vits2/text_processing/japanese/normalizer.py b/style_bert_vits2/text_processing/japanese/normalizer.py new file mode 100644 index 000000000..92c2e879f --- /dev/null +++ b/style_bert_vits2/text_processing/japanese/normalizer.py @@ -0,0 +1,161 @@ +import re +import unicodedata +from num2words import num2words + +from style_bert_vits2.text_processing.symbols import PUNCTUATIONS + + +def normalize_text(text: str) -> str: + """ + 日本語のテキストを正規化する。 + 結果は、ちょうど次の文字のみからなる: + - ひらがな + - カタカナ(全角長音記号「ー」が入る!) + - 漢字 + - 半角アルファベット(大文字と小文字) + - ギリシャ文字 + - `.` (句点`。`や`…`の一部や改行等) + - `,` (読点`、`や`:`等) + - `?` (疑問符`?`) + - `!` (感嘆符`!`) + - `'` (`「`や`」`等) + - `-` (`―`(ダッシュ、長音記号ではない)や`-`等) + + 注意点: + - 三点リーダー`…`は`...`に変換される(`なるほど…。` → `なるほど....`) + - 数字は漢字に変換される(`1,100円` → `千百円`、`52.34` → `五十二点三四`) + - 読点や疑問符等の位置・個数等は保持される(`??あ、、!!!` → `??あ,,!!!`) + + Args: + text (str): 正規化するテキスト + + Returns: + str: 正規化されたテキスト + """ + + res = unicodedata.normalize("NFKC", text) # ここでアルファベットは半角になる + res = __convert_numbers_to_words(res) # 「100円」→「百円」等 + # 「~」と「~」も長音記号として扱う + res = res.replace("~", "ー") + res = res.replace("~", "ー") + + res = replace_punctuation(res) # 句読点等正規化、読めない文字を削除 + + # 結合文字の濁点・半濁点を削除 + # 通常の「ば」等はそのままのこされる、「あ゛」は上で「あ゙」になりここで「あ」になる + res = res.replace("\u3099", "") # 結合文字の濁点を削除、る゙ → る + res = res.replace("\u309A", "") # 結合文字の半濁点を削除、な゚ → な + return res + + +def __convert_numbers_to_words(text: str) -> str: + """ + 記号や数字を日本語の文字表現に変換する。 + + Args: + text (str): 変換するテキスト + + Returns: + str: 変換されたテキスト + """ + + NUMBER_WITH_SEPARATOR_PATTERN = re.compile("[0-9]{1,3}(,[0-9]{3})+") + CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} + CURRENCY_PATTERN = re.compile(r"([$¥£€])([0-9.]*[0-9])") + NUMBER_PATTERN = re.compile(r"[0-9]+(\.[0-9]+)?") + + res = NUMBER_WITH_SEPARATOR_PATTERN.sub(lambda m: m[0].replace(",", ""), text) + res = CURRENCY_PATTERN.sub(lambda m: m[2] + CURRENCY_MAP.get(m[1], m[1]), res) + res = NUMBER_PATTERN.sub(lambda m: num2words(m[0], lang="ja"), res) + + return res + + +def replace_punctuation(text: str) -> str: + """ + 句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalk で読みが取得できるもののみ残す: + 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字 + + Args: + text (str): 正規化するテキスト + + Returns: + str: 正規化されたテキスト + """ + + # 記号類の正規化変換マップ + REPLACE_MAP = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + ".": ".", + "…": "...", + "···": "...", + "・・・": "...", + "·": ",", + "・": ",", + "、": ",", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + # NFKC 正規化後のハイフン・ダッシュの変種を全て通常半角ハイフン - \u002d に変換 + "\u02d7": "\u002d", # ˗, Modifier Letter Minus Sign + "\u2010": "\u002d", # ‐, Hyphen, + # "\u2011": "\u002d", # ‑, Non-Breaking Hyphen, NFKC により \u2010 に変換される + "\u2012": "\u002d", # ‒, Figure Dash + "\u2013": "\u002d", # –, En Dash + "\u2014": "\u002d", # —, Em Dash + "\u2015": "\u002d", # ―, Horizontal Bar + "\u2043": "\u002d", # ⁃, Hyphen Bullet + "\u2212": "\u002d", # −, Minus Sign + "\u23af": "\u002d", # ⎯, Horizontal Line Extension + "\u23e4": "\u002d", # ⏤, Straightness + "\u2500": "\u002d", # ─, Box Drawings Light Horizontal + "\u2501": "\u002d", # ━, Box Drawings Heavy Horizontal + "\u2e3a": "\u002d", # ⸺, Two-Em Dash + "\u2e3b": "\u002d", # ⸻, Three-Em Dash + # "~": "-", # これは長音記号「ー」として扱うよう変更 + # "~": "-", # これも長音記号「ー」として扱うよう変更 + "「": "'", + "」": "'", + } + + pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys())) + + # 句読点を辞書で置換 + replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text) + + replaced_text = re.sub( + # ↓ ひらがな、カタカナ、漢字 + r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" + # ↓ 半角アルファベット(大文字と小文字) + + r"\u0041-\u005A\u0061-\u007A" + # ↓ 全角アルファベット(大文字と小文字) + + r"\uFF21-\uFF3A\uFF41-\uFF5A" + # ↓ ギリシャ文字 + + r"\u0370-\u03FF\u1F00-\u1FFF" + # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている + + "".join(PUNCTUATIONS) + r"]+", + # 上述以外の文字を削除 + "", + replaced_text, + ) + + return replaced_text diff --git a/style_bert_vits2/text_processing/symbols.py b/style_bert_vits2/text_processing/symbols.py new file mode 100644 index 000000000..628edc61a --- /dev/null +++ b/style_bert_vits2/text_processing/symbols.py @@ -0,0 +1,192 @@ +# Punctuations +PUNCTUATIONS = ["!", "?", "…", ",", ".", "'", "-"] + +# Punctuations and special tokens +PUNCTUATION_SYMBOLS = PUNCTUATIONS + ["SP", "UNK"] + +# Padding +PAD = "_" + +# Chinese symbols +ZH_SYMBOLS = [ + "E", + "En", + "a", + "ai", + "an", + "ang", + "ao", + "b", + "c", + "ch", + "d", + "e", + "ei", + "en", + "eng", + "er", + "f", + "g", + "h", + "i", + "i0", + "ia", + "ian", + "iang", + "iao", + "ie", + "in", + "ing", + "iong", + "ir", + "iu", + "j", + "k", + "l", + "m", + "n", + "o", + "ong", + "ou", + "p", + "q", + "r", + "s", + "sh", + "t", + "u", + "ua", + "uai", + "uan", + "uang", + "ui", + "un", + "uo", + "v", + "van", + "ve", + "vn", + "w", + "x", + "y", + "z", + "zh", + "AA", + "EE", + "OO", +] +NUM_ZH_TONES = 6 + +# japanese +JA_SYMBOLS = [ + "N", + "a", + "a:", + "b", + "by", + "ch", + "d", + "dy", + "e", + "e:", + "f", + "g", + "gy", + "h", + "hy", + "i", + "i:", + "j", + "k", + "ky", + "m", + "my", + "n", + "ny", + "o", + "o:", + "p", + "py", + "q", + "r", + "ry", + "s", + "sh", + "t", + "ts", + "ty", + "u", + "u:", + "w", + "y", + "z", + "zy", +] +NUM_JA_TONES = 2 + +# English +EN_SYMBOLS = [ + "aa", + "ae", + "ah", + "ao", + "aw", + "ay", + "b", + "ch", + "d", + "dh", + "eh", + "er", + "ey", + "f", + "g", + "hh", + "ih", + "iy", + "jh", + "k", + "l", + "m", + "n", + "ng", + "ow", + "oy", + "p", + "r", + "s", + "sh", + "t", + "th", + "uh", + "uw", + "V", + "w", + "y", + "z", + "zh", +] +NUM_EN_TONES = 4 + +# Combine all symbols +NORMAL_SYMBOLS = sorted(set(ZH_SYMBOLS + JA_SYMBOLS + EN_SYMBOLS)) +SYMBOLS = [PAD] + NORMAL_SYMBOLS + PUNCTUATION_SYMBOLS +SIL_PHONEMES_IDS = [SYMBOLS.index(i) for i in PUNCTUATION_SYMBOLS] + +# Combine all tones +num_tones = NUM_ZH_TONES + NUM_JA_TONES + NUM_EN_TONES + +# Language maps +LANGUAGE_ID_MAP = {"ZH": 0, "JP": 1, "EN": 2} +NUM_LANGUAGES = len(LANGUAGE_ID_MAP.keys()) + +LANGUAGE_TONE_START_MAP = { + "ZH": 0, + "JP": NUM_ZH_TONES, + "EN": NUM_ZH_TONES + NUM_JA_TONES, +} + +if __name__ == "__main__": + a = set(ZH_SYMBOLS) + b = set(EN_SYMBOLS) + print(sorted(a & b)) diff --git a/style_bert_vits2/utils/stdout_wrapper.py b/style_bert_vits2/utils/stdout_wrapper.py new file mode 100644 index 000000000..09254ada4 --- /dev/null +++ b/style_bert_vits2/utils/stdout_wrapper.py @@ -0,0 +1,47 @@ +import sys +import tempfile +from typing import TextIO + + +class StdoutWrapper(TextIO): + """ + `sys.stdout` wrapper for both Google Colab and local environment. + """ + + + def __init__(self) -> None: + self.temp_file = tempfile.NamedTemporaryFile( + mode="w+", delete=False, encoding="utf-8" + ) + self.original_stdout = sys.stdout + + + def write(self, message: str) -> int: + result = self.temp_file.write(message) + self.temp_file.flush() + print(message, end="", file=self.original_stdout) + return result + + + def flush(self) -> None: + self.temp_file.flush() + + + def read(self, n: int = -1) -> str: + self.temp_file.seek(0) + return self.temp_file.read(n) + + + def close(self) -> None: + self.temp_file.close() + + + def fileno(self) -> int: + return self.temp_file.fileno() + + +try: + import google.colab # type: ignore + SAFE_STDOUT = StdoutWrapper() +except ImportError: + SAFE_STDOUT = sys.stdout From 46c83cf89a1b298fff5b956912108ad6aec80023 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 21:35:47 +0000 Subject: [PATCH 013/148] Refactor: moved the user dictionary implementation ported from VOICEVOX to style_bert_vits2/text_processing/japanese/user_dict/ --- server_editor.py | 2 +- style_bert_vits2/constants.py | 21 ++++++----- .../japanese/user_dict/README.md | 27 ++++++++++++++ .../japanese}/user_dict/__init__.py | 35 +++++++++---------- .../user_dict/part_of_speech_data.py | 13 +++---- .../japanese}/user_dict/word_model.py | 12 ++++--- text/japanese.py | 2 +- text/user_dict/README.md | 19 ---------- 8 files changed, 71 insertions(+), 60 deletions(-) create mode 100644 style_bert_vits2/text_processing/japanese/user_dict/README.md rename {text => style_bert_vits2/text_processing/japanese}/user_dict/__init__.py (93%) rename {text => style_bert_vits2/text_processing/japanese}/user_dict/part_of_speech_data.py (87%) rename {text => style_bert_vits2/text_processing/japanese}/user_dict/word_model.py (93%) delete mode 100644 text/user_dict/README.md diff --git a/server_editor.py b/server_editor.py index 70f7c9bc4..7577637cf 100644 --- a/server_editor.py +++ b/server_editor.py @@ -44,7 +44,7 @@ from style_bert_vits2.logging import logger from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.text_processing.japanese.normalizer import normalize_text -from text.user_dict import ( +from style_bert_vits2.text_processing.japanese.user_dict import ( apply_word, update_dict, read_dict, diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index a90bc01ba..2d6e9adc1 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -5,21 +5,17 @@ # Style-Bert-VITS2 のバージョン VERSION = "2.3.1" -# ユーザー辞書ディレクトリ -USER_DICT_DIR = Path("dict_data") - # Gradio のテーマ ## Built-in theme: "default", "base", "monochrome", "soft", "glass" ## See https://huggingface.co/spaces/gradio/theme-gallery for more themes GRADIO_THEME = "NoCrypt/miku" -# 利用可能な言語 -class Languages(str, Enum): - JP = "JP" - EN = "EN" - ZH = "ZH" +# デフォルトのユーザー辞書ディレクトリ +## style_bert_vits2.text_processing.japanese.user_dict モジュールのデフォルト値として利用される +## ライブラリとしての利用などで外部のユーザー辞書を指定したい場合は、user_dict 以下の各関数の実行時、引数に辞書データファイルのパスを指定する +DEFAULT_USER_DICT_DIR = Path(__file__).parent.parent / "dict_data" -# 推論パラメータのデフォルト値 +# デフォルトの推論パラメータ DEFAULT_STYLE = "Neutral" DEFAULT_STYLE_WEIGHT = 5.0 DEFAULT_SDP_RATIO = 0.2 @@ -30,3 +26,10 @@ class Languages(str, Enum): DEFAULT_SPLIT_INTERVAL = 0.5 DEFAULT_ASSIST_TEXT_WEIGHT = 0.7 DEFAULT_ASSIST_TEXT_WEIGHT = 1.0 + +# 利用可能な言語 +## JP-Extra モデル利用時は JP 以外の言語の音声合成はできない +class Languages(str, Enum): + JP = "JP" + EN = "EN" + ZH = "ZH" diff --git a/style_bert_vits2/text_processing/japanese/user_dict/README.md b/style_bert_vits2/text_processing/japanese/user_dict/README.md new file mode 100644 index 000000000..3e7a0c034 --- /dev/null +++ b/style_bert_vits2/text_processing/japanese/user_dict/README.md @@ -0,0 +1,27 @@ + +## ユーザー辞書関連のコードについて + +このフォルダに含まれるユーザー辞書関連のコードは、[VOICEVOX ENGINE](https://github.com/VOICEVOX/voicevox_engine) プロジェクトのコードを改変したものを使用しています。 +VOICEVOX プロジェクトのチームに深く感謝し、その貢献を尊重します。 + +### 元のコード + +- [voicevox_engine/user_dict/](https://github.com/VOICEVOX/voicevox_engine/tree/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/user_dict) +- [voicevox_engine/model.py](https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/model.py#L207) + +### 改変の詳細 + +- ファイル名の書き換えおよびそれに伴う import 文の書き換え。 +- VOICEVOX 固有の部分をコメントアウト。 +- mutex を使用している部分をコメントアウト。 +- 参照している pyopenjtalk の違いによるメソッド名の書き換え。 +- UserDictWord の mora_count のデフォルト値を None に指定。 +- `model.py` のうち、必要な Pydantic モデルのみを抽出。 + +### ライセンス + +元の VOICEVOX ENGINE のリポジトリのコードは、LGPL v3 と、ソースコードの公開が不要な別ライセンスのデュアルライセンスの下で使用されています。 +当プロジェクトにおけるこのモジュールも LGPL ライセンスの下にあります。 + +詳細については、プロジェクトのルートディレクトリにある [LGPL_LICENSE](/LGPL_LICENSE) ファイルをご参照ください。 +また、元の VOICEVOX ENGINE プロジェクトのライセンスについては、[こちら](https://github.com/VOICEVOX/voicevox_engine/blob/master/LICENSE) をご覧ください。 diff --git a/text/user_dict/__init__.py b/style_bert_vits2/text_processing/japanese/user_dict/__init__.py similarity index 93% rename from text/user_dict/__init__.py rename to style_bert_vits2/text_processing/japanese/user_dict/__init__.py index c12b3d1aa..10ea02f13 100644 --- a/text/user_dict/__init__.py +++ b/style_bert_vits2/text_processing/japanese/user_dict/__init__.py @@ -1,11 +1,12 @@ -# このファイルは、VOICEVOXプロジェクトのVOICEVOX engineからお借りしています。 -# 引用元: -# https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/user_dict/user_dict.py -# ライセンス: LGPL-3.0 -# 詳しくは、このファイルと同じフォルダにあるREADME.mdを参照してください。 +""" +このファイルは、VOICEVOX プロジェクトの VOICEVOX ENGINE からお借りしています。 +引用元: https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/user_dict/user_dict.py +ライセンス: LGPL-3.0 +詳しくは、このファイルと同じフォルダにある README.md を参照してください。 +""" + import json import sys -import threading import traceback from pathlib import Path from typing import Dict, List, Optional @@ -15,25 +16,21 @@ import pyopenjtalk from fastapi import HTTPException -from .word_model import UserDictWord, WordTypes - +from style_bert_vits2.constants import DEFAULT_USER_DICT_DIR +from style_bert_vits2.text_processing.japanese.user_dict.word_model import UserDictWord, WordTypes # from ..utility.mutex_utility import mutex_wrapper # from ..utility.path_utility import engine_root, get_save_dir -from .part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data -from common.constants import USER_DICT_DIR +from style_bert_vits2.text_processing.japanese.user_dict.part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data # root_dir = engine_root() # save_dir = get_save_dir() -root_dir = Path(USER_DICT_DIR) -save_dir = Path(USER_DICT_DIR) - -if not save_dir.is_dir(): - save_dir.mkdir(parents=True) +# if not save_dir.is_dir(): +# save_dir.mkdir(parents=True) -default_dict_path = root_dir / "default.csv" # VOICEVOXデフォルト辞書ファイルのパス -user_dict_path = save_dir / "user_dict.json" # ユーザー辞書ファイルのパス -compiled_dict_path = save_dir / "user.dic" # コンパイル済み辞書ファイルのパス +default_dict_path = DEFAULT_USER_DICT_DIR / "default.csv" # VOICEVOXデフォルト辞書ファイルのパス +user_dict_path = DEFAULT_USER_DICT_DIR / "user_dict.json" # ユーザー辞書ファイルのパス +compiled_dict_path = DEFAULT_USER_DICT_DIR / "user.dic" # コンパイル済み辞書ファイルのパス # # 同時書き込みの制御 @@ -54,7 +51,7 @@ def _write_to_json(user_dict: Dict[str, UserDictWord], user_dict_path: Path) -> """ converted_user_dict = {} for word_uuid, word in user_dict.items(): - word_dict = word.dict() + word_dict = word.model_dump() word_dict["cost"] = _priority2cost( word_dict["context_id"], word_dict["priority"] ) diff --git a/text/user_dict/part_of_speech_data.py b/style_bert_vits2/text_processing/japanese/user_dict/part_of_speech_data.py similarity index 87% rename from text/user_dict/part_of_speech_data.py rename to style_bert_vits2/text_processing/japanese/user_dict/part_of_speech_data.py index 7e22699b7..db42d3869 100644 --- a/text/user_dict/part_of_speech_data.py +++ b/style_bert_vits2/text_processing/japanese/user_dict/part_of_speech_data.py @@ -1,12 +1,13 @@ -# このファイルは、VOICEVOXプロジェクトのVOICEVOX engineからお借りしています。 -# 引用元: -# https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/user_dict/part_of_speech_data.py -# ライセンス: LGPL-3.0 -# 詳しくは、このファイルと同じフォルダにあるREADME.mdを参照してください。 +""" +このファイルは、VOICEVOX プロジェクトの VOICEVOX ENGINE からお借りしています。 +引用元: https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/user_dict/part_of_speech_data.py +ライセンス: LGPL-3.0 +詳しくは、このファイルと同じフォルダにある README.md を参照してください。 +""" from typing import Dict -from .word_model import ( +from style_bert_vits2.text_processing.japanese.user_dict.word_model import ( USER_DICT_MAX_PRIORITY, USER_DICT_MIN_PRIORITY, PartOfSpeechDetail, diff --git a/text/user_dict/word_model.py b/style_bert_vits2/text_processing/japanese/user_dict/word_model.py similarity index 93% rename from text/user_dict/word_model.py rename to style_bert_vits2/text_processing/japanese/user_dict/word_model.py index f05d8dc47..bcd4d377f 100644 --- a/text/user_dict/word_model.py +++ b/style_bert_vits2/text_processing/japanese/user_dict/word_model.py @@ -1,8 +1,10 @@ -# このファイルは、VOICEVOXプロジェクトのVOICEVOX engineからお借りしています。 -# 引用元: -# https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/model.py#L207 -# ライセンス: LGPL-3.0 -# 詳しくは、このファイルと同じフォルダにあるREADME.mdを参照してください。 +""" +このファイルは、VOICEVOX プロジェクトの VOICEVOX ENGINE からお借りしています。 +引用元: https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/model.py#L207 +ライセンス: LGPL-3.0 +詳しくは、このファイルと同じフォルダにある README.md を参照してください。 +""" + from enum import Enum from re import findall, fullmatch from typing import List, Optional diff --git a/text/japanese.py b/text/japanese.py index b18bc682e..47b21c5b8 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -15,7 +15,7 @@ mora_phonemes_to_mora_kata, ) -from text.user_dict import update_dict +from style_bert_vits2.text_processing.japanese.user_dict import update_dict # 最初にpyopenjtalkの辞書を更新 update_dict() diff --git a/text/user_dict/README.md b/text/user_dict/README.md deleted file mode 100644 index 6f5618eda..000000000 --- a/text/user_dict/README.md +++ /dev/null @@ -1,19 +0,0 @@ -このフォルダに含まれるユーザー辞書関連のコードは、[VOICEVOX engine](https://github.com/VOICEVOX/voicevox_engine)プロジェクトのコードを改変したものを使用しています。VOICEVOXプロジェクトのチームに深く感謝し、その貢献を尊重します。 - -**元のコード**: - -- [voicevox_engine/user_dict/](https://github.com/VOICEVOX/voicevox_engine/tree/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/user_dict) -- [voicevox_engine/model.py](https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/model.py#L207) - -**改変の詳細**: - -- ファイル名の書き換えおよびそれに伴うimport文の書き換え。 -- VOICEVOX固有の部分をコメントアウト。 -- mutexを使用している部分をコメントアウト。 -- 参照しているpyopenjtalkの違いによるメソッド名の書き換え。 -- UserDictWordのmora_countのデフォルト値をNoneに指定。 -- Pydanticのモデルで必要な箇所のみを抽出。 - -**ライセンス**: - -元のVOICEVOX engineのリポジトリのコードは、LGPL v3 と、ソースコードの公開が不要な別ライセンスのデュアルライセンスの下で使用されています。当プロジェクトにおけるこのモジュールもLGPLライセンスの下にあります。詳細については、プロジェクトのルートディレクトリにある[LGPL_LICENSE](/LGPL_LICENSE)ファイルをご参照ください。また、元のVOICEVOX engineプロジェクトのライセンスについては、[こちら](https://github.com/VOICEVOX/voicevox_engine/blob/master/LICENSE)をご覧ください。 From 95464954349441f06dcc344be98bcbc845ea1d87 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 22:17:03 +0000 Subject: [PATCH 014/148] Refactor: moved commons.py to style_bert_vits2/models/ and added type definitions and comments --- attentions.py | 2 +- bert_gen.py | 2 +- commons.py | 152 ------------- data_utils.py | 2 +- infer.py | 2 +- models.py | 4 +- models_jp_extra.py | 4 +- modules.py | 4 +- style_bert_vits2/models/commons.py | 336 +++++++++++++++++++++++++++++ train_ms.py | 2 +- train_ms_jp_extra.py | 2 +- 11 files changed, 348 insertions(+), 164 deletions(-) delete mode 100644 commons.py create mode 100644 style_bert_vits2/models/commons.py diff --git a/attentions.py b/attentions.py index 9a4bba9cb..87a7f080d 100644 --- a/attentions.py +++ b/attentions.py @@ -3,7 +3,7 @@ from torch import nn from torch.nn import functional as F -import commons +from style_bert_vits2.models import commons from common.log import logger as logging diff --git a/bert_gen.py b/bert_gen.py index a5f7c258f..70af9fd79 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -5,7 +5,7 @@ import torch.multiprocessing as mp from tqdm import tqdm -import commons +from style_bert_vits2.models import commons import utils from common.log import logger from common.stdout_wrapper import SAFE_STDOUT diff --git a/commons.py b/commons.py deleted file mode 100644 index 081b8a061..000000000 --- a/commons.py +++ /dev/null @@ -1,152 +0,0 @@ -import math -import torch -from torch.nn import functional as F - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def convert_pad_shape(pad_shape): - layer = pad_shape[::-1] - pad_shape = [item for sublist in layer for item in sublist] - return pad_shape - - -def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += ( - 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) - ) - return kl - - -def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - -def slice_segments(x, ids_str, segment_size=4): - gather_indices = ids_str.view(x.size(0), 1, 1).repeat( - 1, x.size(1), 1 - ) + torch.arange(segment_size, device=x.device) - return torch.gather(x, 2, gather_indices) - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) - ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( - num_timescales - 1 - ) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - -def subsequent_mask(length): - mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) - return mask - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - -def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2, 3) * mask - return path - - -def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) - - total_norm = 0 - for p in parameters: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type - if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1.0 / norm_type) - return total_norm diff --git a/data_utils.py b/data_utils.py index ac038c22f..81c0ade7c 100644 --- a/data_utils.py +++ b/data_utils.py @@ -7,7 +7,7 @@ import torch.utils.data from tqdm import tqdm -import commons +from style_bert_vits2.models import commons from config import config from mel_processing import mel_spectrogram_torch, spectrogram_torch from text import cleaned_text_to_sequence diff --git a/infer.py b/infer.py index 3707df1f9..b976486ed 100644 --- a/infer.py +++ b/infer.py @@ -1,6 +1,6 @@ import torch -import commons +from style_bert_vits2.models import commons import utils from models import SynthesizerTrn from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra diff --git a/models.py b/models.py index eb706fd66..501fcaf3d 100644 --- a/models.py +++ b/models.py @@ -8,10 +8,10 @@ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm import attentions -import commons +from style_bert_vits2.models import commons import modules import monotonic_align -from commons import get_padding, init_weights +from style_bert_vits2.models.commons import get_padding, init_weights from text import num_languages, num_tones, symbols diff --git a/models_jp_extra.py b/models_jp_extra.py index 1bb2dd2e3..3e87ceda2 100644 --- a/models_jp_extra.py +++ b/models_jp_extra.py @@ -3,7 +3,7 @@ from torch import nn from torch.nn import functional as F -import commons +from style_bert_vits2.models import commons import modules import attentions import monotonic_align @@ -11,7 +11,7 @@ from torch.nn import Conv1d, ConvTranspose1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from commons import init_weights, get_padding +from style_bert_vits2.models.commons import init_weights, get_padding from text import symbols, num_tones, num_languages diff --git a/modules.py b/modules.py index 86b93b50a..68b0b9a67 100644 --- a/modules.py +++ b/modules.py @@ -7,9 +7,9 @@ from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, weight_norm -import commons +from style_bert_vits2.models import commons from attentions import Encoder -from commons import get_padding, init_weights +from style_bert_vits2.models.commons import get_padding, init_weights from transforms import piecewise_rational_quadratic_transform LRELU_SLOPE = 0.1 diff --git a/style_bert_vits2/models/commons.py b/style_bert_vits2/models/commons.py new file mode 100644 index 000000000..064ef5f58 --- /dev/null +++ b/style_bert_vits2/models/commons.py @@ -0,0 +1,336 @@ +""" +以下に記述されている関数のコメントはリファクタリング時に GPT-4 に生成させたもので、 +コードと完全に一致している保証はない。あくまで参考程度とすること。 +""" + +import math +import torch +from torch.nn import functional as F +from typing import Any + + +def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: + """ + モジュールの重みを初期化する + + Args: + m (torch.nn.Module): 重みを初期化する対象のモジュール + mean (float): 正規分布の平均 + std (float): 正規分布の標準偏差 + """ + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size: int, dilation: int = 1) -> int: + """ + カーネルサイズと膨張率からパディングの大きさを計算する + + Args: + kernel_size (int): カーネルのサイズ + dilation (int): 膨張率 + + Returns: + int: 計算されたパディングの大きさ + """ + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape: list[list[Any]]) -> list[Any]: + """ + パディングの形状を変換する + + Args: + pad_shape (list[list[Any]]): 変換前のパディングの形状 + + Returns: + list[Any]: 変換後のパディングの形状 + """ + layer = pad_shape[::-1] + new_pad_shape = [item for sublist in layer for item in sublist] + return new_pad_shape + + +def intersperse(lst: list[Any], item: Any) -> list[Any]: + """ + リストの要素の間に特定のアイテムを挿入する + + Args: + lst (list[Any]): 元のリスト + item (Any): 挿入するアイテム + + Returns: + list[Any]: 新しいリスト + """ + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p: torch.Tensor, logs_p: torch.Tensor, m_q: torch.Tensor, logs_q: torch.Tensor) -> torch.Tensor: + """ + 2つの正規分布間の KL ダイバージェンスを計算する + + Args: + m_p (torch.Tensor): P の平均 + logs_p (torch.Tensor): P の対数標準偏差 + m_q (torch.Tensor): Q の平均 + logs_q (torch.Tensor): Q の対数標準偏差 + + Returns: + torch.Tensor: KL ダイバージェンスの値。 + """ + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape: torch.Size) -> torch.Tensor: + """ + Gumbel 分布からサンプリングし、オーバーフローを防ぐ + + Args: + shape (torch.Size): サンプルの形状 + + Returns: + torch.Tensor: Gumbel 分布からのサンプル + """ + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x: torch.Tensor) -> torch.Tensor: + """ + 引数と同じ形状のテンソルで、Gumbel 分布からサンプリングする + + Args: + x (torch.Tensor): 形状を基にするテンソル + + Returns: + torch.Tensor: Gumbel 分布からのサンプル + """ + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4) -> torch.Tensor: + """ + テンソルからセグメントをスライスする + + Args: + x (torch.Tensor): 入力テンソル + ids_str (torch.Tensor): スライスを開始するインデックス + segment_size (int, optional): スライスのサイズ (デフォルト: 4) + + Returns: + torch.Tensor: スライスされたセグメント + """ + gather_indices = ids_str.view(x.size(0), 1, 1).repeat( + 1, x.size(1), 1 + ) + torch.arange(segment_size, device=x.device) + return torch.gather(x, 2, gather_indices) + + +def rand_slice_segments(x: torch.Tensor, x_lengths: torch.Tensor | None = None, segment_size: int = 4) -> tuple[torch.Tensor, torch.Tensor]: + """ + ランダムなセグメントをスライスする + + Args: + x (torch.Tensor): 入力テンソル + x_lengths (torch.Tensor, optional): 各バッチの長さ (デフォルト: None) + segment_size (int, optional): スライスのサイズ (デフォルト: 4) + + Returns: + tuple[torch.Tensor, torch.Tensor]: スライスされたセグメントと開始インデックス + """ + b, d, t = x.size() + if x_lengths is None: + x_lengths = t # type: ignore + ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) # type: ignore + ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length: int, channels: int, min_timescale: float = 1.0, max_timescale: float = 1.0e4) -> torch.Tensor: + """ + 1D タイミング信号を取得する + + Args: + length (int): シグナルの長さ + channels (int): シグナルのチャネル数 + min_timescale (float, optional): 最小のタイムスケール (デフォルト: 1.0) + max_timescale (float, optional): 最大のタイムスケール (デフォルト: 1.0e4) + + Returns: + torch.Tensor: タイミング信号 + """ + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x: torch.Tensor, min_timescale: float = 1.0, max_timescale: float = 1.0e4) -> torch.Tensor: + """ + 1D タイミング信号をテンソルに追加する + + Args: + x (torch.Tensor): 入力テンソル + min_timescale (float, optional): 最小のタイムスケール (デフォルト: 1.0) + max_timescale (float, optional): 最大のタイムスケール (デフォルト: 1.0e4) + + Returns: + torch.Tensor: タイミング信号が追加されたテンソル + """ + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x: torch.Tensor, min_timescale: float = 1.0, max_timescale: float = 1.0e4, axis: int = 1) -> torch.Tensor: + """ + 1D タイミング信号をテンソルに連結する + + Args: + x (torch.Tensor): 入力テンソル + min_timescale (float, optional): 最小のタイムスケール (デフォルト: 1.0) + max_timescale (float, optional): 最大のタイムスケール (デフォルト: 1.0e4) + axis (int, optional): 連結する軸 (デフォルト: 1) + + Returns: + torch.Tensor: タイミング信号が連結されたテンソル + """ + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length: int) -> torch.Tensor: + """ + 後続のマスクを生成する + + Args: + length (int): マスクのサイズ + + Returns: + torch.Tensor: 生成されたマスク + """ + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script # type: ignore +def fused_add_tanh_sigmoid_multiply(input_a: torch.Tensor, input_b: torch.Tensor, n_channels: torch.Tensor) -> torch.Tensor: + """ + 加算、tanh、sigmoid の活性化関数を組み合わせた演算を行う + + Args: + input_a (torch.Tensor): 入力テンソル A + input_b (torch.Tensor): 入力テンソル B + n_channels (torch.Tensor): チャネル数 + + Returns: + torch.Tensor: 演算結果 + """ + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def shift_1d(x: torch.Tensor) -> torch.Tensor: + """ + 与えられたテンソルを 1D でシフトする + + Args: + x (torch.Tensor): シフトするテンソル + + Returns: + torch.Tensor: シフトされたテンソル + """ + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length: torch.Tensor, max_length: int | None = None) -> torch.Tensor: + """ + シーケンスマスクを生成する + + Args: + length (torch.Tensor): 各シーケンスの長さ + max_length (int | None): 最大のシーケンス長さ。指定されていない場合は length の最大値を使用 + + Returns: + torch.Tensor: 生成されたシーケンスマスク + """ + if max_length is None: + max_length = length.max() # type: ignore + x = torch.arange(max_length, dtype=length.dtype, device=length.device) # type: ignore + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + パスを生成する + + Args: + duration (torch.Tensor): 各時間ステップの持続時間 + mask (torch.Tensor): マスクテンソル + + Returns: + torch.Tensor: 生成されたパス + """ + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters: torch.Tensor | list[torch.Tensor], clip_value: float | None, norm_type: float = 2.0) -> float: + """ + 勾配の値をクリップする + + Args: + parameters (torch.Tensor | list[torch.Tensor]): クリップするパラメータ + clip_value (float | None): クリップする値。None の場合はクリップしない + norm_type (float): ノルムの種類 + + Returns: + float: 総ノルム + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0.0 + for p in parameters: + assert p.grad is not None + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/train_ms.py b/train_ms.py index 07b510015..0631ffe1c 100644 --- a/train_ms.py +++ b/train_ms.py @@ -15,7 +15,7 @@ from tqdm import tqdm # logging.getLogger("numba").setLevel(logging.WARNING) -import commons +from style_bert_vits2.models import commons import default_style import utils from common.log import logger diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index bae922d52..aa5925b77 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -15,7 +15,7 @@ from huggingface_hub import HfApi # logging.getLogger("numba").setLevel(logging.WARNING) -import commons +from style_bert_vits2.models import commons import default_style import utils from common.log import logger From f880641eb5e82508b2a801014e859bf71c52e0c8 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 22:29:12 +0000 Subject: [PATCH 015/148] Remove: modules under common/ that have been rewritten --- app.py | 8 ++++---- attentions.py | 2 +- bert_gen.py | 4 ++-- common/constants.py | 28 ---------------------------- common/log.py | 17 ----------------- common/stdout_wrapper.py | 40 ---------------------------------------- config.py | 2 +- data_utils.py | 2 +- default_style.py | 4 ++-- infer.py | 2 +- initialize.py | 2 +- losses.py | 2 +- preprocess_text.py | 4 ++-- resample.py | 4 ++-- server_fastapi.py | 4 ++-- slice.py | 4 ++-- speech_mos.py | 2 +- style_gen.py | 4 ++-- text/japanese.py | 2 +- train_ms.py | 4 ++-- train_ms_jp_extra.py | 4 ++-- transcribe.py | 6 +++--- utils.py | 2 +- webui_dataset.py | 4 ++-- webui_merge.py | 4 ++-- webui_style_vectors.py | 4 ++-- webui_train.py | 8 ++++---- 27 files changed, 44 insertions(+), 129 deletions(-) delete mode 100644 common/constants.py delete mode 100644 common/log.py delete mode 100644 common/stdout_wrapper.py diff --git a/app.py b/app.py index 1514444f7..acdb64655 100644 --- a/app.py +++ b/app.py @@ -10,7 +10,7 @@ import torch import yaml -from common.constants import ( +from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, DEFAULT_LINE_SPLIT, @@ -21,10 +21,10 @@ DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, GRADIO_THEME, - LATEST_VERSION, + VERSION, Languages, ) -from common.log import logger +from style_bert_vits2.logging import logger from common.tts_model import ModelHolder from infer import InvalidToneError from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize @@ -202,7 +202,7 @@ def tts_fn( ] initial_md = f""" -# Style-Bert-VITS2 ver {LATEST_VERSION} 音声合成 +# Style-Bert-VITS2 ver {VERSION} 音声合成 - Ver 2.3で追加されたエディターのほうが実際に読み上げさせるには使いやすいかもしれません。`Editor.bat`か`python server_editor.py`で起動できます。 diff --git a/attentions.py b/attentions.py index 87a7f080d..1cca086ba 100644 --- a/attentions.py +++ b/attentions.py @@ -4,7 +4,7 @@ from torch.nn import functional as F from style_bert_vits2.models import commons -from common.log import logger as logging +from style_bert_vits2.logging import logger as logging class LayerNorm(nn.Module): diff --git a/bert_gen.py b/bert_gen.py index 70af9fd79..fd0b54eff 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -7,8 +7,8 @@ from style_bert_vits2.models import commons import utils -from common.log import logger -from common.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config from text import cleaned_text_to_sequence, get_bert diff --git a/common/constants.py b/common/constants.py deleted file mode 100644 index fe620195d..000000000 --- a/common/constants.py +++ /dev/null @@ -1,28 +0,0 @@ -import enum - -# Built-in theme: "default", "base", "monochrome", "soft", "glass" -# See https://huggingface.co/spaces/gradio/theme-gallery for more themes -GRADIO_THEME: str = "NoCrypt/miku" - -LATEST_VERSION: str = "2.3.1" - -USER_DICT_DIR = "dict_data" - -DEFAULT_STYLE: str = "Neutral" -DEFAULT_STYLE_WEIGHT: float = 5.0 - - -class Languages(str, enum.Enum): - JP = "JP" - EN = "EN" - ZH = "ZH" - - -DEFAULT_SDP_RATIO: float = 0.2 -DEFAULT_NOISE: float = 0.6 -DEFAULT_NOISEW: float = 0.8 -DEFAULT_LENGTH: float = 1.0 -DEFAULT_LINE_SPLIT: bool = True -DEFAULT_SPLIT_INTERVAL: float = 0.5 -DEFAULT_ASSIST_TEXT_WEIGHT: float = 0.7 -DEFAULT_ASSIST_TEXT_WEIGHT: float = 1.0 diff --git a/common/log.py b/common/log.py deleted file mode 100644 index 679bb2c77..000000000 --- a/common/log.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -logger封装 -""" - -from loguru import logger - -from .stdout_wrapper import SAFE_STDOUT - -# 移除所有默认的处理器 -logger.remove() - -# 自定义格式并添加到标准输出 -log_format = ( - "{time:MM-DD HH:mm:ss} |{level:^8}| {file}:{line} | {message}" -) - -logger.add(SAFE_STDOUT, format=log_format, backtrace=True, diagnose=True) diff --git a/common/stdout_wrapper.py b/common/stdout_wrapper.py deleted file mode 100644 index 192f9084a..000000000 --- a/common/stdout_wrapper.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -`sys.stdout` wrapper for both Google Colab and local environment. -""" - -import sys -import tempfile - - -class StdoutWrapper: - def __init__(self): - self.temp_file = tempfile.NamedTemporaryFile( - mode="w+", delete=False, encoding="utf-8" - ) - self.original_stdout = sys.stdout - - def write(self, message: str): - self.temp_file.write(message) - self.temp_file.flush() - print(message, end="", file=self.original_stdout) - - def flush(self): - self.temp_file.flush() - - def read(self): - self.temp_file.seek(0) - return self.temp_file.read() - - def close(self): - self.temp_file.close() - - def fileno(self): - return self.temp_file.fileno() - - -try: - import google.colab - - SAFE_STDOUT = StdoutWrapper() -except ImportError: - SAFE_STDOUT = sys.stdout diff --git a/config.py b/config.py index 056e7f1cf..6369e6bbd 100644 --- a/config.py +++ b/config.py @@ -9,7 +9,7 @@ import torch import yaml -from common.log import logger +from style_bert_vits2.logging import logger # If not cuda available, set possible devices to cpu cuda_available = torch.cuda.is_available() diff --git a/data_utils.py b/data_utils.py index 81c0ade7c..111810250 100644 --- a/data_utils.py +++ b/data_utils.py @@ -11,7 +11,7 @@ from config import config from mel_processing import mel_spectrogram_torch, spectrogram_torch from text import cleaned_text_to_sequence -from common.log import logger +from style_bert_vits2.logging import logger from utils import load_filepaths_and_text, load_wav_to_torch """Multi speaker version""" diff --git a/default_style.py b/default_style.py index 9198ca8b2..763e29140 100644 --- a/default_style.py +++ b/default_style.py @@ -1,6 +1,6 @@ import os -from common.log import logger -from common.constants import DEFAULT_STYLE +from style_bert_vits2.logging import logger +from style_bert_vits2.constants import DEFAULT_STYLE import numpy as np import json diff --git a/infer.py b/infer.py index b976486ed..525219d35 100644 --- a/infer.py +++ b/infer.py @@ -7,7 +7,7 @@ from text import cleaned_text_to_sequence, get_bert from text.cleaner import clean_text from text.symbols import symbols -from common.log import logger +from style_bert_vits2.logging import logger class InvalidToneError(ValueError): diff --git a/initialize.py b/initialize.py index 5e35061f2..927a91a80 100644 --- a/initialize.py +++ b/initialize.py @@ -5,7 +5,7 @@ import yaml from huggingface_hub import hf_hub_download -from common.log import logger +from style_bert_vits2.logging import logger def download_bert_models(): diff --git a/losses.py b/losses.py index 763cc0285..4a890ba30 100644 --- a/losses.py +++ b/losses.py @@ -2,7 +2,7 @@ import torchaudio from transformers import AutoModel -from common.log import logger +from style_bert_vits2.logging import logger def feature_loss(fmap_r, fmap_g): diff --git a/preprocess_text.py b/preprocess_text.py index b3aaa17f7..126ba2c59 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -7,8 +7,8 @@ import click from tqdm import tqdm -from common.log import logger -from common.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config from text.cleaner import clean_text diff --git a/resample.py b/resample.py index 5aff5ef49..7001af6a9 100644 --- a/resample.py +++ b/resample.py @@ -7,8 +7,8 @@ import soundfile from tqdm import tqdm -from common.log import logger -from common.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config DEFAULT_BLOCK_SIZE: float = 0.400 # seconds diff --git a/server_fastapi.py b/server_fastapi.py index ce6ed0432..132cc175d 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -20,7 +20,7 @@ from fastapi.responses import FileResponse, Response from scipy.io import wavfile -from common.constants import ( +from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, DEFAULT_LINE_SPLIT, @@ -32,7 +32,7 @@ DEFAULT_STYLE_WEIGHT, Languages, ) -from common.log import logger +from style_bert_vits2.logging import logger from common.tts_model import Model, ModelHolder from config import config diff --git a/slice.py b/slice.py index 2d56427d1..c69f8bf88 100644 --- a/slice.py +++ b/slice.py @@ -8,8 +8,8 @@ import yaml from tqdm import tqdm -from common.log import logger -from common.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT vad_model, utils = torch.hub.load( repo_or_dir="snakers4/silero-vad", diff --git a/speech_mos.py b/speech_mos.py index d69a23ab5..15cccef5f 100644 --- a/speech_mos.py +++ b/speech_mos.py @@ -10,7 +10,7 @@ import torch from tqdm import tqdm -from common.log import logger +from style_bert_vits2.logging import logger from common.tts_model import Model from config import config diff --git a/style_gen.py b/style_gen.py index 97a0aee83..1c1f0340f 100644 --- a/style_gen.py +++ b/style_gen.py @@ -7,8 +7,8 @@ from tqdm import tqdm import utils -from common.log import logger -from common.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config warnings.filterwarnings("ignore", category=UserWarning) diff --git a/text/japanese.py b/text/japanese.py index 47b21c5b8..12dc349a8 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -8,7 +8,7 @@ from num2words import num2words from transformers import AutoTokenizer -from common.log import logger +from style_bert_vits2.logging import logger from text import punctuation from text.japanese_mora_list import ( mora_kata_to_mora_phonemes, diff --git a/train_ms.py b/train_ms.py index 0631ffe1c..3e9e70723 100644 --- a/train_ms.py +++ b/train_ms.py @@ -18,8 +18,8 @@ from style_bert_vits2.models import commons import default_style import utils -from common.log import logger -from common.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config from data_utils import ( DistributedBucketSampler, diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index aa5925b77..9a964c34a 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -18,8 +18,8 @@ from style_bert_vits2.models import commons import default_style import utils -from common.log import logger -from common.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config from data_utils import ( DistributedBucketSampler, diff --git a/transcribe.py b/transcribe.py index b3f9014b7..18509c9ea 100644 --- a/transcribe.py +++ b/transcribe.py @@ -7,9 +7,9 @@ from faster_whisper import WhisperModel from tqdm import tqdm -from common.constants import Languages -from common.log import logger -from common.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.constants import Languages +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT def transcribe(wav_path: Path, initial_prompt=None, language="ja"): diff --git a/utils.py b/utils.py index ca194bfb5..80dfa66ed 100644 --- a/utils.py +++ b/utils.py @@ -13,7 +13,7 @@ from safetensors.torch import save_file from scipy.io.wavfile import read -from common.log import logger +from style_bert_vits2.logging import logger MATPLOTLIB_FLAG = False diff --git a/webui_dataset.py b/webui_dataset.py index fec7a9ac9..169cc4dbf 100644 --- a/webui_dataset.py +++ b/webui_dataset.py @@ -4,8 +4,8 @@ import gradio as gr import yaml -from common.constants import GRADIO_THEME -from common.log import logger +from style_bert_vits2.constants import GRADIO_THEME +from style_bert_vits2.logging import logger from common.subprocess_utils import run_script_with_log # Get path settings diff --git a/webui_merge.py b/webui_merge.py index a58471a80..0a3990258 100644 --- a/webui_merge.py +++ b/webui_merge.py @@ -11,8 +11,8 @@ from safetensors import safe_open from safetensors.torch import save_file -from common.constants import DEFAULT_STYLE, GRADIO_THEME -from common.log import logger +from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME +from style_bert_vits2.logging import logger from common.tts_model import Model, ModelHolder voice_keys = ["dec"] diff --git a/webui_style_vectors.py b/webui_style_vectors.py index b89c9c9f8..cf53ca22a 100644 --- a/webui_style_vectors.py +++ b/webui_style_vectors.py @@ -12,8 +12,8 @@ from sklearn.manifold import TSNE from umap import UMAP -from common.constants import DEFAULT_STYLE, GRADIO_THEME -from common.log import logger +from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME +from style_bert_vits2.logging import logger from config import config # Get path settings diff --git a/webui_train.py b/webui_train.py index 59cc9f594..0f272e835 100644 --- a/webui_train.py +++ b/webui_train.py @@ -14,9 +14,9 @@ import gradio as gr import yaml -from common.constants import GRADIO_THEME, LATEST_VERSION -from common.log import logger -from common.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.constants import GRADIO_THEME, VERSION +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from common.subprocess_utils import run_script_with_log, second_elem_of logger_handler = None @@ -399,7 +399,7 @@ def run_tensorboard(model_name): initial_md = f""" -# Style-Bert-VITS2 ver {LATEST_VERSION} 学習用WebUI +# Style-Bert-VITS2 ver {VERSION} 学習用WebUI ## 使い方 From 1936344c0c6af8213b4327cb964b32dbd20130ba Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 22:51:25 +0000 Subject: [PATCH 016/148] Refactor: remove old code that can be deleted and update where modules are imported --- app.py | 5 +- infer.py | 6 +- models.py | 8 +- models_jp_extra.py | 8 +- style_bert_vits2/text_processing/symbols.py | 2 +- text/__init__.py | 8 +- text/chinese.py | 8 +- text/english.py | 13 +- text/japanese.py | 34 ++- text/japanese_bert.py | 6 +- text/japanese_mora_list.py | 232 -------------------- text/symbols.py | 187 ---------------- train_ms.py | 4 +- train_ms_jp_extra.py | 4 +- 14 files changed, 52 insertions(+), 473 deletions(-) delete mode 100644 text/japanese_mora_list.py delete mode 100644 text/symbols.py diff --git a/app.py b/app.py index acdb64655..c70a89dbf 100644 --- a/app.py +++ b/app.py @@ -27,7 +27,8 @@ from style_bert_vits2.logging import logger from common.tts_model import ModelHolder from infer import InvalidToneError -from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize +from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone +from style_bert_vits2.text_processing.japanese.normalizer import normalize_text # Get path settings with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: @@ -131,7 +132,7 @@ def tts_fn( if tone is None and language == "JP": # アクセント指定に使えるようにアクセント情報を返す - norm_text = text_normalize(text) + norm_text = normalize_text(text) kata_tone = g2kata_tone(norm_text) kata_tone_json_str = json.dumps(kata_tone, ensure_ascii=False) elif tone is None: diff --git a/infer.py b/infer.py index 525219d35..4afd048fd 100644 --- a/infer.py +++ b/infer.py @@ -6,7 +6,7 @@ from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra from text import cleaned_text_to_sequence, get_bert from text.cleaner import clean_text -from text.symbols import symbols +from style_bert_vits2.text_processing.symbols import SYMBOLS from style_bert_vits2.logging import logger @@ -18,7 +18,7 @@ def get_net_g(model_path: str, version: str, device: str, hps): if version.endswith("JP-Extra"): logger.info("Using JP-Extra model") net_g = SynthesizerTrnJPExtra( - len(symbols), + len(SYMBOLS), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, @@ -27,7 +27,7 @@ def get_net_g(model_path: str, version: str, device: str, hps): else: logger.info("Using normal model") net_g = SynthesizerTrn( - len(symbols), + len(SYMBOLS), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, diff --git a/models.py b/models.py index 501fcaf3d..ef581f291 100644 --- a/models.py +++ b/models.py @@ -12,7 +12,7 @@ import modules import monotonic_align from style_bert_vits2.models.commons import get_padding, init_weights -from text import num_languages, num_tones, symbols +from style_bert_vits2.text_processing.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS class DurationDiscriminator(nn.Module): # vits2 @@ -334,11 +334,11 @@ def __init__( self.kernel_size = kernel_size self.p_dropout = p_dropout self.gin_channels = gin_channels - self.emb = nn.Embedding(len(symbols), hidden_channels) + self.emb = nn.Embedding(len(SYMBOLS), hidden_channels) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - self.tone_emb = nn.Embedding(num_tones, hidden_channels) + self.tone_emb = nn.Embedding(NUM_TONES, hidden_channels) nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5) - self.language_emb = nn.Embedding(num_languages, hidden_channels) + self.language_emb = nn.Embedding(NUM_LANGUAGES, hidden_channels) nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5) self.bert_proj = nn.Conv1d(1024, hidden_channels, 1) self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1) diff --git a/models_jp_extra.py b/models_jp_extra.py index 3e87ceda2..16cacc7cd 100644 --- a/models_jp_extra.py +++ b/models_jp_extra.py @@ -12,7 +12,7 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm from style_bert_vits2.models.commons import init_weights, get_padding -from text import symbols, num_tones, num_languages +from style_bert_vits2.text_processing.symbols import SYMBOLS, NUM_TONES, NUM_LANGUAGES class DurationDiscriminator(nn.Module): # vits2 @@ -353,11 +353,11 @@ def __init__( self.kernel_size = kernel_size self.p_dropout = p_dropout self.gin_channels = gin_channels - self.emb = nn.Embedding(len(symbols), hidden_channels) + self.emb = nn.Embedding(len(SYMBOLS), hidden_channels) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - self.tone_emb = nn.Embedding(num_tones, hidden_channels) + self.tone_emb = nn.Embedding(NUM_TONES, hidden_channels) nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5) - self.language_emb = nn.Embedding(num_languages, hidden_channels) + self.language_emb = nn.Embedding(NUM_LANGUAGES, hidden_channels) nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5) self.bert_proj = nn.Conv1d(1024, hidden_channels, 1) diff --git a/style_bert_vits2/text_processing/symbols.py b/style_bert_vits2/text_processing/symbols.py index 628edc61a..d69bc1c42 100644 --- a/style_bert_vits2/text_processing/symbols.py +++ b/style_bert_vits2/text_processing/symbols.py @@ -174,7 +174,7 @@ SIL_PHONEMES_IDS = [SYMBOLS.index(i) for i in PUNCTUATION_SYMBOLS] # Combine all tones -num_tones = NUM_ZH_TONES + NUM_JA_TONES + NUM_EN_TONES +NUM_TONES = NUM_ZH_TONES + NUM_JA_TONES + NUM_EN_TONES # Language maps LANGUAGE_ID_MAP = {"ZH": 0, "JP": 1, "EN": 2} diff --git a/text/__init__.py b/text/__init__.py index d8ae88dea..ce4c008c1 100644 --- a/text/__init__.py +++ b/text/__init__.py @@ -1,6 +1,6 @@ -from text.symbols import * +from style_bert_vits2.text_processing.symbols import * -_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)} def cleaned_text_to_sequence(cleaned_text, tones, language): @@ -11,9 +11,9 @@ def cleaned_text_to_sequence(cleaned_text, tones, language): List of integers corresponding to the symbols in the text """ phones = [_symbol_to_id[symbol] for symbol in cleaned_text] - tone_start = language_tone_start_map[language] + tone_start = LANGUAGE_TONE_START_MAP[language] tones = [i + tone_start for i in tones] - lang_id = language_id_map[language] + lang_id = LANGUAGE_ID_MAP[language] lang_ids = [lang_id for i in phones] return phones, tones, lang_ids diff --git a/text/chinese.py b/text/chinese.py index d9174ee0e..56dc4f34f 100644 --- a/text/chinese.py +++ b/text/chinese.py @@ -4,7 +4,7 @@ import cn2an from pypinyin import lazy_pinyin, Style -from text.symbols import punctuation +from style_bert_vits2.text_processing.symbols import PUNCTUATIONS from text.tone_sandhi import ToneSandhi current_file_path = os.path.dirname(__file__) @@ -60,14 +60,14 @@ def replace_punctuation(text): replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) replaced_text = re.sub( - r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text + r"[^\u4e00-\u9fa5" + "".join(PUNCTUATIONS) + r"]+", "", replaced_text ) return replaced_text def g2p(text): - pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) + pattern = r"(?<=[{0}])\s*".format("".join(PUNCTUATIONS)) sentences = [i for i in re.split(pattern, text) if i.strip() != ""] phones, tones, word2ph = _g2p(sentences) assert sum(word2ph) == len(phones) @@ -119,7 +119,7 @@ def _g2p(segments): # NOTE: post process for pypinyin outputs # we discriminate i, ii and iii if c == v: - assert c in punctuation + assert c in PUNCTUATIONS phone = [c] tone = "0" word2ph.append(1) diff --git a/text/english.py b/text/english.py index 4a2af9523..f38ee84ae 100644 --- a/text/english.py +++ b/text/english.py @@ -4,8 +4,7 @@ from g2p_en import G2p from transformers import DebertaV2Tokenizer -from text import symbols -from text.symbols import punctuation +from style_bert_vits2.text_processing.symbols import PUNCTUATIONS, SYMBOLS current_file_path = os.path.dirname(__file__) CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep") @@ -107,9 +106,9 @@ def post_replace_ph(ph): } if ph in rep_map.keys(): ph = rep_map[ph] - if ph in symbols: + if ph in SYMBOLS: return ph - if ph not in symbols: + if ph not in SYMBOLS: ph = "UNK" return ph @@ -399,13 +398,13 @@ def text_to_words(text): if t.startswith("▁"): words.append([t[1:]]) else: - if t in punctuation: + if t in PUNCTUATIONS: if idx == len(tokens) - 1: words.append([f"{t}"]) else: if ( not tokens[idx + 1].startswith("▁") - and tokens[idx + 1] not in punctuation + and tokens[idx + 1] not in PUNCTUATIONS ): if idx == 0: words.append([]) @@ -433,7 +432,7 @@ def g2p(text): if "'" in word: word = ["".join(word)] for w in word: - if w in punctuation: + if w in PUNCTUATIONS: temp_phones.append(w) temp_tones.append(0) continue diff --git a/text/japanese.py b/text/japanese.py index 12dc349a8..0dc6aa83f 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -2,20 +2,18 @@ # compatible with Julius https://github.com/julius-speech/segmentation-kit import re import unicodedata -from pathlib import Path import pyopenjtalk from num2words import num2words from transformers import AutoTokenizer from style_bert_vits2.logging import logger -from text import punctuation -from text.japanese_mora_list import ( - mora_kata_to_mora_phonemes, - mora_phonemes_to_mora_kata, +from style_bert_vits2.text_processing.japanese.mora_list import ( + MORA_KATA_TO_MORA_PHONEMES, + MORA_PHONEMES_TO_MORA_KATA, ) - from style_bert_vits2.text_processing.japanese.user_dict import update_dict +from style_bert_vits2.text_processing.symbols import PUNCTUATIONS # 最初にpyopenjtalkの辞書を更新 update_dict() @@ -24,7 +22,7 @@ COSONANTS = set( [ cosonant - for cosonant, _ in mora_kata_to_mora_phonemes.values() + for cosonant, _ in MORA_KATA_TO_MORA_PHONEMES.values() if cosonant is not None ] ) @@ -153,7 +151,7 @@ def replace_punctuation(text: str) -> str: # ↓ ギリシャ文字 + r"\u0370-\u03FF\u1F00-\u1FFF" # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている - + "".join(punctuation) + r"]+", + + "".join(PUNCTUATIONS) + r"]+", # 上述以外の文字を削除 "", replaced_text, @@ -220,7 +218,7 @@ def g2p( # sep_textから、各単語を1文字1文字分割して、文字のリスト(のリスト)を作る sep_tokenized: list[list[str]] = [] for i in sep_text: - if i not in punctuation: + if i not in PUNCTUATIONS: sep_tokenized.append( tokenizer.tokenize(i) ) # ここでおそらく`i`が文字単位に分割される @@ -268,7 +266,7 @@ def phone_tone2kata_tone(phone_tone: list[tuple[str, int]]) -> list[tuple[str, i current_mora = "" for phone, next_phone, tone, next_tone in zip(phones, phones[1:], tones, tones[1:]): # zipの関係で最後の("_", 0)は無視されている - if phone in punctuation: + if phone in PUNCTUATIONS: result.append((phone, tone)) continue if phone in COSONANTS: # n以外の子音の場合 @@ -278,7 +276,7 @@ def phone_tone2kata_tone(phone_tone: list[tuple[str, int]]) -> list[tuple[str, i else: # phoneが母音もしくは「N」 current_mora += phone - result.append((mora_phonemes_to_mora_kata[current_mora], tone)) + result.append((MORA_PHONEMES_TO_MORA_KATA[current_mora], tone)) current_mora = "" return result @@ -287,10 +285,10 @@ def kata_tone2phone_tone(kata_tone: list[tuple[str, int]]) -> list[tuple[str, in """`phone_tone2kata_tone()`の逆。""" result: list[tuple[str, int]] = [("_", 0)] for mora, tone in kata_tone: - if mora in punctuation: + if mora in PUNCTUATIONS: result.append((mora, tone)) else: - cosonant, vowel = mora_kata_to_mora_phonemes[mora] + cosonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora] if cosonant is None: result.append((vowel, tone)) else: @@ -387,7 +385,7 @@ def text2sep_kata( assert yomi != "", f"Empty yomi: {word}" if yomi == "、": # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`, `--` のいずれか - if not set(word).issubset(set(punctuation)): # 記号繰り返しか判定 + if not set(word).issubset(set(PUNCTUATIONS)): # 記号繰り返しか判定 # ここはpyopenjtalkが読めない文字等のときに起こる if raise_yomi_error: raise YomiError(f"Cannot read: {word} in:\n{norm_text}") @@ -581,7 +579,7 @@ def align_tones( result.append((phone, phone_tone_list[tone_index][1])) # 探すindexを1つ進める tone_index += 1 - elif phone in punctuation: + elif phone in PUNCTUATIONS: # phoneがpunctuationの場合 → (phone, 0)を追加 result.append((phone, 0)) else: @@ -606,16 +604,16 @@ def kata2phoneme_list(text: str) -> list[str]: `?` → ["?"] `!?!?!?!?!` → ["!", "?", "!", "?", "!", "?", "!", "?", "!"] """ - if set(text).issubset(set(punctuation)): + if set(text).issubset(set(PUNCTUATIONS)): return list(text) # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None: raise ValueError(f"Input must be katakana only: {text}") - sorted_keys = sorted(mora_kata_to_mora_phonemes.keys(), key=len, reverse=True) + sorted_keys = sorted(MORA_KATA_TO_MORA_PHONEMES.keys(), key=len, reverse=True) pattern = "|".join(map(re.escape, sorted_keys)) def mora2phonemes(mora: str) -> str: - cosonant, vowel = mora_kata_to_mora_phonemes[mora] + cosonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora] if cosonant is None: return f" {vowel}" return f" {cosonant} {vowel}" diff --git a/text/japanese_bert.py b/text/japanese_bert.py index dcee0f3d2..fbeb94d6c 100644 --- a/text/japanese_bert.py +++ b/text/japanese_bert.py @@ -4,7 +4,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer from config import config -from text.japanese import text2sep_kata, text_normalize +from style_bert_vits2.text_processing.japanese.g2p import text_to_sep_kata LOCAL_PATH = "./bert/deberta-v2-large-japanese-char-wwm" @@ -22,10 +22,10 @@ def get_bert_feature( ): # 各単語が何文字かを作る`word2ph`を使う必要があるので、読めない文字は必ず無視する # でないと`word2ph`の結果とテキストの文字数結果が整合性が取れない - text = "".join(text2sep_kata(text, raise_yomi_error=False)[0]) + text = "".join(text_to_sep_kata(text, raise_yomi_error=False)[0]) if assist_text: - assist_text = "".join(text2sep_kata(assist_text, raise_yomi_error=False)[0]) + assist_text = "".join(text_to_sep_kata(assist_text, raise_yomi_error=False)[0]) if ( sys.platform == "darwin" and torch.backends.mps.is_available() diff --git a/text/japanese_mora_list.py b/text/japanese_mora_list.py deleted file mode 100644 index b43e54d8d..000000000 --- a/text/japanese_mora_list.py +++ /dev/null @@ -1,232 +0,0 @@ -""" -VOICEVOXのソースコードからお借りして最低限に改造したコード。 -https://github.com/VOICEVOX/voicevox_engine/blob/master/voicevox_engine/tts_pipeline/mora_list.py -""" - -""" -以下のモーラ対応表はOpenJTalkのソースコードから取得し、 -カタカナ表記とモーラが一対一対応するように改造した。 -ライセンス表記: ------------------------------------------------------------------ - The Japanese TTS System "Open JTalk" - developed by HTS Working Group - http://open-jtalk.sourceforge.net/ ------------------------------------------------------------------ - - Copyright (c) 2008-2014 Nagoya Institute of Technology - Department of Computer Science - -All rights reserved. - -Redistribution and use in source and binary forms, with or -without modification, are permitted provided that the following -conditions are met: - -- Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. -- Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. -- Neither the name of the HTS working group nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND -CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, -INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS -BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED -TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. -""" -from typing import Optional - -# (カタカナ, 子音, 母音)の順。子音がない場合はNoneを入れる。 -# 但し「ン」と「ッ」は母音のみという扱いで、「ン」は「N」、「ッ」は「q」とする。 -# (元々「ッ」は「cl」) -# また「デェ = dy e」はpyopenjtalkの出力(de e)と合わないため削除 -_mora_list_minimum: list[tuple[str, Optional[str], str]] = [ - ("ヴォ", "v", "o"), - ("ヴェ", "v", "e"), - ("ヴィ", "v", "i"), - ("ヴァ", "v", "a"), - ("ヴ", "v", "u"), - ("ン", None, "N"), - ("ワ", "w", "a"), - ("ロ", "r", "o"), - ("レ", "r", "e"), - ("ル", "r", "u"), - ("リョ", "ry", "o"), - ("リュ", "ry", "u"), - ("リャ", "ry", "a"), - ("リェ", "ry", "e"), - ("リ", "r", "i"), - ("ラ", "r", "a"), - ("ヨ", "y", "o"), - ("ユ", "y", "u"), - ("ヤ", "y", "a"), - ("モ", "m", "o"), - ("メ", "m", "e"), - ("ム", "m", "u"), - ("ミョ", "my", "o"), - ("ミュ", "my", "u"), - ("ミャ", "my", "a"), - ("ミェ", "my", "e"), - ("ミ", "m", "i"), - ("マ", "m", "a"), - ("ポ", "p", "o"), - ("ボ", "b", "o"), - ("ホ", "h", "o"), - ("ペ", "p", "e"), - ("ベ", "b", "e"), - ("ヘ", "h", "e"), - ("プ", "p", "u"), - ("ブ", "b", "u"), - ("フォ", "f", "o"), - ("フェ", "f", "e"), - ("フィ", "f", "i"), - ("ファ", "f", "a"), - ("フ", "f", "u"), - ("ピョ", "py", "o"), - ("ピュ", "py", "u"), - ("ピャ", "py", "a"), - ("ピェ", "py", "e"), - ("ピ", "p", "i"), - ("ビョ", "by", "o"), - ("ビュ", "by", "u"), - ("ビャ", "by", "a"), - ("ビェ", "by", "e"), - ("ビ", "b", "i"), - ("ヒョ", "hy", "o"), - ("ヒュ", "hy", "u"), - ("ヒャ", "hy", "a"), - ("ヒェ", "hy", "e"), - ("ヒ", "h", "i"), - ("パ", "p", "a"), - ("バ", "b", "a"), - ("ハ", "h", "a"), - ("ノ", "n", "o"), - ("ネ", "n", "e"), - ("ヌ", "n", "u"), - ("ニョ", "ny", "o"), - ("ニュ", "ny", "u"), - ("ニャ", "ny", "a"), - ("ニェ", "ny", "e"), - ("ニ", "n", "i"), - ("ナ", "n", "a"), - ("ドゥ", "d", "u"), - ("ド", "d", "o"), - ("トゥ", "t", "u"), - ("ト", "t", "o"), - ("デョ", "dy", "o"), - ("デュ", "dy", "u"), - ("デャ", "dy", "a"), - # ("デェ", "dy", "e"), - ("ディ", "d", "i"), - ("デ", "d", "e"), - ("テョ", "ty", "o"), - ("テュ", "ty", "u"), - ("テャ", "ty", "a"), - ("ティ", "t", "i"), - ("テ", "t", "e"), - ("ツォ", "ts", "o"), - ("ツェ", "ts", "e"), - ("ツィ", "ts", "i"), - ("ツァ", "ts", "a"), - ("ツ", "ts", "u"), - ("ッ", None, "q"), # 「cl」から「q」に変更 - ("チョ", "ch", "o"), - ("チュ", "ch", "u"), - ("チャ", "ch", "a"), - ("チェ", "ch", "e"), - ("チ", "ch", "i"), - ("ダ", "d", "a"), - ("タ", "t", "a"), - ("ゾ", "z", "o"), - ("ソ", "s", "o"), - ("ゼ", "z", "e"), - ("セ", "s", "e"), - ("ズィ", "z", "i"), - ("ズ", "z", "u"), - ("スィ", "s", "i"), - ("ス", "s", "u"), - ("ジョ", "j", "o"), - ("ジュ", "j", "u"), - ("ジャ", "j", "a"), - ("ジェ", "j", "e"), - ("ジ", "j", "i"), - ("ショ", "sh", "o"), - ("シュ", "sh", "u"), - ("シャ", "sh", "a"), - ("シェ", "sh", "e"), - ("シ", "sh", "i"), - ("ザ", "z", "a"), - ("サ", "s", "a"), - ("ゴ", "g", "o"), - ("コ", "k", "o"), - ("ゲ", "g", "e"), - ("ケ", "k", "e"), - ("グヮ", "gw", "a"), - ("グ", "g", "u"), - ("クヮ", "kw", "a"), - ("ク", "k", "u"), - ("ギョ", "gy", "o"), - ("ギュ", "gy", "u"), - ("ギャ", "gy", "a"), - ("ギェ", "gy", "e"), - ("ギ", "g", "i"), - ("キョ", "ky", "o"), - ("キュ", "ky", "u"), - ("キャ", "ky", "a"), - ("キェ", "ky", "e"), - ("キ", "k", "i"), - ("ガ", "g", "a"), - ("カ", "k", "a"), - ("オ", None, "o"), - ("エ", None, "e"), - ("ウォ", "w", "o"), - ("ウェ", "w", "e"), - ("ウィ", "w", "i"), - ("ウ", None, "u"), - ("イェ", "y", "e"), - ("イ", None, "i"), - ("ア", None, "a"), -] -_mora_list_additional: list[tuple[str, Optional[str], str]] = [ - ("ヴョ", "by", "o"), - ("ヴュ", "by", "u"), - ("ヴャ", "by", "a"), - ("ヲ", None, "o"), - ("ヱ", None, "e"), - ("ヰ", None, "i"), - ("ヮ", "w", "a"), - ("ョ", "y", "o"), - ("ュ", "y", "u"), - ("ヅ", "z", "u"), - ("ヂ", "j", "i"), - ("ヶ", "k", "e"), - ("ャ", "y", "a"), - ("ォ", None, "o"), - ("ェ", None, "e"), - ("ゥ", None, "u"), - ("ィ", None, "i"), - ("ァ", None, "a"), -] - -# 例: "vo" -> "ヴォ", "a" -> "ア" -mora_phonemes_to_mora_kata: dict[str, str] = { - (consonant or "") + vowel: kana for [kana, consonant, vowel] in _mora_list_minimum -} - -# 例: "ヴォ" -> ("v", "o"), "ア" -> (None, "a") -mora_kata_to_mora_phonemes: dict[str, tuple[Optional[str], str]] = { - kana: (consonant, vowel) - for [kana, consonant, vowel] in _mora_list_minimum + _mora_list_additional -} diff --git a/text/symbols.py b/text/symbols.py deleted file mode 100644 index 846de6458..000000000 --- a/text/symbols.py +++ /dev/null @@ -1,187 +0,0 @@ -punctuation = ["!", "?", "…", ",", ".", "'", "-"] -pu_symbols = punctuation + ["SP", "UNK"] -pad = "_" - -# chinese -zh_symbols = [ - "E", - "En", - "a", - "ai", - "an", - "ang", - "ao", - "b", - "c", - "ch", - "d", - "e", - "ei", - "en", - "eng", - "er", - "f", - "g", - "h", - "i", - "i0", - "ia", - "ian", - "iang", - "iao", - "ie", - "in", - "ing", - "iong", - "ir", - "iu", - "j", - "k", - "l", - "m", - "n", - "o", - "ong", - "ou", - "p", - "q", - "r", - "s", - "sh", - "t", - "u", - "ua", - "uai", - "uan", - "uang", - "ui", - "un", - "uo", - "v", - "van", - "ve", - "vn", - "w", - "x", - "y", - "z", - "zh", - "AA", - "EE", - "OO", -] -num_zh_tones = 6 - -# japanese -ja_symbols = [ - "N", - "a", - "a:", - "b", - "by", - "ch", - "d", - "dy", - "e", - "e:", - "f", - "g", - "gy", - "h", - "hy", - "i", - "i:", - "j", - "k", - "ky", - "m", - "my", - "n", - "ny", - "o", - "o:", - "p", - "py", - "q", - "r", - "ry", - "s", - "sh", - "t", - "ts", - "ty", - "u", - "u:", - "w", - "y", - "z", - "zy", -] -num_ja_tones = 2 - -# English -en_symbols = [ - "aa", - "ae", - "ah", - "ao", - "aw", - "ay", - "b", - "ch", - "d", - "dh", - "eh", - "er", - "ey", - "f", - "g", - "hh", - "ih", - "iy", - "jh", - "k", - "l", - "m", - "n", - "ng", - "ow", - "oy", - "p", - "r", - "s", - "sh", - "t", - "th", - "uh", - "uw", - "V", - "w", - "y", - "z", - "zh", -] -num_en_tones = 4 - -# combine all symbols -normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols)) -symbols = [pad] + normal_symbols + pu_symbols -sil_phonemes_ids = [symbols.index(i) for i in pu_symbols] - -# combine all tones -num_tones = num_zh_tones + num_ja_tones + num_en_tones - -# language maps -language_id_map = {"ZH": 0, "JP": 1, "EN": 2} -num_languages = len(language_id_map.keys()) - -language_tone_start_map = { - "ZH": 0, - "JP": num_zh_tones, - "EN": num_zh_tones + num_ja_tones, -} - -if __name__ == "__main__": - a = set(zh_symbols) - b = set(en_symbols) - print(sorted(a & b)) diff --git a/train_ms.py b/train_ms.py index 3e9e70723..783b15700 100644 --- a/train_ms.py +++ b/train_ms.py @@ -29,7 +29,7 @@ from losses import discriminator_loss, feature_loss, generator_loss, kl_loss from mel_processing import mel_spectrogram_torch, spec_to_mel_torch from models import DurationDiscriminator, MultiPeriodDiscriminator, SynthesizerTrn -from text.symbols import symbols +from style_bert_vits2.text_processing.symbols import SYMBOLS torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = ( @@ -279,7 +279,7 @@ def run(): logger.info("Using normal encoder for VITS1") net_g = SynthesizerTrn( - len(symbols), + len(SYMBOLS), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index 9a964c34a..e4e4e5681 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -34,7 +34,7 @@ SynthesizerTrn, WavLMDiscriminator, ) -from text.symbols import symbols +from style_bert_vits2.text_processing.symbols import SYMBOLS torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = ( @@ -293,7 +293,7 @@ def run(): logger.info("Using normal encoder for VITS1") net_g = SynthesizerTrn( - len(symbols), + len(SYMBOLS), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, From ca4c03c67bbf838eec27ca6811c90a5b9d9a5ec2 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 23:03:06 +0000 Subject: [PATCH 017/148] Fix: import error --- common/subprocess_utils.py | 4 ++-- common/tts_model.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/common/subprocess_utils.py b/common/subprocess_utils.py index 40426f765..d7f3fc407 100644 --- a/common/subprocess_utils.py +++ b/common/subprocess_utils.py @@ -1,8 +1,8 @@ import subprocess import sys -from .log import logger -from .stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT python = sys.executable diff --git a/common/tts_model.py b/common/tts_model.py index e09787e7e..3f17d1529 100644 --- a/common/tts_model.py +++ b/common/tts_model.py @@ -1,11 +1,9 @@ -import os import warnings from pathlib import Path from typing import Optional, Union import gradio as gr import numpy as np - import torch from gradio.processing_utils import convert_to_16_bit_wav @@ -14,7 +12,7 @@ from models import SynthesizerTrn from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra -from .constants import ( +from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, DEFAULT_LINE_SPLIT, @@ -25,7 +23,7 @@ DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, ) -from .log import logger +from style_bert_vits2.logging import logger def adjust_voice(fs, wave, pitch_scale, intonation_scale): From a52fda7a88ad296831419993364f261559147939 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 23:11:27 +0000 Subject: [PATCH 018/148] Refactor: moved common/subprocess_utils.py to style_bert_vits2/utils/subprocess.py --- common/subprocess_utils.py | 33 ---------------- style_bert_vits2/utils/subprocess.py | 56 ++++++++++++++++++++++++++++ webui_dataset.py | 2 +- webui_train.py | 2 +- 4 files changed, 58 insertions(+), 35 deletions(-) delete mode 100644 common/subprocess_utils.py create mode 100644 style_bert_vits2/utils/subprocess.py diff --git a/common/subprocess_utils.py b/common/subprocess_utils.py deleted file mode 100644 index d7f3fc407..000000000 --- a/common/subprocess_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -import subprocess -import sys - -from style_bert_vits2.logging import logger -from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT - -python = sys.executable - - -def run_script_with_log(cmd: list[str], ignore_warning=False) -> tuple[bool, str]: - logger.info(f"Running: {' '.join(cmd)}") - result = subprocess.run( - [python] + cmd, - stdout=SAFE_STDOUT, # type: ignore - stderr=subprocess.PIPE, - text=True, - encoding="utf-8", - ) - if result.returncode != 0: - logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}") - return False, result.stderr - elif result.stderr and not ignore_warning: - logger.warning(f"Warning: {' '.join(cmd)}\n{result.stderr}") - return True, result.stderr - logger.success(f"Success: {' '.join(cmd)}") - return True, "" - - -def second_elem_of(original_function): - def inner_function(*args, **kwargs): - return original_function(*args, **kwargs)[1] - - return inner_function diff --git a/style_bert_vits2/utils/subprocess.py b/style_bert_vits2/utils/subprocess.py new file mode 100644 index 000000000..b152702ef --- /dev/null +++ b/style_bert_vits2/utils/subprocess.py @@ -0,0 +1,56 @@ +import subprocess +import sys +from typing import Any, Callable + +from style_bert_vits2.logging import logger +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + +PYTHON = sys.executable + + +def run_script_with_log(cmd: list[str], ignore_warning: bool = False) -> tuple[bool, str]: + """ + 指定されたコマンドを実行し、そのログを記録する + + Args: + cmd: 実行するコマンドのリスト + ignore_warning: 警告を無視するかどうかのフラグ + + Returns: + tuple[bool, str]: 実行が成功したかどうかのブール値と、エラーまたは警告のメッセージ(ある場合) + """ + + logger.info(f"Running: {' '.join(cmd)}") + result = subprocess.run( + [PYTHON] + cmd, + stdout = SAFE_STDOUT, + stderr = subprocess.PIPE, + text = True, + encoding = "utf-8", + ) + if result.returncode != 0: + logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}") + return False, result.stderr + elif result.stderr and not ignore_warning: + logger.warning(f"Warning: {' '.join(cmd)}\n{result.stderr}") + return True, result.stderr + logger.success(f"Success: {' '.join(cmd)}") + + return True, "" + + +def second_elem_of(original_function: Callable[..., tuple[Any, Any]]) -> Callable[..., Any]: + """ + 与えられた関数をラップし、その戻り値の 2 番目の要素のみを返す関数を生成する + + Args: + original_function (Callable[..., tuple[Any, Any]])): ラップする元の関数 + + Returns: + Callable[..., Any]: 元の関数の戻り値の 2 番目の要素のみを返す関数 + """ + + def inner_function(*args, **kwargs) -> Any: # type: ignore + return original_function(*args, **kwargs)[1] + + return inner_function diff --git a/webui_dataset.py b/webui_dataset.py index 169cc4dbf..3ad63c163 100644 --- a/webui_dataset.py +++ b/webui_dataset.py @@ -6,7 +6,7 @@ from style_bert_vits2.constants import GRADIO_THEME from style_bert_vits2.logging import logger -from common.subprocess_utils import run_script_with_log +from style_bert_vits2.utils.subprocess import run_script_with_log # Get path settings with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: diff --git a/webui_train.py b/webui_train.py index 0f272e835..6c89b261f 100644 --- a/webui_train.py +++ b/webui_train.py @@ -17,7 +17,7 @@ from style_bert_vits2.constants import GRADIO_THEME, VERSION from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -from common.subprocess_utils import run_script_with_log, second_elem_of +from style_bert_vits2.utils.subprocess import run_script_with_log, second_elem_of logger_handler = None tensorboard_executed = False From 89825e68d88586225d93b7486d866578f1e4990c Mon Sep 17 00:00:00 2001 From: tsukumi Date: Wed, 6 Mar 2024 23:43:25 +0000 Subject: [PATCH 019/148] Refactor: moved model, attentions definitions and inference code to style_bert_vits2/models/ The code has not yet been cleaned up, just moved. --- app.py | 2 +- common/tts_model.py | 7 ++--- .../models/attentions.py | 3 +- infer.py => style_bert_vits2/models/infer.py | 8 +++--- .../models/models.py | 7 ++--- .../models/models_jp_extra.py | 26 ++++++++--------- .../models/modules.py | 28 +++++++++---------- train_ms.py | 13 +++++---- train_ms_jp_extra.py | 10 +++---- webui.py | 11 ++++---- 10 files changed, 56 insertions(+), 59 deletions(-) rename attentions.py => style_bert_vits2/models/attentions.py (99%) rename infer.py => style_bert_vits2/models/infer.py (95%) rename models.py => style_bert_vits2/models/models.py (99%) rename models_jp_extra.py => style_bert_vits2/models/models_jp_extra.py (98%) rename modules.py => style_bert_vits2/models/modules.py (95%) diff --git a/app.py b/app.py index c70a89dbf..a0a215cdd 100644 --- a/app.py +++ b/app.py @@ -26,7 +26,7 @@ ) from style_bert_vits2.logging import logger from common.tts_model import ModelHolder -from infer import InvalidToneError +from style_bert_vits2.models.infer import InvalidToneError from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.text_processing.japanese.normalizer import normalize_text diff --git a/common/tts_model.py b/common/tts_model.py index 3f17d1529..c924686f1 100644 --- a/common/tts_model.py +++ b/common/tts_model.py @@ -8,10 +8,6 @@ from gradio.processing_utils import convert_to_16_bit_wav import utils -from infer import get_net_g, infer -from models import SynthesizerTrn -from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra - from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, @@ -23,6 +19,9 @@ DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, ) +from style_bert_vits2.models.infer import get_net_g, infer +from style_bert_vits2.models.models import SynthesizerTrn +from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra from style_bert_vits2.logging import logger diff --git a/attentions.py b/style_bert_vits2/models/attentions.py similarity index 99% rename from attentions.py rename to style_bert_vits2/models/attentions.py index 1cca086ba..6d43e0864 100644 --- a/attentions.py +++ b/style_bert_vits2/models/attentions.py @@ -4,7 +4,6 @@ from torch.nn import functional as F from style_bert_vits2.models import commons -from style_bert_vits2.logging import logger as logging class LayerNorm(nn.Module): @@ -67,7 +66,7 @@ def __init__( self.cond_layer_idx = ( kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 ) - # logging.debug(self.gin_channels, self.cond_layer_idx) + # logger.debug(self.gin_channels, self.cond_layer_idx) assert ( self.cond_layer_idx < self.n_layers ), "cond_layer_idx should be less than n_layers" diff --git a/infer.py b/style_bert_vits2/models/infer.py similarity index 95% rename from infer.py rename to style_bert_vits2/models/infer.py index 4afd048fd..9abd3782b 100644 --- a/infer.py +++ b/style_bert_vits2/models/infer.py @@ -1,13 +1,13 @@ import torch -from style_bert_vits2.models import commons import utils -from models import SynthesizerTrn -from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra from text import cleaned_text_to_sequence, get_bert from text.cleaner import clean_text -from style_bert_vits2.text_processing.symbols import SYMBOLS from style_bert_vits2.logging import logger +from style_bert_vits2.models import commons +from style_bert_vits2.models.models import SynthesizerTrn +from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra +from style_bert_vits2.text_processing.symbols import SYMBOLS class InvalidToneError(ValueError): diff --git a/models.py b/style_bert_vits2/models/models.py similarity index 99% rename from models.py rename to style_bert_vits2/models/models.py index ef581f291..a8c669596 100644 --- a/models.py +++ b/style_bert_vits2/models/models.py @@ -1,5 +1,4 @@ import math -import warnings import torch from torch import nn @@ -7,10 +6,10 @@ from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm -import attentions -from style_bert_vits2.models import commons -import modules import monotonic_align +from style_bert_vits2.models import attentions +from style_bert_vits2.models import commons +from style_bert_vits2.models import modules from style_bert_vits2.models.commons import get_padding, init_weights from style_bert_vits2.text_processing.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS diff --git a/models_jp_extra.py b/style_bert_vits2/models/models_jp_extra.py similarity index 98% rename from models_jp_extra.py rename to style_bert_vits2/models/models_jp_extra.py index 16cacc7cd..8c4a4ecf0 100644 --- a/models_jp_extra.py +++ b/style_bert_vits2/models/models_jp_extra.py @@ -1,17 +1,15 @@ import math + import torch from torch import nn +from torch.nn import Conv1d, Conv2d, ConvTranspose1d from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm -from style_bert_vits2.models import commons -import modules -import attentions import monotonic_align - -from torch.nn import Conv1d, ConvTranspose1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm - -from style_bert_vits2.models.commons import init_weights, get_padding +from style_bert_vits2.models import attentions +from style_bert_vits2.models import commons +from style_bert_vits2.models import modules from style_bert_vits2.text_processing.symbols import SYMBOLS, NUM_TONES, NUM_LANGUAGES @@ -529,7 +527,7 @@ def __init__( self.resblocks.append(resblock(ch, k, d)) self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) + self.ups.apply(commons.init_weights) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) @@ -577,7 +575,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 32, (kernel_size, 1), (stride, 1), - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), norm_f( @@ -586,7 +584,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 128, (kernel_size, 1), (stride, 1), - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), norm_f( @@ -595,7 +593,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 512, (kernel_size, 1), (stride, 1), - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), norm_f( @@ -604,7 +602,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 1024, (kernel_size, 1), (stride, 1), - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), norm_f( @@ -613,7 +611,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 1024, (kernel_size, 1), 1, - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), ] diff --git a/modules.py b/style_bert_vits2/models/modules.py similarity index 95% rename from modules.py rename to style_bert_vits2/models/modules.py index 68b0b9a67..e0885c4b8 100644 --- a/modules.py +++ b/style_bert_vits2/models/modules.py @@ -1,16 +1,14 @@ import math -import warnings import torch from torch import nn from torch.nn import Conv1d from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, weight_norm +from transforms import piecewise_rational_quadratic_transform from style_bert_vits2.models import commons -from attentions import Encoder -from style_bert_vits2.models.commons import get_padding, init_weights -from transforms import piecewise_rational_quadratic_transform +from style_bert_vits2.models.attentions import Encoder LRELU_SLOPE = 0.1 @@ -231,7 +229,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): kernel_size, 1, dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), + padding=commons.get_padding(kernel_size, dilation[0]), ) ), weight_norm( @@ -241,7 +239,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): kernel_size, 1, dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), + padding=commons.get_padding(kernel_size, dilation[1]), ) ), weight_norm( @@ -251,12 +249,12 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): kernel_size, 1, dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), + padding=commons.get_padding(kernel_size, dilation[2]), ) ), ] ) - self.convs1.apply(init_weights) + self.convs1.apply(commons.init_weights) self.convs2 = nn.ModuleList( [ @@ -267,7 +265,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1), + padding=commons.get_padding(kernel_size, 1), ) ), weight_norm( @@ -277,7 +275,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1), + padding=commons.get_padding(kernel_size, 1), ) ), weight_norm( @@ -287,12 +285,12 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): kernel_size, 1, dilation=1, - padding=get_padding(kernel_size, 1), + padding=commons.get_padding(kernel_size, 1), ) ), ] ) - self.convs2.apply(init_weights) + self.convs2.apply(commons.init_weights) def forward(self, x, x_mask=None): for c1, c2 in zip(self.convs1, self.convs2): @@ -328,7 +326,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3)): kernel_size, 1, dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), + padding=commons.get_padding(kernel_size, dilation[0]), ) ), weight_norm( @@ -338,12 +336,12 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3)): kernel_size, 1, dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), + padding=commons.get_padding(kernel_size, dilation[1]), ) ), ] ) - self.convs.apply(init_weights) + self.convs.apply(commons.init_weights) def forward(self, x, x_mask=None): for c in self.convs: diff --git a/train_ms.py b/train_ms.py index 783b15700..05cd4cb02 100644 --- a/train_ms.py +++ b/train_ms.py @@ -1,6 +1,5 @@ import argparse import datetime -import gc import os import platform @@ -15,11 +14,8 @@ from tqdm import tqdm # logging.getLogger("numba").setLevel(logging.WARNING) -from style_bert_vits2.models import commons import default_style import utils -from style_bert_vits2.logging import logger -from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config from data_utils import ( DistributedBucketSampler, @@ -28,8 +24,15 @@ ) from losses import discriminator_loss, feature_loss, generator_loss, kl_loss from mel_processing import mel_spectrogram_torch, spec_to_mel_torch -from models import DurationDiscriminator, MultiPeriodDiscriminator, SynthesizerTrn +from style_bert_vits2.logging import logger +from style_bert_vits2.models import commons +from style_bert_vits2.models.models import ( + DurationDiscriminator, + MultiPeriodDiscriminator, + SynthesizerTrn, +) from style_bert_vits2.text_processing.symbols import SYMBOLS +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = ( diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index e4e4e5681..1a287d9c2 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -6,20 +6,17 @@ import torch import torch.distributed as dist +from huggingface_hub import HfApi from torch.cuda.amp import GradScaler, autocast from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from huggingface_hub import HfApi # logging.getLogger("numba").setLevel(logging.WARNING) -from style_bert_vits2.models import commons import default_style import utils -from style_bert_vits2.logging import logger -from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config from data_utils import ( DistributedBucketSampler, @@ -28,13 +25,16 @@ ) from losses import WavLMLoss, discriminator_loss, feature_loss, generator_loss, kl_loss from mel_processing import mel_spectrogram_torch, spec_to_mel_torch -from models_jp_extra import ( +from style_bert_vits2.logging import logger +from style_bert_vits2.models import commons +from style_bert_vits2.models.models_jp_extra import ( DurationDiscriminator, MultiPeriodDiscriminator, SynthesizerTrn, WavLMDiscriminator, ) from style_bert_vits2.text_processing.symbols import SYMBOLS +from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = ( diff --git a/webui.py b/webui.py index 90318a1e7..1f31e9483 100644 --- a/webui.py +++ b/webui.py @@ -19,15 +19,16 @@ logger = logging.getLogger(__name__) -import torch -import utils -from infer import infer, latest_version, get_net_g, infer_multilang import gradio as gr -import webbrowser +import librosa import numpy as np +import torch +import webbrowser + +import utils from config import config +from style_bert_vits2.models.infer import infer, latest_version, get_net_g, infer_multilang from tools.translate import translate -import librosa net_g = None From e826faf62e69552d5cbcbab079ab0f4ffe5d8972 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 00:24:28 +0000 Subject: [PATCH 020/148] Refactor: moved text/cleaner.py to style_bert_vits2/text_processing/ --- preprocess_text.py | 6 +- style_bert_vits2/models/infer.py | 633 ++++++++++---------- style_bert_vits2/text_processing/cleaner.py | 46 ++ text/chinese.py | 4 +- text/cleaner.py | 26 - text/english.py | 2 +- text/japanese.py | 8 +- 7 files changed, 375 insertions(+), 350 deletions(-) create mode 100644 style_bert_vits2/text_processing/cleaner.py delete mode 100644 text/cleaner.py diff --git a/preprocess_text.py b/preprocess_text.py index 126ba2c59..92e00b99b 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -7,10 +7,10 @@ import click from tqdm import tqdm +from config import config from style_bert_vits2.logging import logger +from style_bert_vits2.text_processing.cleaner import clean_text from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -from config import config -from text.cleaner import clean_text preprocess_text_config = config.preprocess_text_config @@ -72,7 +72,7 @@ def preprocess( utt, spk, language, text = line.strip().split("|") norm_text, phones, tones, word2ph = clean_text( text=text, - language=language, + language=language, # type: ignore use_jp_extra=use_jp_extra, raise_yomi_error=(yomi_error != "use"), ) diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index 9abd3782b..e0d869110 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -1,314 +1,319 @@ -import torch - -import utils -from text import cleaned_text_to_sequence, get_bert -from text.cleaner import clean_text -from style_bert_vits2.logging import logger -from style_bert_vits2.models import commons -from style_bert_vits2.models.models import SynthesizerTrn -from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra -from style_bert_vits2.text_processing.symbols import SYMBOLS - - -class InvalidToneError(ValueError): - pass - - -def get_net_g(model_path: str, version: str, device: str, hps): - if version.endswith("JP-Extra"): - logger.info("Using JP-Extra model") - net_g = SynthesizerTrnJPExtra( - len(SYMBOLS), - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **hps.model, - ).to(device) - else: - logger.info("Using normal model") - net_g = SynthesizerTrn( - len(SYMBOLS), - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **hps.model, - ).to(device) - net_g.state_dict() - _ = net_g.eval() - if model_path.endswith(".pth") or model_path.endswith(".pt"): - _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True) - elif model_path.endswith(".safetensors"): - _ = utils.load_safetensors(model_path, net_g, True) - else: - raise ValueError(f"Unknown model format: {model_path}") - return net_g - - -def get_text( - text, - language_str, - hps, - device, - assist_text=None, - assist_text_weight=0.7, - given_tone=None, -): - use_jp_extra = hps.version.endswith("JP-Extra") - # 推論のときにのみ呼び出されるので、raise_yomi_errorはFalseに設定 - norm_text, phone, tone, word2ph = clean_text( - text, language_str, use_jp_extra, raise_yomi_error=False - ) - if given_tone is not None: - if len(given_tone) != len(phone): - raise InvalidToneError( - f"Length of given_tone ({len(given_tone)}) != length of phone ({len(phone)})" - ) - tone = given_tone - phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) - - if hps.data.add_blank: - phone = commons.intersperse(phone, 0) - tone = commons.intersperse(tone, 0) - language = commons.intersperse(language, 0) - for i in range(len(word2ph)): - word2ph[i] = word2ph[i] * 2 - word2ph[0] += 1 - bert_ori = get_bert( - norm_text, - word2ph, - language_str, - device, - assist_text, - assist_text_weight, - ) - del word2ph - assert bert_ori.shape[-1] == len(phone), phone - - if language_str == "ZH": - bert = bert_ori - ja_bert = torch.zeros(1024, len(phone)) - en_bert = torch.zeros(1024, len(phone)) - elif language_str == "JP": - bert = torch.zeros(1024, len(phone)) - ja_bert = bert_ori - en_bert = torch.zeros(1024, len(phone)) - elif language_str == "EN": - bert = torch.zeros(1024, len(phone)) - ja_bert = torch.zeros(1024, len(phone)) - en_bert = bert_ori - else: - raise ValueError("language_str should be ZH, JP or EN") - - assert bert.shape[-1] == len( - phone - ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" - - phone = torch.LongTensor(phone) - tone = torch.LongTensor(tone) - language = torch.LongTensor(language) - return bert, ja_bert, en_bert, phone, tone, language - - -def infer( - text, - style_vec, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id - language, - hps, - net_g, - device, - skip_start=False, - skip_end=False, - assist_text=None, - assist_text_weight=0.7, - given_tone=None, -): - is_jp_extra = hps.version.endswith("JP-Extra") - bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( - text, - language, - hps, - device, - assist_text=assist_text, - assist_text_weight=assist_text_weight, - given_tone=given_tone, - ) - if skip_start: - phones = phones[3:] - tones = tones[3:] - lang_ids = lang_ids[3:] - bert = bert[:, 3:] - ja_bert = ja_bert[:, 3:] - en_bert = en_bert[:, 3:] - if skip_end: - phones = phones[:-2] - tones = tones[:-2] - lang_ids = lang_ids[:-2] - bert = bert[:, :-2] - ja_bert = ja_bert[:, :-2] - en_bert = en_bert[:, :-2] - with torch.no_grad(): - x_tst = phones.to(device).unsqueeze(0) - tones = tones.to(device).unsqueeze(0) - lang_ids = lang_ids.to(device).unsqueeze(0) - bert = bert.to(device).unsqueeze(0) - ja_bert = ja_bert.to(device).unsqueeze(0) - en_bert = en_bert.to(device).unsqueeze(0) - x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) - style_vec = torch.from_numpy(style_vec).to(device).unsqueeze(0) - del phones - sid_tensor = torch.LongTensor([sid]).to(device) - if is_jp_extra: - output = net_g.infer( - x_tst, - x_tst_lengths, - sid_tensor, - tones, - lang_ids, - ja_bert, - style_vec=style_vec, - sdp_ratio=sdp_ratio, - noise_scale=noise_scale, - noise_scale_w=noise_scale_w, - length_scale=length_scale, - ) - else: - output = net_g.infer( - x_tst, - x_tst_lengths, - sid_tensor, - tones, - lang_ids, - bert, - ja_bert, - en_bert, - style_vec=style_vec, - sdp_ratio=sdp_ratio, - noise_scale=noise_scale, - noise_scale_w=noise_scale_w, - length_scale=length_scale, - ) - audio = output[0][0, 0].data.cpu().float().numpy() - del ( - x_tst, - tones, - lang_ids, - bert, - x_tst_lengths, - sid_tensor, - ja_bert, - en_bert, - style_vec, - ) # , emo - if torch.cuda.is_available(): - torch.cuda.empty_cache() - return audio - - -def infer_multilang( - text, - style_vec, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - sid, - language, - hps, - net_g, - device, - skip_start=False, - skip_end=False, -): - bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], [] - # emo = get_emo_(reference_audio, emotion, sid) - # if isinstance(reference_audio, np.ndarray): - # emo = get_clap_audio_feature(reference_audio, device) - # else: - # emo = get_clap_text_feature(emotion, device) - # emo = torch.squeeze(emo, dim=1) - for idx, (txt, lang) in enumerate(zip(text, language)): - _skip_start = (idx != 0) or (skip_start and idx == 0) - _skip_end = (idx != len(language) - 1) or skip_end - ( - temp_bert, - temp_ja_bert, - temp_en_bert, - temp_phones, - temp_tones, - temp_lang_ids, - ) = get_text(txt, lang, hps, device) - if _skip_start: - temp_bert = temp_bert[:, 3:] - temp_ja_bert = temp_ja_bert[:, 3:] - temp_en_bert = temp_en_bert[:, 3:] - temp_phones = temp_phones[3:] - temp_tones = temp_tones[3:] - temp_lang_ids = temp_lang_ids[3:] - if _skip_end: - temp_bert = temp_bert[:, :-2] - temp_ja_bert = temp_ja_bert[:, :-2] - temp_en_bert = temp_en_bert[:, :-2] - temp_phones = temp_phones[:-2] - temp_tones = temp_tones[:-2] - temp_lang_ids = temp_lang_ids[:-2] - bert.append(temp_bert) - ja_bert.append(temp_ja_bert) - en_bert.append(temp_en_bert) - phones.append(temp_phones) - tones.append(temp_tones) - lang_ids.append(temp_lang_ids) - bert = torch.concatenate(bert, dim=1) - ja_bert = torch.concatenate(ja_bert, dim=1) - en_bert = torch.concatenate(en_bert, dim=1) - phones = torch.concatenate(phones, dim=0) - tones = torch.concatenate(tones, dim=0) - lang_ids = torch.concatenate(lang_ids, dim=0) - with torch.no_grad(): - x_tst = phones.to(device).unsqueeze(0) - tones = tones.to(device).unsqueeze(0) - lang_ids = lang_ids.to(device).unsqueeze(0) - bert = bert.to(device).unsqueeze(0) - ja_bert = ja_bert.to(device).unsqueeze(0) - en_bert = en_bert.to(device).unsqueeze(0) - # emo = emo.to(device).unsqueeze(0) - x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) - del phones - speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) - audio = ( - net_g.infer( - x_tst, - x_tst_lengths, - speakers, - tones, - lang_ids, - bert, - ja_bert, - en_bert, - style_vec=style_vec, - sdp_ratio=sdp_ratio, - noise_scale=noise_scale, - noise_scale_w=noise_scale_w, - length_scale=length_scale, - )[0][0, 0] - .data.cpu() - .float() - .numpy() - ) - del ( - x_tst, - tones, - lang_ids, - bert, - x_tst_lengths, - speakers, - ja_bert, - en_bert, - ) # , emo - if torch.cuda.is_available(): - torch.cuda.empty_cache() - return audio +from typing import Literal + +import torch + +import utils +from text import cleaned_text_to_sequence, get_bert +from style_bert_vits2.logging import logger +from style_bert_vits2.models import commons +from style_bert_vits2.models.models import SynthesizerTrn +from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra +from style_bert_vits2.text_processing.cleaner import clean_text +from style_bert_vits2.text_processing.symbols import SYMBOLS + + +class InvalidToneError(ValueError): + pass + + +def get_net_g(model_path: str, version: str, device: str, hps): + if version.endswith("JP-Extra"): + logger.info("Using JP-Extra model") + net_g = SynthesizerTrnJPExtra( + len(SYMBOLS), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) + else: + logger.info("Using normal model") + net_g = SynthesizerTrn( + len(SYMBOLS), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) + net_g.state_dict() + _ = net_g.eval() + if model_path.endswith(".pth") or model_path.endswith(".pt"): + _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True) + elif model_path.endswith(".safetensors"): + _ = utils.load_safetensors(model_path, net_g, True) + else: + raise ValueError(f"Unknown model format: {model_path}") + return net_g + + +def get_text( + text: str, + language_str: Literal["JP", "EN", "ZH"], + hps, + device: str, + assist_text: str | None = None, + assist_text_weight: float = 0.7, + given_tone: list[int] | None = None, +): + use_jp_extra = hps.version.endswith("JP-Extra") + # 推論時のみ呼び出されるので、raise_yomi_error は False に設定 + norm_text, phone, tone, word2ph = clean_text( + text, + language_str, + use_jp_extra = use_jp_extra, + raise_yomi_error = False, + ) + if given_tone is not None: + if len(given_tone) != len(phone): + raise InvalidToneError( + f"Length of given_tone ({len(given_tone)}) != length of phone ({len(phone)})" + ) + tone = given_tone + phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) + + if hps.data.add_blank: + phone = commons.intersperse(phone, 0) + tone = commons.intersperse(tone, 0) + language = commons.intersperse(language, 0) + for i in range(len(word2ph)): + word2ph[i] = word2ph[i] * 2 + word2ph[0] += 1 + bert_ori = get_bert( + norm_text, + word2ph, + language_str, + device, + assist_text, + assist_text_weight, + ) + del word2ph + assert bert_ori.shape[-1] == len(phone), phone + + if language_str == "ZH": + bert = bert_ori + ja_bert = torch.zeros(1024, len(phone)) + en_bert = torch.zeros(1024, len(phone)) + elif language_str == "JP": + bert = torch.zeros(1024, len(phone)) + ja_bert = bert_ori + en_bert = torch.zeros(1024, len(phone)) + elif language_str == "EN": + bert = torch.zeros(1024, len(phone)) + ja_bert = torch.zeros(1024, len(phone)) + en_bert = bert_ori + else: + raise ValueError("language_str should be ZH, JP or EN") + + assert bert.shape[-1] == len( + phone + ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" + + phone = torch.LongTensor(phone) + tone = torch.LongTensor(tone) + language = torch.LongTensor(language) + return bert, ja_bert, en_bert, phone, tone, language + + +def infer( + text: str, + style_vec, + sdp_ratio: float, + noise_scale: float, + noise_scale_w: float, + length_scale: float, + sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id + language: Literal["JP", "EN", "ZH"], + hps, + net_g, + device: str, + skip_start: bool = False, + skip_end: bool = False, + assist_text: str | None = None, + assist_text_weight: float = 0.7, + given_tone: list[int] | None = None, +): + is_jp_extra = hps.version.endswith("JP-Extra") + bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( + text, + language, + hps, + device, + assist_text=assist_text, + assist_text_weight=assist_text_weight, + given_tone=given_tone, + ) + if skip_start: + phones = phones[3:] + tones = tones[3:] + lang_ids = lang_ids[3:] + bert = bert[:, 3:] + ja_bert = ja_bert[:, 3:] + en_bert = en_bert[:, 3:] + if skip_end: + phones = phones[:-2] + tones = tones[:-2] + lang_ids = lang_ids[:-2] + bert = bert[:, :-2] + ja_bert = ja_bert[:, :-2] + en_bert = en_bert[:, :-2] + with torch.no_grad(): + x_tst = phones.to(device).unsqueeze(0) + tones = tones.to(device).unsqueeze(0) + lang_ids = lang_ids.to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + ja_bert = ja_bert.to(device).unsqueeze(0) + en_bert = en_bert.to(device).unsqueeze(0) + x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) + style_vec = torch.from_numpy(style_vec).to(device).unsqueeze(0) + del phones + sid_tensor = torch.LongTensor([sid]).to(device) + if is_jp_extra: + output = net_g.infer( + x_tst, + x_tst_lengths, + sid_tensor, + tones, + lang_ids, + ja_bert, + style_vec=style_vec, + sdp_ratio=sdp_ratio, + noise_scale=noise_scale, + noise_scale_w=noise_scale_w, + length_scale=length_scale, + ) + else: + output = net_g.infer( + x_tst, + x_tst_lengths, + sid_tensor, + tones, + lang_ids, + bert, + ja_bert, + en_bert, + style_vec=style_vec, + sdp_ratio=sdp_ratio, + noise_scale=noise_scale, + noise_scale_w=noise_scale_w, + length_scale=length_scale, + ) + audio = output[0][0, 0].data.cpu().float().numpy() + del ( + x_tst, + tones, + lang_ids, + bert, + x_tst_lengths, + sid_tensor, + ja_bert, + en_bert, + style_vec, + ) # , emo + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return audio + + +def infer_multilang( + text: str, + style_vec, + sdp_ratio: float, + noise_scale: float, + noise_scale_w: float, + length_scale: float, + sid: int, + language: Literal["JP", "EN", "ZH"], + hps, + net_g, + device: str, + skip_start: bool = False, + skip_end: bool = False, +): + bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], [] + # emo = get_emo_(reference_audio, emotion, sid) + # if isinstance(reference_audio, np.ndarray): + # emo = get_clap_audio_feature(reference_audio, device) + # else: + # emo = get_clap_text_feature(emotion, device) + # emo = torch.squeeze(emo, dim=1) + for idx, (txt, lang) in enumerate(zip(text, language)): + _skip_start = (idx != 0) or (skip_start and idx == 0) + _skip_end = (idx != len(language) - 1) or skip_end + ( + temp_bert, + temp_ja_bert, + temp_en_bert, + temp_phones, + temp_tones, + temp_lang_ids, + ) = get_text(txt, lang, hps, device) # type: ignore + if _skip_start: + temp_bert = temp_bert[:, 3:] + temp_ja_bert = temp_ja_bert[:, 3:] + temp_en_bert = temp_en_bert[:, 3:] + temp_phones = temp_phones[3:] + temp_tones = temp_tones[3:] + temp_lang_ids = temp_lang_ids[3:] + if _skip_end: + temp_bert = temp_bert[:, :-2] + temp_ja_bert = temp_ja_bert[:, :-2] + temp_en_bert = temp_en_bert[:, :-2] + temp_phones = temp_phones[:-2] + temp_tones = temp_tones[:-2] + temp_lang_ids = temp_lang_ids[:-2] + bert.append(temp_bert) + ja_bert.append(temp_ja_bert) + en_bert.append(temp_en_bert) + phones.append(temp_phones) + tones.append(temp_tones) + lang_ids.append(temp_lang_ids) + bert = torch.concatenate(bert, dim=1) + ja_bert = torch.concatenate(ja_bert, dim=1) + en_bert = torch.concatenate(en_bert, dim=1) + phones = torch.concatenate(phones, dim=0) + tones = torch.concatenate(tones, dim=0) + lang_ids = torch.concatenate(lang_ids, dim=0) + with torch.no_grad(): + x_tst = phones.to(device).unsqueeze(0) + tones = tones.to(device).unsqueeze(0) + lang_ids = lang_ids.to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + ja_bert = ja_bert.to(device).unsqueeze(0) + en_bert = en_bert.to(device).unsqueeze(0) + # emo = emo.to(device).unsqueeze(0) + x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) + del phones + speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) + audio = ( + net_g.infer( + x_tst, + x_tst_lengths, + speakers, + tones, + lang_ids, + bert, + ja_bert, + en_bert, + style_vec=style_vec, + sdp_ratio=sdp_ratio, + noise_scale=noise_scale, + noise_scale_w=noise_scale_w, + length_scale=length_scale, + )[0][0, 0] + .data.cpu() + .float() + .numpy() + ) + del ( + x_tst, + tones, + lang_ids, + bert, + x_tst_lengths, + speakers, + ja_bert, + en_bert, + ) # , emo + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return audio diff --git a/style_bert_vits2/text_processing/cleaner.py b/style_bert_vits2/text_processing/cleaner.py new file mode 100644 index 000000000..400d19880 --- /dev/null +++ b/style_bert_vits2/text_processing/cleaner.py @@ -0,0 +1,46 @@ +from typing import Literal + + +def clean_text( + text: str, + language: Literal["JP", "EN", "ZH"], + use_jp_extra: bool = True, + raise_yomi_error: bool = False, +) -> tuple[str, list[str], list[int], list[int]]: + """ + テキストをクリーニングし、音素に変換する + + Args: + text (str): クリーニングするテキスト + language (Literal["JP", "EN", "ZH"]): テキストの言語 + use_jp_extra (bool, optional): テキストが日本語の場合に JP-Extra モデルを利用するかどうか。Defaults to True. + raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False. + + Returns: + tuple[str, list[str], list[int], list[int]]: クリーニングされたテキストと、音素・アクセント・元のテキストの各文字に音素が何個割り当てられるかのリスト + """ + + # Changed to import inside if condition to avoid unnecessary import + if language == "JP": + from transformers import AutoTokenizer + from style_bert_vits2.text_processing.japanese.g2p import g2p + from style_bert_vits2.text_processing.japanese.normalizer import normalize_text + norm_text = normalize_text(text) + phones, tones, word2ph = g2p( + norm_text, + tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese-char-wwm"), # 暫定的にここで指定 + use_jp_extra = use_jp_extra, + raise_yomi_error = raise_yomi_error, + ) + elif language == "EN": + from ...text import english as language_module + norm_text = language_module.normalize_text(text) + phones, tones, word2ph = language_module.g2p(norm_text) + elif language == "ZH": + from ...text import chinese as language_module + norm_text = language_module.normalize_text(text) + phones, tones, word2ph = language_module.g2p(norm_text) + else: + raise ValueError(f"Language {language} not supported") + + return norm_text, phones, tones, word2ph diff --git a/text/chinese.py b/text/chinese.py index 56dc4f34f..3d9c39226 100644 --- a/text/chinese.py +++ b/text/chinese.py @@ -168,7 +168,7 @@ def _g2p(segments): return phones_list, tones_list, word2ph -def text_normalize(text): +def normalize_text(text): numbers = re.findall(r"\d+(?:\.?\d+)?", text) for number in numbers: text = text.replace(number, cn2an.an2cn(number), 1) @@ -186,7 +186,7 @@ def get_bert_feature(text, word2ph): from text.chinese_bert import get_bert_feature text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏" - text = text_normalize(text) + text = normalize_text(text) print(text) phones, tones, word2ph = g2p(text) bert = get_bert_feature(text, word2ph) diff --git a/text/cleaner.py b/text/cleaner.py deleted file mode 100644 index d805b5145..000000000 --- a/text/cleaner.py +++ /dev/null @@ -1,26 +0,0 @@ -def clean_text(text, language, use_jp_extra=True, raise_yomi_error=False): - # Changed to import inside if condition to avoid unnecessary import - if language == "ZH": - from . import chinese as language_module - - norm_text = language_module.text_normalize(text) - phones, tones, word2ph = language_module.g2p(norm_text) - elif language == "EN": - from . import english as language_module - - norm_text = language_module.text_normalize(text) - phones, tones, word2ph = language_module.g2p(norm_text) - elif language == "JP": - from . import japanese as language_module - - norm_text = language_module.text_normalize(text) - phones, tones, word2ph = language_module.g2p( - norm_text, use_jp_extra, raise_yomi_error=raise_yomi_error - ) - else: - raise ValueError(f"Language {language} not supported") - return norm_text, phones, tones, word2ph - - -if __name__ == "__main__": - pass diff --git a/text/english.py b/text/english.py index f38ee84ae..3dcfdecd7 100644 --- a/text/english.py +++ b/text/english.py @@ -369,7 +369,7 @@ def normalize_numbers(text): return text -def text_normalize(text): +def normalize_text(text): text = normalize_numbers(text) text = replace_punctuation(text) text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text) diff --git a/text/japanese.py b/text/japanese.py index 0dc6aa83f..03682e360 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -96,7 +96,7 @@ class YomiError(Exception): } -def text_normalize(text): +def normalize_text(text): """ 日本語のテキストを正規化する。 結果は、ちょうど次の文字のみからなる: @@ -177,7 +177,7 @@ def g2p( norm_text: str, use_jp_extra: bool = True, raise_yomi_error: bool = False ) -> tuple[list[str], list[int], list[int]]: """ - 他で使われるメインの関数。`text_normalize()`で正規化された`norm_text`を受け取り、 + 他で使われるメインの関数。`normalize_text()`で正規化された`norm_text`を受け取り、 - phones: 音素のリスト(ただし`!`や`,`や`.`等punctuationが含まれうる) - tones: アクセントのリスト、0(低)と1(高)からなり、phonesと同じ長さ - word2ph: 元のテキストの各文字に音素が何個割り当てられるかを表すリスト @@ -350,7 +350,7 @@ def text2sep_kata( norm_text: str, raise_yomi_error: bool = False ) -> tuple[list[str], list[str]]: """ - `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、 + `normalize_text()`で正規化済みの`norm_text`を受け取り、それを単語分割し、 分割された単語リストとその読み(カタカナor記号1文字)のリストのタプルを返す。 単語分割結果は、`g2p()`の`word2ph`で1文字あたりに割り振る音素記号の数を決めるために使う。 例: @@ -634,7 +634,7 @@ def mora2phonemes(mora: str) -> str: text = "こんにちは、世界。" from text.japanese_bert import get_bert_feature - text = text_normalize(text) + text = normalize_text(text) phones, tones, word2ph = g2p(text) bert = get_bert_feature(text, word2ph) From 2994873e3885fc1bae602dae8d311838697b788e Mon Sep 17 00:00:00 2001 From: kale4eat Date: Thu, 7 Mar 2024 09:40:21 +0900 Subject: [PATCH 021/148] add no client timeout summarize the except statement at disconnection --- text/pyopenjtalk_worker/worker_server.py | 25 ++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/text/pyopenjtalk_worker/worker_server.py b/text/pyopenjtalk_worker/worker_server.py index dc6d476c1..6babd91d6 100644 --- a/text/pyopenjtalk_worker/worker_server.py +++ b/text/pyopenjtalk_worker/worker_server.py @@ -1,6 +1,7 @@ import pyopenjtalk import socket import select +import time from .worker_common import ( ConnectionClosedException, @@ -62,13 +63,23 @@ def handle_request(self, request): return response - def start_server(self, port: int): + def start_server(self, port: int, no_client_timeout: int = 30): logger.info("start pyopenjtalk worker server") with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: server_socket.bind((socket.gethostname(), port)) server_socket.listen() sockets = [server_socket] + no_client_since = time.time() while True: + if self.client_count == 0: + if no_client_since is None: + no_client_since = time.time() + elif (time.time() - no_client_since) > no_client_timeout: + logger.info("quit because there is no client") + return + else: + no_client_since = None + ready_sockets, _, _ = select.select(sockets, [], [], 0.1) for sock in ready_sockets: if sock is server_socket: @@ -80,17 +91,15 @@ def start_server(self, port: int): # client try: request = receive_data(sock) - except ConnectionClosedException as e: - sock.close() - sockets.remove(sock) - self.client_count -= 1 - logger.info("close connection") - continue except Exception as e: sock.close() sockets.remove(sock) self.client_count -= 1 - logger.error(e) + # unexpected disconnections + if not isinstance(e, ConnectionClosedException): + logger.error(e) + + logger.info("close connection") continue logger.trace(f"server received request: {request}") From c3c0dd8b32db40b385995a1389402d3683a21762 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 02:31:30 +0000 Subject: [PATCH 022/148] Refactor: add style_bert_vits2/text_processing/bert_models.py to hold loaded BERT models/tokenizer and replace all from_pretrained() to load_model/load_tokenizer --- server_editor.py | 10 +- style_bert_vits2/constants.py | 34 +- style_bert_vits2/models/infer.py | 15 +- .../text_processing/bert_models.py | 123 ++++ style_bert_vits2/text_processing/cleaner.py | 20 +- .../text_processing/japanese/g2p.py | 8 +- .../text_processing/japanese/g2p_utils.py | 10 +- text/__init__.py | 24 +- text/chinese_bert.py | 17 +- text/english.py | 7 +- text/english_bert_mock.py | 18 +- text/japanese.py | 642 ------------------ text/japanese_bert.py | 17 +- 13 files changed, 217 insertions(+), 728 deletions(-) create mode 100644 style_bert_vits2/text_processing/bert_models.py delete mode 100644 text/japanese.py diff --git a/server_editor.py b/server_editor.py index 7577637cf..eb1cffd87 100644 --- a/server_editor.py +++ b/server_editor.py @@ -28,7 +28,6 @@ from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from scipy.io import wavfile -from transformers import AutoTokenizer from common.tts_model import ModelHolder from style_bert_vits2.constants import ( @@ -42,6 +41,7 @@ Languages, ) from style_bert_vits2.logging import logger +from style_bert_vits2.text_processing import bert_models from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.text_processing.japanese.normalizer import normalize_text from style_bert_vits2.text_processing.japanese.user_dict import ( @@ -150,8 +150,10 @@ def save_last_download(latest_release): # 最初に pyopenjtalk の辞書を更新 update_dict() -# 単語分割に使う BERT トークナイザーをロード -tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese-char-wwm") +# 単語分割に使う BERT モデル/トークナイザーを事前にロードしておく +## server_editor.py は日本語にしか対応していないため、日本語の BERT モデル/トークナイザーのみロードする +bert_models.load_model(Languages.JP) +bert_models.load_tokenizer(Languages.JP) class AudioResponse(Response): @@ -227,7 +229,7 @@ async def read_item(item: TextRequest): try: # 最初に正規化しないと整合性がとれない text = normalize_text(item.text) - kata_tone_list = g2kata_tone(text, tokenizer) + kata_tone_list = g2kata_tone(text) except Exception as e: raise HTTPException( status_code=400, diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index 2d6e9adc1..d6b54c896 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -1,19 +1,31 @@ -from enum import Enum +from enum import StrEnum from pathlib import Path # Style-Bert-VITS2 のバージョン VERSION = "2.3.1" -# Gradio のテーマ -## Built-in theme: "default", "base", "monochrome", "soft", "glass" -## See https://huggingface.co/spaces/gradio/theme-gallery for more themes -GRADIO_THEME = "NoCrypt/miku" +# Style-Bert-VITS2 のベースディレクトリ +BASE_DIR = Path(__file__).parent.parent + +# 利用可能な言語 +## JP-Extra モデル利用時は JP 以外の言語の音声合成はできない +class Languages(StrEnum): + JP = "JP" + EN = "EN" + ZH = "ZH" + +# 言語ごとのデフォルトの BERT トークナイザーのパス +DEFAULT_BERT_TOKENIZER_PATHS = { + Languages.JP: BASE_DIR / "bert" / "deberta-v2-large-japanese-char-wwm", + Languages.EN: BASE_DIR / "bert" / "deberta-v3-large", + Languages.ZH: BASE_DIR / "bert" / "chinese-roberta-wwm-ext-large", +} # デフォルトのユーザー辞書ディレクトリ ## style_bert_vits2.text_processing.japanese.user_dict モジュールのデフォルト値として利用される ## ライブラリとしての利用などで外部のユーザー辞書を指定したい場合は、user_dict 以下の各関数の実行時、引数に辞書データファイルのパスを指定する -DEFAULT_USER_DICT_DIR = Path(__file__).parent.parent / "dict_data" +DEFAULT_USER_DICT_DIR = BASE_DIR / "dict_data" # デフォルトの推論パラメータ DEFAULT_STYLE = "Neutral" @@ -27,9 +39,7 @@ DEFAULT_ASSIST_TEXT_WEIGHT = 0.7 DEFAULT_ASSIST_TEXT_WEIGHT = 1.0 -# 利用可能な言語 -## JP-Extra モデル利用時は JP 以外の言語の音声合成はできない -class Languages(str, Enum): - JP = "JP" - EN = "EN" - ZH = "ZH" +# Gradio のテーマ +## Built-in theme: "default", "base", "monochrome", "soft", "glass" +## See https://huggingface.co/spaces/gradio/theme-gallery for more themes +GRADIO_THEME = "NoCrypt/miku" diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index e0d869110..99999a957 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -1,9 +1,8 @@ -from typing import Literal - import torch import utils from text import cleaned_text_to_sequence, get_bert +from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger from style_bert_vits2.models import commons from style_bert_vits2.models.models import SynthesizerTrn @@ -48,7 +47,7 @@ def get_net_g(model_path: str, version: str, device: str, hps): def get_text( text: str, - language_str: Literal["JP", "EN", "ZH"], + language_str: Languages, hps, device: str, assist_text: str | None = None, @@ -89,15 +88,15 @@ def get_text( del word2ph assert bert_ori.shape[-1] == len(phone), phone - if language_str == "ZH": + if language_str == Languages.ZH: bert = bert_ori ja_bert = torch.zeros(1024, len(phone)) en_bert = torch.zeros(1024, len(phone)) - elif language_str == "JP": + elif language_str == Languages.JP: bert = torch.zeros(1024, len(phone)) ja_bert = bert_ori en_bert = torch.zeros(1024, len(phone)) - elif language_str == "EN": + elif language_str == Languages.EN: bert = torch.zeros(1024, len(phone)) ja_bert = torch.zeros(1024, len(phone)) en_bert = bert_ori @@ -122,7 +121,7 @@ def infer( noise_scale_w: float, length_scale: float, sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id - language: Literal["JP", "EN", "ZH"], + language: Languages, hps, net_g, device: str, @@ -222,7 +221,7 @@ def infer_multilang( noise_scale_w: float, length_scale: float, sid: int, - language: Literal["JP", "EN", "ZH"], + language: Languages, hps, net_g, device: str, diff --git a/style_bert_vits2/text_processing/bert_models.py b/style_bert_vits2/text_processing/bert_models.py new file mode 100644 index 000000000..9132082c9 --- /dev/null +++ b/style_bert_vits2/text_processing/bert_models.py @@ -0,0 +1,123 @@ +""" +Style-Bert-VITS2 の学習・推論に必要な各言語ごとの BERT モデルをロード/取得するためのモジュール。 + +オリジナルの Bert-VITS2 では各言語ごとの BERT モデルが初回インポート時にハードコードされたパスから「暗黙的に」ロードされているが、 +場合によっては多重にロードされて非効率なほか、BERT モデルのロード元のパスがハードコードされているためライブラリ化ができない。 + +そこで、ライブラリの利用前に、音声合成に利用する言語の BERT モデルだけを「明示的に」ロードできるようにした。 +一度 load_tokenizer() で当該言語の BERT モデルがロードされていれば、ライブラリ内部のどこからでもロード済みのモデル/トークナイザーを取得できる。 +""" + +from typing import cast + +from transformers import ( + AutoModelForMaskedLM, + AutoTokenizer, + DebertaV2Model, + DebertaV2Tokenizer, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) + +from style_bert_vits2.constants import DEFAULT_BERT_TOKENIZER_PATHS, Languages +from style_bert_vits2.logging import logger + + +# 各言語ごとのロード済みの BERT モデルを格納する辞書 +loaded_models: dict[Languages, PreTrainedModel | DebertaV2Model] = {} + +# 各言語ごとのロード済みの BERT トークナイザーを格納する辞書 +loaded_tokenizers: dict[Languages, PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer] = {} + + +def load_model( + language: Languages, + pretrained_model_name_or_path: str | None = None, +) -> PreTrainedModel | DebertaV2Model: + """ + 指定された言語の BERT モデルをロードし、ロード済みの BERT モデルを返す + 一度ロードされていれば、ロード済みの BERT モデルを即座に返す + ライブラリ利用時は常に pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある + ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき + + Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている + これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い + - 日本語: ku-nlp/deberta-v2-large-japanese-char-wwm + - 英語: microsoft/deberta-v3-large + - 中国語: hfl/chinese-roberta-wwm-ext-large + + Args: + language (Languages): ロードする学習済みモデルの対象言語 + pretrained_model_name_or_path (str | None): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) + + Returns: + PreTrainedModel | DebertaV2Model: ロード済みの BERT モデル + """ + + # すでにロード済みの場合はそのまま返す + if language in loaded_models: + return loaded_models[language] + + # pretrained_model_name_or_path が指定されていない場合はデフォルトのパスを利用 + if pretrained_model_name_or_path is None: + assert DEFAULT_BERT_TOKENIZER_PATHS[language].exists(), \ + f"The default {language} BERT model does not exist on the file system. Please specify the path to the pre-trained model." + pretrained_model_name_or_path = str(DEFAULT_BERT_TOKENIZER_PATHS[language]) + + # BERT モデルをロードし、辞書に格納して返す + ## 英語のみ DebertaV2Model でロードする必要がある + if language == Languages.EN: + model = cast(DebertaV2Model, DebertaV2Model.from_pretrained(pretrained_model_name_or_path)) + else: + model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path) + loaded_models[language] = model + logger.info(f"Loaded the {language} BERT model from {pretrained_model_name_or_path}") + + return model + + +def load_tokenizer( + language: Languages, + pretrained_model_name_or_path: str | None = None, +) -> PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer: + """ + 指定された言語の BERT モデルをロードし、ロード済みの BERT トークナイザーを返す + 一度ロードされていれば、ロード済みの BERT トークナイザーを即座に返す + ライブラリ利用時は常に pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある + ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき + + Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている + これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い + - 日本語: ku-nlp/deberta-v2-large-japanese-char-wwm + - 英語: microsoft/deberta-v3-large + - 中国語: hfl/chinese-roberta-wwm-ext-large + + Args: + language (Languages): ロードする学習済みモデルの対象言語 + pretrained_model_name_or_path (str | None): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) + + Returns: + PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer: ロード済みの BERT トークナイザー + """ + + # すでにロード済みの場合はそのまま返す + if language in loaded_tokenizers: + return loaded_tokenizers[language] + + # pretrained_model_name_or_path が指定されていない場合はデフォルトのパスを利用 + if pretrained_model_name_or_path is None: + assert DEFAULT_BERT_TOKENIZER_PATHS[language].exists(), \ + f"The default {language} BERT tokenizer does not exist on the file system. Please specify the path to the pre-trained model." + pretrained_model_name_or_path = str(DEFAULT_BERT_TOKENIZER_PATHS[language]) + + # BERT トークナイザーをロードし、辞書に格納して返す + ## 英語のみ DebertaV2Tokenizer でロードする必要がある + if language == Languages.EN: + tokenizer = DebertaV2Tokenizer.from_pretrained(pretrained_model_name_or_path) + else: + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) + loaded_tokenizers[language] = tokenizer + logger.info(f"Loaded the {language} BERT tokenizer from {pretrained_model_name_or_path}") + + return tokenizer diff --git a/style_bert_vits2/text_processing/cleaner.py b/style_bert_vits2/text_processing/cleaner.py index 400d19880..042009964 100644 --- a/style_bert_vits2/text_processing/cleaner.py +++ b/style_bert_vits2/text_processing/cleaner.py @@ -1,9 +1,9 @@ -from typing import Literal +from style_bert_vits2.constants import Languages def clean_text( text: str, - language: Literal["JP", "EN", "ZH"], + language: Languages, use_jp_extra: bool = True, raise_yomi_error: bool = False, ) -> tuple[str, list[str], list[int], list[int]]: @@ -12,7 +12,7 @@ def clean_text( Args: text (str): クリーニングするテキスト - language (Literal["JP", "EN", "ZH"]): テキストの言語 + language (Languages): テキストの言語 use_jp_extra (bool, optional): テキストが日本語の場合に JP-Extra モデルを利用するかどうか。Defaults to True. raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False. @@ -21,22 +21,16 @@ def clean_text( """ # Changed to import inside if condition to avoid unnecessary import - if language == "JP": - from transformers import AutoTokenizer + if language == Languages.JP: from style_bert_vits2.text_processing.japanese.g2p import g2p from style_bert_vits2.text_processing.japanese.normalizer import normalize_text norm_text = normalize_text(text) - phones, tones, word2ph = g2p( - norm_text, - tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese-char-wwm"), # 暫定的にここで指定 - use_jp_extra = use_jp_extra, - raise_yomi_error = raise_yomi_error, - ) - elif language == "EN": + phones, tones, word2ph = g2p(norm_text, use_jp_extra, raise_yomi_error) + elif language == Languages.EN: from ...text import english as language_module norm_text = language_module.normalize_text(text) phones, tones, word2ph = language_module.g2p(norm_text) - elif language == "ZH": + elif language == Languages.ZH: from ...text import chinese as language_module norm_text = language_module.normalize_text(text) phones, tones, word2ph = language_module.g2p(norm_text) diff --git a/style_bert_vits2/text_processing/japanese/g2p.py b/style_bert_vits2/text_processing/japanese/g2p.py index 1b8bd6972..8ef4d3eac 100644 --- a/style_bert_vits2/text_processing/japanese/g2p.py +++ b/style_bert_vits2/text_processing/japanese/g2p.py @@ -1,8 +1,9 @@ import pyopenjtalk import re -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger +from style_bert_vits2.text_processing import bert_models from style_bert_vits2.text_processing.japanese.mora_list import MORA_KATA_TO_MORA_PHONEMES from style_bert_vits2.text_processing.japanese.normalizer import replace_punctuation from style_bert_vits2.text_processing.symbols import PUNCTUATIONS @@ -10,7 +11,6 @@ def g2p( norm_text: str, - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, use_jp_extra: bool = True, raise_yomi_error: bool = False ) -> tuple[list[str], list[int], list[int]]: @@ -21,11 +21,9 @@ def g2p( - word2ph: 元のテキストの各文字に音素が何個割り当てられるかを表すリスト のタプルを返す。 ただし `phones` と `tones` の最初と終わりに `_` が入り、応じて `word2ph` の最初と最後に 1 が追加される。 - tokenizer には deberta-v2-large-japanese-char-wwm を AutoTokenizer.from_pretrained() でロードしたものを指定する。 Args: norm_text (str): 正規化されたテキスト - tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast): 単語分割に使うロード済みの BERT Tokenizer インスタンス use_jp_extra (bool, optional): False の場合、「ん」の音素を「N」ではなく「n」とする。Defaults to True. raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False. @@ -66,7 +64,7 @@ def g2p( for i in sep_text: if i not in PUNCTUATIONS: sep_tokenized.append( - tokenizer.tokenize(i) + bert_models.load_tokenizer(Languages.JP).tokenize(i) ) # ここでおそらく`i`が文字単位に分割される else: sep_tokenized.append([i]) diff --git a/style_bert_vits2/text_processing/japanese/g2p_utils.py b/style_bert_vits2/text_processing/japanese/g2p_utils.py index e09560263..3a91a00da 100644 --- a/style_bert_vits2/text_processing/japanese/g2p_utils.py +++ b/style_bert_vits2/text_processing/japanese/g2p_utils.py @@ -1,5 +1,3 @@ -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast - from style_bert_vits2.text_processing.japanese.g2p import g2p from style_bert_vits2.text_processing.japanese.mora_list import ( MORA_KATA_TO_MORA_PHONEMES, @@ -8,21 +6,19 @@ from style_bert_vits2.text_processing.symbols import PUNCTUATIONS -def g2kata_tone(norm_text: str, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast) -> list[tuple[str, int]]: +def g2kata_tone(norm_text: str) -> list[tuple[str, int]]: """ テキストからカタカナとアクセントのペアのリストを返す。 - 推論時のみに使われるので、常に`raise_yomi_error=False`でg2pを呼ぶ。 - tokenizer には deberta-v2-large-japanese-char-wwm を AutoTokenizer.from_pretrained() でロードしたものを指定する。 + 推論時のみに使われるので、常に `raise_yomi_error=False` で g2p を呼ぶ。 Args: norm_text: 正規化されたテキスト。 - tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast): 単語分割に使うロード済みの BERT Tokenizer インスタンス Returns: カタカナと音高のリスト。 """ - phones, tones, _ = g2p(norm_text, tokenizer, use_jp_extra=True, raise_yomi_error=False) + phones, tones, _ = g2p(norm_text, use_jp_extra=True, raise_yomi_error=False) return phone_tone2kata_tone(list(zip(phones, tones))) diff --git a/text/__init__.py b/text/__init__.py index ce4c008c1..efff8302d 100644 --- a/text/__init__.py +++ b/text/__init__.py @@ -1,12 +1,17 @@ +from style_bert_vits2.constants import Languages from style_bert_vits2.text_processing.symbols import * + _symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)} -def cleaned_text_to_sequence(cleaned_text, tones, language): - """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. +def cleaned_text_to_sequence(cleaned_text: str, tones: list[int], language: Languages): + """ + Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + Args: text: string to convert to a sequence + Returns: List of integers corresponding to the symbols in the text """ @@ -18,12 +23,19 @@ def cleaned_text_to_sequence(cleaned_text, tones, language): return phones, tones, lang_ids -def get_bert(text, word2ph, language, device, assist_text=None, assist_text_weight=0.7): - if language == "ZH": +def get_bert( + text: str, + word2ph, + language: Languages, + device: str, + assist_text: str | None = None, + assist_text_weight: float = 0.7, +): + if language == Languages.ZH: from .chinese_bert import get_bert_feature - elif language == "EN": + elif language == Languages.EN: from .english_bert_mock import get_bert_feature - elif language == "JP": + elif language == Languages.JP: from .japanese_bert import get_bert_feature else: raise ValueError(f"Language {language} not supported") diff --git a/text/chinese_bert.py b/text/chinese_bert.py index 94e1408e2..2ee6e9d6c 100644 --- a/text/chinese_bert.py +++ b/text/chinese_bert.py @@ -1,23 +1,21 @@ import sys import torch -from transformers import AutoModelForMaskedLM, AutoTokenizer from config import config +from style_bert_vits2.constants import Languages +from style_bert_vits2.text_processing import bert_models -LOCAL_PATH = "./bert/chinese-roberta-wwm-ext-large" - -tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH) models = dict() def get_bert_feature( - text, + text: str, word2ph, - device=config.bert_gen_config.device, - assist_text=None, - assist_text_weight=0.7, + device = config.bert_gen_config.device, + assist_text: str | None = None, + assist_text_weight: float = 0.7, ): if ( sys.platform == "darwin" @@ -30,8 +28,9 @@ def get_bert_feature( if device == "cuda" and not torch.cuda.is_available(): device = "cpu" if device not in models.keys(): - models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device) + models[device] = bert_models.load_model(Languages.ZH).to(device) with torch.no_grad(): + tokenizer = bert_models.load_tokenizer(Languages.ZH) inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) diff --git a/text/english.py b/text/english.py index 3dcfdecd7..a6d71e9cb 100644 --- a/text/english.py +++ b/text/english.py @@ -2,16 +2,16 @@ import os import re from g2p_en import G2p -from transformers import DebertaV2Tokenizer +from style_bert_vits2.constants import Languages +from style_bert_vits2.text_processing import bert_models from style_bert_vits2.text_processing.symbols import PUNCTUATIONS, SYMBOLS + current_file_path = os.path.dirname(__file__) CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep") CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle") _g2p = G2p() -LOCAL_PATH = "./bert/deberta-v3-large" -tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH) arpa = { "AH0", @@ -392,6 +392,7 @@ def sep_text(text): def text_to_words(text): + tokenizer = bert_models.load_tokenizer(Languages.EN) tokens = tokenizer.tokenize(text) words = [] for idx, t in enumerate(tokens): diff --git a/text/english_bert_mock.py b/text/english_bert_mock.py index 782b65d7f..8e57df1da 100644 --- a/text/english_bert_mock.py +++ b/text/english_bert_mock.py @@ -1,24 +1,21 @@ import sys import torch -from transformers import DebertaV2Model, DebertaV2Tokenizer from config import config +from style_bert_vits2.constants import Languages +from style_bert_vits2.text_processing import bert_models -LOCAL_PATH = "./bert/deberta-v3-large" - -tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH) - models = dict() def get_bert_feature( - text, + text: str, word2ph, - device=config.bert_gen_config.device, - assist_text=None, - assist_text_weight=0.7, + device = config.bert_gen_config.device, + assist_text: str | None = None, + assist_text_weight: float = 0.7, ): if ( sys.platform == "darwin" @@ -31,8 +28,9 @@ def get_bert_feature( if device == "cuda" and not torch.cuda.is_available(): device = "cpu" if device not in models.keys(): - models[device] = DebertaV2Model.from_pretrained(LOCAL_PATH).to(device) + models[device] = bert_models.load_model(Languages.EN).to(device) with torch.no_grad(): + tokenizer = bert_models.load_tokenizer(Languages.EN) inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) diff --git a/text/japanese.py b/text/japanese.py deleted file mode 100644 index 03682e360..000000000 --- a/text/japanese.py +++ /dev/null @@ -1,642 +0,0 @@ -# Convert Japanese text to phonemes which is -# compatible with Julius https://github.com/julius-speech/segmentation-kit -import re -import unicodedata - -import pyopenjtalk -from num2words import num2words -from transformers import AutoTokenizer - -from style_bert_vits2.logging import logger -from style_bert_vits2.text_processing.japanese.mora_list import ( - MORA_KATA_TO_MORA_PHONEMES, - MORA_PHONEMES_TO_MORA_KATA, -) -from style_bert_vits2.text_processing.japanese.user_dict import update_dict -from style_bert_vits2.text_processing.symbols import PUNCTUATIONS - -# 最初にpyopenjtalkの辞書を更新 -update_dict() - -# 子音の集合 -COSONANTS = set( - [ - cosonant - for cosonant, _ in MORA_KATA_TO_MORA_PHONEMES.values() - if cosonant is not None - ] -) - -# 母音の集合、便宜上「ん」を含める -VOWELS = {"a", "i", "u", "e", "o", "N"} - - -class YomiError(Exception): - """ - OpenJTalkで、読みが正しく取得できない箇所があるときに発生する例外。 - 基本的に「学習の前処理のテキスト処理時」には発生させ、そうでない場合は、 - ignore_yomi_error=Trueにしておいて、この例外を発生させないようにする。 - """ - - pass - - -# 正規化で記号を変換するための辞書 -rep_map = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - ".": ".", - "…": "...", - "···": "...", - "・・・": "...", - "·": ",", - "・": ",", - "、": ",", - "$": ".", - "“": "'", - "”": "'", - '"': "'", - "‘": "'", - "’": "'", - "(": "'", - ")": "'", - "(": "'", - ")": "'", - "《": "'", - "》": "'", - "【": "'", - "】": "'", - "[": "'", - "]": "'", - # NFKC正規化後のハイフン・ダッシュの変種を全て通常半角ハイフン - \u002d に変換 - "\u02d7": "\u002d", # ˗, Modifier Letter Minus Sign - "\u2010": "\u002d", # ‐, Hyphen, - # "\u2011": "\u002d", # ‑, Non-Breaking Hyphen, NFKCにより\u2010に変換される - "\u2012": "\u002d", # ‒, Figure Dash - "\u2013": "\u002d", # –, En Dash - "\u2014": "\u002d", # —, Em Dash - "\u2015": "\u002d", # ―, Horizontal Bar - "\u2043": "\u002d", # ⁃, Hyphen Bullet - "\u2212": "\u002d", # −, Minus Sign - "\u23af": "\u002d", # ⎯, Horizontal Line Extension - "\u23e4": "\u002d", # ⏤, Straightness - "\u2500": "\u002d", # ─, Box Drawings Light Horizontal - "\u2501": "\u002d", # ━, Box Drawings Heavy Horizontal - "\u2e3a": "\u002d", # ⸺, Two-Em Dash - "\u2e3b": "\u002d", # ⸻, Three-Em Dash - # "~": "-", # これは長音記号「ー」として扱うよう変更 - # "~": "-", # これも長音記号「ー」として扱うよう変更 - "「": "'", - "」": "'", -} - - -def normalize_text(text): - """ - 日本語のテキストを正規化する。 - 結果は、ちょうど次の文字のみからなる: - - ひらがな - - カタカナ(全角長音記号「ー」が入る!) - - 漢字 - - 半角アルファベット(大文字と小文字) - - ギリシャ文字 - - `.` (句点`。`や`…`の一部や改行等) - - `,` (読点`、`や`:`等) - - `?` (疑問符`?`) - - `!` (感嘆符`!`) - - `'` (`「`や`」`等) - - `-` (`―`(ダッシュ、長音記号ではない)や`-`等) - - 注意点: - - 三点リーダー`…`は`...`に変換される(`なるほど…。` → `なるほど....`) - - 数字は漢字に変換される(`1,100円` → `千百円`、`52.34` → `五十二点三四`) - - 読点や疑問符等の位置・個数等は保持される(`??あ、、!!!` → `??あ,,!!!`) - """ - res = unicodedata.normalize("NFKC", text) # ここでアルファベットは半角になる - res = japanese_convert_numbers_to_words(res) # 「100円」→「百円」等 - # 「~」と「~」も長音記号として扱う - res = res.replace("~", "ー") - res = res.replace("~", "ー") - - res = replace_punctuation(res) # 句読点等正規化、読めない文字を削除 - - # 結合文字の濁点・半濁点を削除 - # 通常の「ば」等はそのままのこされる、「あ゛」は上で「あ゙」になりここで「あ」になる - res = res.replace("\u3099", "") # 結合文字の濁点を削除、る゙ → る - res = res.replace("\u309A", "") # 結合文字の半濁点を削除、な゚ → な - return res - - -def replace_punctuation(text: str) -> str: - """句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalkで読みが取得できるもののみ残す: - 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字 - """ - pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) - - # 句読点を辞書で置換 - replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) - - replaced_text = re.sub( - # ↓ ひらがな、カタカナ、漢字 - r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" - # ↓ 半角アルファベット(大文字と小文字) - + r"\u0041-\u005A\u0061-\u007A" - # ↓ 全角アルファベット(大文字と小文字) - + r"\uFF21-\uFF3A\uFF41-\uFF5A" - # ↓ ギリシャ文字 - + r"\u0370-\u03FF\u1F00-\u1FFF" - # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている - + "".join(PUNCTUATIONS) + r"]+", - # 上述以外の文字を削除 - "", - replaced_text, - ) - - return replaced_text - - -_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+") -_CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} -_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])") -_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?") - - -def japanese_convert_numbers_to_words(text: str) -> str: - res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text) - res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res) - res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res) - return res - - -def g2p( - norm_text: str, use_jp_extra: bool = True, raise_yomi_error: bool = False -) -> tuple[list[str], list[int], list[int]]: - """ - 他で使われるメインの関数。`normalize_text()`で正規化された`norm_text`を受け取り、 - - phones: 音素のリスト(ただし`!`や`,`や`.`等punctuationが含まれうる) - - tones: アクセントのリスト、0(低)と1(高)からなり、phonesと同じ長さ - - word2ph: 元のテキストの各文字に音素が何個割り当てられるかを表すリスト - のタプルを返す。 - ただし`phones`と`tones`の最初と終わりに`_`が入り、応じて`word2ph`の最初と最後に1が追加される。 - - use_jp_extra: Falseの場合、「ん」の音素を「N」ではなく「n」とする。 - raise_yomi_error: Trueの場合、読めない文字があるときに例外を発生させる。 - Falseの場合は読めない文字が消えたような扱いとして処理される。 - """ - # pyopenjtalkのフルコンテキストラベルを使ってアクセントを取り出すと、punctuationの位置が消えてしまい情報が失われてしまう: - # 「こんにちは、世界。」と「こんにちは!世界。」と「こんにちは!!!???世界……。」は全て同じになる。 - # よって、まずpunctuation無しの音素とアクセントのリストを作り、 - # それとは別にpyopenjtalk.run_frontend()で得られる音素リスト(こちらはpunctuationが保持される)を使い、 - # アクセント割当をしなおすことによってpunctuationを含めた音素とアクセントのリストを作る。 - - # punctuationがすべて消えた、音素とアクセントのタプルのリスト(「ん」は「N」) - phone_tone_list_wo_punct = g2phone_tone_wo_punct(norm_text) - - # sep_text: 単語単位の単語のリスト、読めない文字があったらraise_yomi_errorなら例外、そうでないなら読めない文字が消えて返ってくる - # sep_kata: 単語単位の単語のカタカナ読みのリスト - sep_text, sep_kata = text2sep_kata(norm_text, raise_yomi_error=raise_yomi_error) - - # sep_phonemes: 各単語ごとの音素のリストのリスト - sep_phonemes = handle_long([kata2phoneme_list(i) for i in sep_kata]) - - # phone_w_punct: sep_phonemesを結合した、punctuationを元のまま保持した音素列 - phone_w_punct: list[str] = [] - for i in sep_phonemes: - phone_w_punct += i - - # punctuation無しのアクセント情報を使って、punctuationを含めたアクセント情報を作る - phone_tone_list = align_tones(phone_w_punct, phone_tone_list_wo_punct) - # logger.debug(f"phone_tone_list:\n{phone_tone_list}") - # word2phは厳密な解答は不可能なので(「今日」「眼鏡」等の熟字訓が存在)、 - # Bert-VITS2では、単語単位の分割を使って、単語の文字ごとにだいたい均等に音素を分配する - - # sep_textから、各単語を1文字1文字分割して、文字のリスト(のリスト)を作る - sep_tokenized: list[list[str]] = [] - for i in sep_text: - if i not in PUNCTUATIONS: - sep_tokenized.append( - tokenizer.tokenize(i) - ) # ここでおそらく`i`が文字単位に分割される - else: - sep_tokenized.append([i]) - - # 各単語について、音素の数と文字の数を比較して、均等っぽく分配する - word2ph = [] - for token, phoneme in zip(sep_tokenized, sep_phonemes): - phone_len = len(phoneme) - word_len = len(token) - word2ph += distribute_phone(phone_len, word_len) - - # 最初と最後に`_`記号を追加、アクセントは0(低)、word2phもそれに合わせて追加 - phone_tone_list = [("_", 0)] + phone_tone_list + [("_", 0)] - word2ph = [1] + word2ph + [1] - - phones = [phone for phone, _ in phone_tone_list] - tones = [tone for _, tone in phone_tone_list] - - assert len(phones) == sum(word2ph), f"{len(phones)} != {sum(word2ph)}" - - # use_jp_extraでない場合は「N」を「n」に変換 - if not use_jp_extra: - phones = [phone if phone != "N" else "n" for phone in phones] - - return phones, tones, word2ph - - -def g2kata_tone(norm_text: str) -> list[tuple[str, int]]: - """ - テキストからカタカナとアクセントのペアのリストを返す。 - 推論時のみに使われるので、常に`raise_yomi_error=False`でg2pを呼ぶ。 - """ - phones, tones, _ = g2p(norm_text, use_jp_extra=True, raise_yomi_error=False) - return phone_tone2kata_tone(list(zip(phones, tones))) - - -def phone_tone2kata_tone(phone_tone: list[tuple[str, int]]) -> list[tuple[str, int]]: - """phone_toneをのphone部分をカタカナに変換する。ただし最初と最後の("_", 0)は無視""" - phone_tone = phone_tone[1:] # 最初の("_", 0)を無視 - phones = [phone for phone, _ in phone_tone] - tones = [tone for _, tone in phone_tone] - result: list[tuple[str, int]] = [] - current_mora = "" - for phone, next_phone, tone, next_tone in zip(phones, phones[1:], tones, tones[1:]): - # zipの関係で最後の("_", 0)は無視されている - if phone in PUNCTUATIONS: - result.append((phone, tone)) - continue - if phone in COSONANTS: # n以外の子音の場合 - assert current_mora == "", f"Unexpected {phone} after {current_mora}" - assert tone == next_tone, f"Unexpected {phone} tone {tone} != {next_tone}" - current_mora = phone - else: - # phoneが母音もしくは「N」 - current_mora += phone - result.append((MORA_PHONEMES_TO_MORA_KATA[current_mora], tone)) - current_mora = "" - return result - - -def kata_tone2phone_tone(kata_tone: list[tuple[str, int]]) -> list[tuple[str, int]]: - """`phone_tone2kata_tone()`の逆。""" - result: list[tuple[str, int]] = [("_", 0)] - for mora, tone in kata_tone: - if mora in PUNCTUATIONS: - result.append((mora, tone)) - else: - cosonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora] - if cosonant is None: - result.append((vowel, tone)) - else: - result.append((cosonant, tone)) - result.append((vowel, tone)) - result.append(("_", 0)) - return result - - -def g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: - """ - テキストに対して、音素とアクセント(0か1)のペアのリストを返す。 - ただし「!」「.」「?」等の非音素記号(punctuation)は全て消える(ポーズ記号も残さない)。 - 非音素記号を含める処理は`align_tones()`で行われる。 - また「っ」は「q」に、「ん」は「N」に変換される。 - 例: "こんにちは、世界ー。。元気?!" → - [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)] - """ - prosodies = pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True) - # logger.debug(f"prosodies: {prosodies}") - result: list[tuple[str, int]] = [] - current_phrase: list[tuple[str, int]] = [] - current_tone = 0 - for i, letter in enumerate(prosodies): - # 特殊記号の処理 - - # 文頭記号、無視する - if letter == "^": - assert i == 0, "Unexpected ^" - # アクセント句の終わりに来る記号 - elif letter in ("$", "?", "_", "#"): - # 保持しているフレーズを、アクセント数値を0-1に修正し結果に追加 - result.extend(fix_phone_tone(current_phrase)) - # 末尾に来る終了記号、無視(文中の疑問文は`_`になる) - if letter in ("$", "?"): - assert i == len(prosodies) - 1, f"Unexpected {letter}" - # あとは"_"(ポーズ)と"#"(アクセント句の境界)のみ - # これらは残さず、次のアクセント句に備える。 - current_phrase = [] - # 0を基準点にしてそこから上昇・下降する(負の場合は上の`fix_phone_tone`で直る) - current_tone = 0 - # アクセント上昇記号 - elif letter == "[": - current_tone = current_tone + 1 - # アクセント下降記号 - elif letter == "]": - current_tone = current_tone - 1 - # それ以外は通常の音素 - else: - if letter == "cl": # 「っ」の処理 - letter = "q" - # elif letter == "N": # 「ん」の処理 - # letter = "n" - current_phrase.append((letter, current_tone)) - return result - - -def text2sep_kata( - norm_text: str, raise_yomi_error: bool = False -) -> tuple[list[str], list[str]]: - """ - `normalize_text()`で正規化済みの`norm_text`を受け取り、それを単語分割し、 - 分割された単語リストとその読み(カタカナor記号1文字)のリストのタプルを返す。 - 単語分割結果は、`g2p()`の`word2ph`で1文字あたりに割り振る音素記号の数を決めるために使う。 - 例: - `私はそう思う!って感じ?` → - ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"] - - raise_yomi_error: Trueの場合、読めない文字があるときに例外を発生させる。 - Falseの場合は読めない文字が消えたような扱いとして処理される。 - """ - # parsed: OpenJTalkの解析結果 - parsed = pyopenjtalk.run_frontend(norm_text) - sep_text: list[str] = [] - sep_kata: list[str] = [] - for parts in parsed: - # word: 実際の単語の文字列 - # yomi: その読み、但し無声化サインの`’`は除去 - word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace( - "’", "" - ) - """ - ここで`yomi`の取りうる値は以下の通りのはず。 - - `word`が通常単語 → 通常の読み(カタカナ) - (カタカナからなり、長音記号も含みうる、`アー` 等) - - `word`が`ー` から始まる → `ーラー` や `ーーー` など - - `word`が句読点や空白等 → `、` - - `word`がpunctuationの繰り返し → 全角にしたもの - 基本的にpunctuationは1文字ずつ分かれるが、何故かある程度連続すると1つにまとまる。 - 他にも`word`が読めないキリル文字アラビア文字等が来ると`、`になるが、正規化でこの場合は起きないはず。 - また元のコードでは`yomi`が空白の場合の処理があったが、これは起きないはず。 - 処理すべきは`yomi`が`、`の場合のみのはず。 - """ - assert yomi != "", f"Empty yomi: {word}" - if yomi == "、": - # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`, `--` のいずれか - if not set(word).issubset(set(PUNCTUATIONS)): # 記号繰り返しか判定 - # ここはpyopenjtalkが読めない文字等のときに起こる - if raise_yomi_error: - raise YomiError(f"Cannot read: {word} in:\n{norm_text}") - logger.warning(f"Ignoring unknown: {word} in:\n{norm_text}") - continue - # yomiは元の記号のままに変更 - yomi = word - elif yomi == "?": - assert word == "?", f"yomi `?` comes from: {word}" - yomi = "?" - sep_text.append(word) - sep_kata.append(yomi) - return sep_text, sep_kata - - -# ESPnetの実装から引用、変更点無し。「ん」は「N」なことに注意。 -# https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py -def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> list[str]: - """Extract phoneme + prosoody symbol sequence from input full-context labels. - - The algorithm is based on `Prosodic features control by symbols as input of - sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks. - - Args: - text (str): Input text. - drop_unvoiced_vowels (bool): whether to drop unvoiced vowels. - - Returns: - List[str]: List of phoneme + prosody symbols. - - Examples: - >>> from espnet2.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody - >>> pyopenjtalk_g2p_prosody("こんにちは。") - ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$'] - - .. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic - modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104 - - """ - labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text)) - N = len(labels) - - phones = [] - for n in range(N): - lab_curr = labels[n] - - # current phoneme - p3 = re.search(r"\-(.*?)\+", lab_curr).group(1) - # deal unvoiced vowels as normal vowels - if drop_unvoiced_vowels and p3 in "AEIOU": - p3 = p3.lower() - - # deal with sil at the beginning and the end of text - if p3 == "sil": - assert n == 0 or n == N - 1 - if n == 0: - phones.append("^") - elif n == N - 1: - # check question form or not - e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr) - if e3 == 0: - phones.append("$") - elif e3 == 1: - phones.append("?") - continue - elif p3 == "pau": - phones.append("_") - continue - else: - phones.append(p3) - - # accent type and position info (forward or backward) - a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr) - a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr) - a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr) - - # number of mora in accent phrase - f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr) - - a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1]) - # accent phrase border - if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl": - phones.append("#") - # pitch falling - elif a1 == 0 and a2_next == a2 + 1 and a2 != f1: - phones.append("]") - # pitch rising - elif a2 == 1 and a2_next == 2: - phones.append("[") - - return phones - - -def _numeric_feature_by_regex(regex, s): - match = re.search(regex, s) - if match is None: - return -50 - return int(match.group(1)) - - -def fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]: - """ - `phone_tone_list`のtone(アクセントの値)を0か1の範囲に修正する。 - 例: [(a, 0), (i, -1), (u, -1)] → [(a, 1), (i, 0), (u, 0)] - """ - tone_values = set(tone for _, tone in phone_tone_list) - if len(tone_values) == 1: - assert tone_values == {0}, tone_values - return phone_tone_list - elif len(tone_values) == 2: - if tone_values == {0, 1}: - return phone_tone_list - elif tone_values == {-1, 0}: - return [ - (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list - ] - else: - raise ValueError(f"Unexpected tone values: {tone_values}") - else: - raise ValueError(f"Unexpected tone values: {tone_values}") - - -def distribute_phone(n_phone: int, n_word: int) -> list[int]: - """ - 左から右に1ずつ振り分け、次にまた左から右に1ずつ増やし、というふうに、 - 音素の数`n_phone`を単語の数`n_word`に分配する。 - """ - phones_per_word = [0] * n_word - for _ in range(n_phone): - min_tasks = min(phones_per_word) - min_index = phones_per_word.index(min_tasks) - phones_per_word[min_index] += 1 - return phones_per_word - - -def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]: - """ - フレーズごとに分かれた音素(長音記号がそのまま)のリストのリスト`sep_phonemes`を受け取り、 - その長音記号を処理して、音素のリストのリストを返す。 - 基本的には直前の音素を伸ばすが、直前の音素が母音でない場合もしくは冒頭の場合は、 - おそらく長音記号とダッシュを勘違いしていると思われるので、ダッシュに対応する音素`-`に変換する。 - """ - for i in range(len(sep_phonemes)): - if len(sep_phonemes[i]) == 0: - # 空白文字等でリストが空の場合 - continue - if sep_phonemes[i][0] == "ー": - if i != 0: - prev_phoneme = sep_phonemes[i - 1][-1] - if prev_phoneme in VOWELS: - # 母音と「ん」のあとの伸ばし棒なので、その母音に変換 - sep_phonemes[i][0] = sep_phonemes[i - 1][-1] - else: - # 「。ーー」等おそらく予期しない長音記号 - # ダッシュの勘違いだと思われる - sep_phonemes[i][0] = "-" - else: - # 冒頭に長音記号が来ていおり、これはダッシュの勘違いと思われる - sep_phonemes[i][0] = "-" - if "ー" in sep_phonemes[i]: - for j in range(len(sep_phonemes[i])): - if sep_phonemes[i][j] == "ー": - sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1] - return sep_phonemes - - -tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese-char-wwm") - - -def align_tones( - phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]] -) -> list[tuple[str, int]]: - """ - 例: - …私は、、そう思う。 - phones_with_punct: - [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."] - phone_tone_list: - [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("_", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))] - Return: - [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)] - """ - result: list[tuple[str, int]] = [] - tone_index = 0 - for phone in phones_with_punct: - if tone_index >= len(phone_tone_list): - # 余ったpunctuationがある場合 → (punctuation, 0)を追加 - result.append((phone, 0)) - elif phone == phone_tone_list[tone_index][0]: - # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加 - result.append((phone, phone_tone_list[tone_index][1])) - # 探すindexを1つ進める - tone_index += 1 - elif phone in PUNCTUATIONS: - # phoneがpunctuationの場合 → (phone, 0)を追加 - result.append((phone, 0)) - else: - logger.debug(f"phones: {phones_with_punct}") - logger.debug(f"phone_tone_list: {phone_tone_list}") - logger.debug(f"result: {result}") - logger.debug(f"tone_index: {tone_index}") - logger.debug(f"phone: {phone}") - raise ValueError(f"Unexpected phone: {phone}") - return result - - -def kata2phoneme_list(text: str) -> list[str]: - """ - 原則カタカナの`text`を受け取り、それをそのままいじらずに音素記号のリストに変換。 - 注意点: - - punctuationかその繰り返しが来た場合、punctuationたちをそのままリストにして返す。 - - 冒頭に続く「ー」はそのまま「ー」のままにする(`handle_long()`で処理される) - - 文中の「ー」は前の音素記号の最後の音素記号に変換される。 - 例: - `ーーソーナノカーー` → ["ー", "ー", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"] - `?` → ["?"] - `!?!?!?!?!` → ["!", "?", "!", "?", "!", "?", "!", "?", "!"] - """ - if set(text).issubset(set(PUNCTUATIONS)): - return list(text) - # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック - if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None: - raise ValueError(f"Input must be katakana only: {text}") - sorted_keys = sorted(MORA_KATA_TO_MORA_PHONEMES.keys(), key=len, reverse=True) - pattern = "|".join(map(re.escape, sorted_keys)) - - def mora2phonemes(mora: str) -> str: - cosonant, vowel = MORA_KATA_TO_MORA_PHONEMES[mora] - if cosonant is None: - return f" {vowel}" - return f" {cosonant} {vowel}" - - spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text) - - # 長音記号「ー」の処理 - long_pattern = r"(\w)(ー*)" - long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2)) - spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes) - return spaced_phonemes.strip().split(" ") - - -if __name__ == "__main__": - tokenizer = AutoTokenizer.from_pretrained( - "./bert/deberta-v2-large-japanese-char-wwm" - ) - text = "こんにちは、世界。" - from text.japanese_bert import get_bert_feature - - text = normalize_text(text) - - phones, tones, word2ph = g2p(text) - bert = get_bert_feature(text, word2ph) - - print(phones, tones, word2ph, bert.shape) diff --git a/text/japanese_bert.py b/text/japanese_bert.py index fbeb94d6c..4c1bccc0c 100644 --- a/text/japanese_bert.py +++ b/text/japanese_bert.py @@ -1,24 +1,22 @@ import sys import torch -from transformers import AutoModelForMaskedLM, AutoTokenizer from config import config +from style_bert_vits2.constants import Languages +from style_bert_vits2.text_processing import bert_models from style_bert_vits2.text_processing.japanese.g2p import text_to_sep_kata -LOCAL_PATH = "./bert/deberta-v2-large-japanese-char-wwm" - -tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH) models = dict() def get_bert_feature( - text, + text: str, word2ph, - device=config.bert_gen_config.device, - assist_text=None, - assist_text_weight=0.7, + device = config.bert_gen_config.device, + assist_text: str | None = None, + assist_text_weight: float = 0.7, ): # 各単語が何文字かを作る`word2ph`を使う必要があるので、読めない文字は必ず無視する # でないと`word2ph`の結果とテキストの文字数結果が整合性が取れない @@ -37,8 +35,9 @@ def get_bert_feature( if device == "cuda" and not torch.cuda.is_available(): device = "cpu" if device not in models.keys(): - models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device) + models[device] = bert_models.load_model(Languages.JP).to(device) with torch.no_grad(): + tokenizer = bert_models.load_tokenizer(Languages.JP) inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) From 62919e904e643d0d72ef18fd5424260a38157cb8 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 03:32:07 +0000 Subject: [PATCH 023/148] Refactor: moved the module for extracting BERT features from text in each language to style_bert_vits2/text_processing/(language)/bert_feature.py --- app.py | 3 +- bert_gen.py | 8 +-- data_utils.py | 6 +- server_fastapi.py | 4 +- style_bert_vits2/models/infer.py | 4 +- style_bert_vits2/text_processing/__init__.py | 68 +++++++++++++++++++ .../text_processing/chinese/bert_feature.py | 36 +++++++--- .../text_processing/english/bert_feature.py | 36 +++++++--- .../text_processing/japanese/bert_feature.py | 39 ++++++++--- text/__init__.py | 43 ------------ text/chinese.py | 10 +-- text/english.py | 6 -- webui_dataset.py | 1 + webui_merge.py | 3 +- webui_style_vectors.py | 3 +- webui_train.py | 1 + 16 files changed, 171 insertions(+), 100 deletions(-) create mode 100644 style_bert_vits2/text_processing/__init__.py rename text/chinese_bert.py => style_bert_vits2/text_processing/chinese/bert_feature.py (74%) rename text/english_bert_mock.py => style_bert_vits2/text_processing/english/bert_feature.py (65%) rename text/japanese_bert.py => style_bert_vits2/text_processing/japanese/bert_feature.py (63%) delete mode 100644 text/__init__.py diff --git a/app.py b/app.py index a0a215cdd..02c01cdf3 100644 --- a/app.py +++ b/app.py @@ -10,6 +10,7 @@ import torch import yaml +from common.tts_model import ModelHolder from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, @@ -25,11 +26,11 @@ Languages, ) from style_bert_vits2.logging import logger -from common.tts_model import ModelHolder from style_bert_vits2.models.infer import InvalidToneError from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.text_processing.japanese.normalizer import normalize_text + # Get path settings with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: path_config: dict[str, str] = yaml.safe_load(f.read()) diff --git a/bert_gen.py b/bert_gen.py index fd0b54eff..2b512d211 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -5,12 +5,12 @@ import torch.multiprocessing as mp from tqdm import tqdm -from style_bert_vits2.models import commons import utils +from config import config from style_bert_vits2.logging import logger +from style_bert_vits2.models import commons +from style_bert_vits2.text_processing import cleaned_text_to_sequence, extract_bert_feature from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -from config import config -from text import cleaned_text_to_sequence, get_bert def process_line(x): @@ -45,7 +45,7 @@ def process_line(x): bert = torch.load(bert_path) assert bert.shape[-1] == len(phone) except Exception: - bert = get_bert(text, word2ph, language_str, device) + bert = extract_bert_feature(text, word2ph, language_str, device) assert bert.shape[-1] == len(phone) torch.save(bert, bert_path) diff --git a/data_utils.py b/data_utils.py index 111810250..7738247ea 100644 --- a/data_utils.py +++ b/data_utils.py @@ -7,12 +7,12 @@ import torch.utils.data from tqdm import tqdm -from style_bert_vits2.models import commons from config import config from mel_processing import mel_spectrogram_torch, spectrogram_torch -from text import cleaned_text_to_sequence -from style_bert_vits2.logging import logger from utils import load_filepaths_and_text, load_wav_to_torch +from style_bert_vits2.logging import logger +from style_bert_vits2.models import commons +from style_bert_vits2.text_processing import cleaned_text_to_sequence """Multi speaker version""" diff --git a/server_fastapi.py b/server_fastapi.py index 132cc175d..ca9520c17 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -20,6 +20,8 @@ from fastapi.responses import FileResponse, Response from scipy.io import wavfile +from common.tts_model import Model, ModelHolder +from config import config from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, @@ -33,8 +35,6 @@ Languages, ) from style_bert_vits2.logging import logger -from common.tts_model import Model, ModelHolder -from config import config ln = config.server_config.language diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index 99999a957..c4f6aed6b 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -1,12 +1,12 @@ import torch import utils -from text import cleaned_text_to_sequence, get_bert from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger from style_bert_vits2.models import commons from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra +from style_bert_vits2.text_processing import cleaned_text_to_sequence, extract_bert_feature from style_bert_vits2.text_processing.cleaner import clean_text from style_bert_vits2.text_processing.symbols import SYMBOLS @@ -77,7 +77,7 @@ def get_text( for i in range(len(word2ph)): word2ph[i] = word2ph[i] * 2 word2ph[0] += 1 - bert_ori = get_bert( + bert_ori = extract_bert_feature( norm_text, word2ph, language_str, diff --git a/style_bert_vits2/text_processing/__init__.py b/style_bert_vits2/text_processing/__init__.py new file mode 100644 index 000000000..42bb3a54b --- /dev/null +++ b/style_bert_vits2/text_processing/__init__.py @@ -0,0 +1,68 @@ +import torch + +from style_bert_vits2.constants import Languages +from style_bert_vits2.text_processing.symbols import ( + LANGUAGE_ID_MAP, + LANGUAGE_TONE_START_MAP, + SYMBOLS, +) + + +_symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)} + + +def cleaned_text_to_sequence(cleaned_text: str, tones: list[int], language: Languages) -> tuple[list[int], list[int], list[int]]: + """ + Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + + Args: + cleaned_text (str): string to convert to a sequence + tones (list[int]): List of tones + language (Languages): Language of the text + + Returns: + tuple[list[int], list[int], list[int]]: List of integers corresponding to the symbols in the text + """ + + phones = [_symbol_to_id[symbol] for symbol in cleaned_text] + tone_start = LANGUAGE_TONE_START_MAP[language] + tones = [i + tone_start for i in tones] + lang_id = LANGUAGE_ID_MAP[language] + lang_ids = [lang_id for i in phones] + + return phones, tones, lang_ids + + +def extract_bert_feature( + text: str, + word2ph: list[int], + language: Languages, + device: torch.device | str, + assist_text: str | None = None, + assist_text_weight: float = 0.7, +) -> torch.Tensor: + """ + テキストから BERT の特徴量を抽出する + + Args: + text (str): テキスト + word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト + language (Languages): テキストの言語 + device (torch.device | str): 推論に利用するデバイス + assist_text (str | None, optional): 補助テキスト (デフォルト: None) + assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) + + Returns: + torch.Tensor: BERT の特徴量 + """ + + if language == Languages.JP: + from style_bert_vits2.text_processing.japanese.bert_feature import extract_bert_feature + elif language == Languages.EN: + from style_bert_vits2.text_processing.english.bert_feature import extract_bert_feature + elif language == Languages.ZH: + from style_bert_vits2.text_processing.chinese.bert_feature import extract_bert_feature + else: + raise ValueError(f"Language {language} not supported") + + return extract_bert_feature(text, word2ph, device, assist_text, assist_text_weight) diff --git a/text/chinese_bert.py b/style_bert_vits2/text_processing/chinese/bert_feature.py similarity index 74% rename from text/chinese_bert.py rename to style_bert_vits2/text_processing/chinese/bert_feature.py index 2ee6e9d6c..c2085a403 100644 --- a/text/chinese_bert.py +++ b/style_bert_vits2/text_processing/chinese/bert_feature.py @@ -1,22 +1,36 @@ import sys import torch +from transformers import PreTrainedModel -from config import config from style_bert_vits2.constants import Languages from style_bert_vits2.text_processing import bert_models -models = dict() +models: dict[str, PreTrainedModel] = {} -def get_bert_feature( +def extract_bert_feature( text: str, - word2ph, - device = config.bert_gen_config.device, + word2ph: list[int], + device: torch.device | str, assist_text: str | None = None, assist_text_weight: float = 0.7, -): +) -> torch.Tensor: + """ + 中国語のテキストから BERT の特徴量を抽出する + + Args: + text (str): 中国語のテキスト + word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト + device (torch.device | str): 推論に利用するデバイス + assist_text (str | None, optional): 補助テキスト (デフォルト: None) + assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) + + Returns: + torch.Tensor: BERT の特徴量 + """ + if ( sys.platform == "darwin" and torch.backends.mps.is_available() @@ -28,26 +42,30 @@ def get_bert_feature( if device == "cuda" and not torch.cuda.is_available(): device = "cpu" if device not in models.keys(): - models[device] = bert_models.load_model(Languages.ZH).to(device) + models[device] = bert_models.load_model(Languages.ZH).to(device) # type: ignore + + style_res_mean = None with torch.no_grad(): tokenizer = bert_models.load_tokenizer(Languages.ZH) inputs = tokenizer(text, return_tensors="pt") for i in inputs: - inputs[i] = inputs[i].to(device) + inputs[i] = inputs[i].to(device) # type: ignore res = models[device](**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() if assist_text: style_inputs = tokenizer(assist_text, return_tensors="pt") for i in style_inputs: - style_inputs[i] = style_inputs[i].to(device) + style_inputs[i] = style_inputs[i].to(device) # type: ignore style_res = models[device](**style_inputs, output_hidden_states=True) style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) + assert len(word2ph) == len(text) + 2 word2phone = word2ph phone_level_feature = [] for i in range(len(word2phone)): if assist_text: + assert style_res_mean is not None repeat_feature = ( res[i].repeat(word2phone[i], 1) * (1 - assist_text_weight) + style_res_mean.repeat(word2phone[i], 1) * assist_text_weight diff --git a/text/english_bert_mock.py b/style_bert_vits2/text_processing/english/bert_feature.py similarity index 65% rename from text/english_bert_mock.py rename to style_bert_vits2/text_processing/english/bert_feature.py index 8e57df1da..d9c1025a0 100644 --- a/text/english_bert_mock.py +++ b/style_bert_vits2/text_processing/english/bert_feature.py @@ -1,22 +1,36 @@ import sys import torch +from transformers import PreTrainedModel -from config import config from style_bert_vits2.constants import Languages from style_bert_vits2.text_processing import bert_models -models = dict() +models: dict[str, PreTrainedModel] = {} -def get_bert_feature( +def extract_bert_feature( text: str, - word2ph, - device = config.bert_gen_config.device, + word2ph: list[int], + device: torch.device | str, assist_text: str | None = None, assist_text_weight: float = 0.7, -): +) -> torch.Tensor: + """ + 英語のテキストから BERT の特徴量を抽出する + + Args: + text (str): 英語のテキスト + word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト + device (torch.device | str): 推論に利用するデバイス + assist_text (str | None, optional): 補助テキスト (デフォルト: None) + assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) + + Returns: + torch.Tensor: BERT の特徴量 + """ + if ( sys.platform == "darwin" and torch.backends.mps.is_available() @@ -28,26 +42,30 @@ def get_bert_feature( if device == "cuda" and not torch.cuda.is_available(): device = "cpu" if device not in models.keys(): - models[device] = bert_models.load_model(Languages.EN).to(device) + models[device] = bert_models.load_model(Languages.EN).to(device) # type: ignore + + style_res_mean = None with torch.no_grad(): tokenizer = bert_models.load_tokenizer(Languages.EN) inputs = tokenizer(text, return_tensors="pt") for i in inputs: - inputs[i] = inputs[i].to(device) + inputs[i] = inputs[i].to(device) # type: ignore res = models[device](**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() if assist_text: style_inputs = tokenizer(assist_text, return_tensors="pt") for i in style_inputs: - style_inputs[i] = style_inputs[i].to(device) + style_inputs[i] = style_inputs[i].to(device) # type: ignore style_res = models[device](**style_inputs, output_hidden_states=True) style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) + assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph)) word2phone = word2ph phone_level_feature = [] for i in range(len(word2phone)): if assist_text: + assert style_res_mean is not None repeat_feature = ( res[i].repeat(word2phone[i], 1) * (1 - assist_text_weight) + style_res_mean.repeat(word2phone[i], 1) * assist_text_weight diff --git a/text/japanese_bert.py b/style_bert_vits2/text_processing/japanese/bert_feature.py similarity index 63% rename from text/japanese_bert.py rename to style_bert_vits2/text_processing/japanese/bert_feature.py index 4c1bccc0c..078bb5c43 100644 --- a/text/japanese_bert.py +++ b/style_bert_vits2/text_processing/japanese/bert_feature.py @@ -1,25 +1,39 @@ import sys import torch +from transformers import PreTrainedModel -from config import config from style_bert_vits2.constants import Languages from style_bert_vits2.text_processing import bert_models from style_bert_vits2.text_processing.japanese.g2p import text_to_sep_kata -models = dict() +models: dict[str, PreTrainedModel] = {} -def get_bert_feature( +def extract_bert_feature( text: str, - word2ph, - device = config.bert_gen_config.device, + word2ph: list[int], + device: torch.device | str, assist_text: str | None = None, assist_text_weight: float = 0.7, -): - # 各単語が何文字かを作る`word2ph`を使う必要があるので、読めない文字は必ず無視する - # でないと`word2ph`の結果とテキストの文字数結果が整合性が取れない +) -> torch.Tensor: + """ + 日本語のテキストから BERT の特徴量を抽出する + + Args: + text (str): 日本語のテキスト + word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト + device (torch.device | str): 推論に利用するデバイス + assist_text (str | None, optional): 補助テキスト (デフォルト: None) + assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) + + Returns: + torch.Tensor: BERT の特徴量 + """ + + # 各単語が何文字かを作る `word2ph` を使う必要があるので、読めない文字は必ず無視する + # でないと `word2ph` の結果とテキストの文字数結果が整合性が取れない text = "".join(text_to_sep_kata(text, raise_yomi_error=False)[0]) if assist_text: @@ -35,18 +49,20 @@ def get_bert_feature( if device == "cuda" and not torch.cuda.is_available(): device = "cpu" if device not in models.keys(): - models[device] = bert_models.load_model(Languages.JP).to(device) + models[device] = bert_models.load_model(Languages.JP).to(device) # type: ignore + + style_res_mean = None with torch.no_grad(): tokenizer = bert_models.load_tokenizer(Languages.JP) inputs = tokenizer(text, return_tensors="pt") for i in inputs: - inputs[i] = inputs[i].to(device) + inputs[i] = inputs[i].to(device) # type: ignore res = models[device](**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() if assist_text: style_inputs = tokenizer(assist_text, return_tensors="pt") for i in style_inputs: - style_inputs[i] = style_inputs[i].to(device) + style_inputs[i] = style_inputs[i].to(device) # type: ignore style_res = models[device](**style_inputs, output_hidden_states=True) style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) @@ -56,6 +72,7 @@ def get_bert_feature( phone_level_feature = [] for i in range(len(word2phone)): if assist_text: + assert style_res_mean is not None repeat_feature = ( res[i].repeat(word2phone[i], 1) * (1 - assist_text_weight) + style_res_mean.repeat(word2phone[i], 1) * assist_text_weight diff --git a/text/__init__.py b/text/__init__.py deleted file mode 100644 index efff8302d..000000000 --- a/text/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from style_bert_vits2.constants import Languages -from style_bert_vits2.text_processing.symbols import * - - -_symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)} - - -def cleaned_text_to_sequence(cleaned_text: str, tones: list[int], language: Languages): - """ - Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - - Args: - text: string to convert to a sequence - - Returns: - List of integers corresponding to the symbols in the text - """ - phones = [_symbol_to_id[symbol] for symbol in cleaned_text] - tone_start = LANGUAGE_TONE_START_MAP[language] - tones = [i + tone_start for i in tones] - lang_id = LANGUAGE_ID_MAP[language] - lang_ids = [lang_id for i in phones] - return phones, tones, lang_ids - - -def get_bert( - text: str, - word2ph, - language: Languages, - device: str, - assist_text: str | None = None, - assist_text_weight: float = 0.7, -): - if language == Languages.ZH: - from .chinese_bert import get_bert_feature - elif language == Languages.EN: - from .english_bert_mock import get_bert_feature - elif language == Languages.JP: - from .japanese_bert import get_bert_feature - else: - raise ValueError(f"Language {language} not supported") - - return get_bert_feature(text, word2ph, device, assist_text, assist_text_weight) diff --git a/text/chinese.py b/text/chinese.py index 3d9c39226..94266c247 100644 --- a/text/chinese.py +++ b/text/chinese.py @@ -176,20 +176,14 @@ def normalize_text(text): return text -def get_bert_feature(text, word2ph): - from text import chinese_bert - - return chinese_bert.get_bert_feature(text, word2ph) - - if __name__ == "__main__": - from text.chinese_bert import get_bert_feature + from style_bert_vits2.text_processing.chinese.bert_feature import extract_bert_feature text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏" text = normalize_text(text) print(text) phones, tones, word2ph = g2p(text) - bert = get_bert_feature(text, word2ph) + bert = extract_bert_feature(text, word2ph, 'cuda') print(phones, tones, word2ph, bert.shape) diff --git a/text/english.py b/text/english.py index a6d71e9cb..419be3b9d 100644 --- a/text/english.py +++ b/text/english.py @@ -477,12 +477,6 @@ def g2p(text): return phones, tones, word2ph -def get_bert_feature(text, word2ph): - from text import english_bert_mock - - return english_bert_mock.get_bert_feature(text, word2ph) - - if __name__ == "__main__": # print(get_dict()) # print(eng_word_to_phoneme("hello")) diff --git a/webui_dataset.py b/webui_dataset.py index 3ad63c163..1796864fd 100644 --- a/webui_dataset.py +++ b/webui_dataset.py @@ -8,6 +8,7 @@ from style_bert_vits2.logging import logger from style_bert_vits2.utils.subprocess import run_script_with_log + # Get path settings with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: path_config: dict[str, str] = yaml.safe_load(f.read()) diff --git a/webui_merge.py b/webui_merge.py index 0a3990258..5bf0568b4 100644 --- a/webui_merge.py +++ b/webui_merge.py @@ -11,9 +11,10 @@ from safetensors import safe_open from safetensors.torch import save_file +from common.tts_model import Model, ModelHolder from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME from style_bert_vits2.logging import logger -from common.tts_model import Model, ModelHolder + voice_keys = ["dec"] voice_pitch_keys = ["flow"] diff --git a/webui_style_vectors.py b/webui_style_vectors.py index cf53ca22a..1cbabbd51 100644 --- a/webui_style_vectors.py +++ b/webui_style_vectors.py @@ -12,9 +12,10 @@ from sklearn.manifold import TSNE from umap import UMAP +from config import config from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME from style_bert_vits2.logging import logger -from config import config + # Get path settings with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: diff --git a/webui_train.py b/webui_train.py index 6c89b261f..fda31ca2b 100644 --- a/webui_train.py +++ b/webui_train.py @@ -19,6 +19,7 @@ from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from style_bert_vits2.utils.subprocess import run_script_with_log, second_elem_of + logger_handler = None tensorboard_executed = False From 4f11b011fdca5a860ba83f9bdc4062b7667780fd Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 03:54:02 +0000 Subject: [PATCH 024/148] Refactor: minor adjustments --- .../text_processing/chinese/bert_feature.py | 2 +- .../text_processing/english/bert_feature.py | 2 +- .../text_processing/japanese/bert_feature.py | 2 +- style_bert_vits2/text_processing/japanese/g2p.py | 10 +++++----- .../text_processing/japanese/g2p_utils.py | 2 +- style_bert_vits2/text_processing/symbols.py | 14 ++++++++------ style_bert_vits2/utils/subprocess.py | 4 +--- utils.py | 15 --------------- 8 files changed, 18 insertions(+), 33 deletions(-) diff --git a/style_bert_vits2/text_processing/chinese/bert_feature.py b/style_bert_vits2/text_processing/chinese/bert_feature.py index c2085a403..25024cb25 100644 --- a/style_bert_vits2/text_processing/chinese/bert_feature.py +++ b/style_bert_vits2/text_processing/chinese/bert_feature.py @@ -7,7 +7,7 @@ from style_bert_vits2.text_processing import bert_models -models: dict[str, PreTrainedModel] = {} +models: dict[torch.device | str, PreTrainedModel] = {} def extract_bert_feature( diff --git a/style_bert_vits2/text_processing/english/bert_feature.py b/style_bert_vits2/text_processing/english/bert_feature.py index d9c1025a0..ec556c234 100644 --- a/style_bert_vits2/text_processing/english/bert_feature.py +++ b/style_bert_vits2/text_processing/english/bert_feature.py @@ -7,7 +7,7 @@ from style_bert_vits2.text_processing import bert_models -models: dict[str, PreTrainedModel] = {} +models: dict[torch.device | str, PreTrainedModel] = {} def extract_bert_feature( diff --git a/style_bert_vits2/text_processing/japanese/bert_feature.py b/style_bert_vits2/text_processing/japanese/bert_feature.py index 078bb5c43..3ff9d7b81 100644 --- a/style_bert_vits2/text_processing/japanese/bert_feature.py +++ b/style_bert_vits2/text_processing/japanese/bert_feature.py @@ -8,7 +8,7 @@ from style_bert_vits2.text_processing.japanese.g2p import text_to_sep_kata -models: dict[str, PreTrainedModel] = {} +models: dict[torch.device | str, PreTrainedModel] = {} def extract_bert_feature( diff --git a/style_bert_vits2/text_processing/japanese/g2p.py b/style_bert_vits2/text_processing/japanese/g2p.py index 8ef4d3eac..79687511f 100644 --- a/style_bert_vits2/text_processing/japanese/g2p.py +++ b/style_bert_vits2/text_processing/japanese/g2p.py @@ -391,7 +391,7 @@ def __kata_to_phoneme_list(text: str) -> list[str]: if set(text).issubset(set(PUNCTUATIONS)): return list(text) - # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック + # `text` がカタカナ(`ー`含む)のみからなるかどうかをチェック if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None: raise ValueError(f"Input must be katakana only: {text}") sorted_keys = sorted(MORA_KATA_TO_MORA_PHONEMES.keys(), key=len, reverse=True) @@ -438,15 +438,15 @@ def __align_tones( tone_index = 0 for phone in phones_with_punct: if tone_index >= len(phone_tone_list): - # 余ったpunctuationがある場合 → (punctuation, 0)を追加 + # 余った punctuation がある場合 → (punctuation, 0) を追加 result.append((phone, 0)) elif phone == phone_tone_list[tone_index][0]: - # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加 + # phone_tone_list の現在の音素と一致する場合 → tone をそこから取得、(phone, tone) を追加 result.append((phone, phone_tone_list[tone_index][1])) - # 探すindexを1つ進める + # 探す index を1つ進める tone_index += 1 elif phone in PUNCTUATIONS: - # phoneがpunctuationの場合 → (phone, 0)を追加 + # phone が punctuation の場合 → (phone, 0) を追加 result.append((phone, 0)) else: logger.debug(f"phones: {phones_with_punct}") diff --git a/style_bert_vits2/text_processing/japanese/g2p_utils.py b/style_bert_vits2/text_processing/japanese/g2p_utils.py index 3a91a00da..4ea56e819 100644 --- a/style_bert_vits2/text_processing/japanese/g2p_utils.py +++ b/style_bert_vits2/text_processing/japanese/g2p_utils.py @@ -9,7 +9,7 @@ def g2kata_tone(norm_text: str) -> list[tuple[str, int]]: """ テキストからカタカナとアクセントのペアのリストを返す。 - 推論時のみに使われるので、常に `raise_yomi_error=False` で g2p を呼ぶ。 + 推論時のみに使われる関数のため、常に `raise_yomi_error=False` を指定して g2p() を呼ぶ仕様になっている。 Args: norm_text: 正規化されたテキスト。 diff --git a/style_bert_vits2/text_processing/symbols.py b/style_bert_vits2/text_processing/symbols.py index d69bc1c42..e3a650f71 100644 --- a/style_bert_vits2/text_processing/symbols.py +++ b/style_bert_vits2/text_processing/symbols.py @@ -77,8 +77,8 @@ ] NUM_ZH_TONES = 6 -# japanese -JA_SYMBOLS = [ +# Japanese +JP_SYMBOLS = [ "N", "a", "a:", @@ -122,7 +122,7 @@ "z", "zy", ] -NUM_JA_TONES = 2 +NUM_JP_TONES = 2 # English EN_SYMBOLS = [ @@ -169,23 +169,25 @@ NUM_EN_TONES = 4 # Combine all symbols -NORMAL_SYMBOLS = sorted(set(ZH_SYMBOLS + JA_SYMBOLS + EN_SYMBOLS)) +NORMAL_SYMBOLS = sorted(set(ZH_SYMBOLS + JP_SYMBOLS + EN_SYMBOLS)) SYMBOLS = [PAD] + NORMAL_SYMBOLS + PUNCTUATION_SYMBOLS SIL_PHONEMES_IDS = [SYMBOLS.index(i) for i in PUNCTUATION_SYMBOLS] # Combine all tones -NUM_TONES = NUM_ZH_TONES + NUM_JA_TONES + NUM_EN_TONES +NUM_TONES = NUM_ZH_TONES + NUM_JP_TONES + NUM_EN_TONES # Language maps LANGUAGE_ID_MAP = {"ZH": 0, "JP": 1, "EN": 2} NUM_LANGUAGES = len(LANGUAGE_ID_MAP.keys()) +# Language tone start map LANGUAGE_TONE_START_MAP = { "ZH": 0, "JP": NUM_ZH_TONES, - "EN": NUM_ZH_TONES + NUM_JA_TONES, + "EN": NUM_ZH_TONES + NUM_JP_TONES, } + if __name__ == "__main__": a = set(ZH_SYMBOLS) b = set(EN_SYMBOLS) diff --git a/style_bert_vits2/utils/subprocess.py b/style_bert_vits2/utils/subprocess.py index b152702ef..5ff267b4e 100644 --- a/style_bert_vits2/utils/subprocess.py +++ b/style_bert_vits2/utils/subprocess.py @@ -5,8 +5,6 @@ from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -PYTHON = sys.executable - def run_script_with_log(cmd: list[str], ignore_warning: bool = False) -> tuple[bool, str]: """ @@ -22,7 +20,7 @@ def run_script_with_log(cmd: list[str], ignore_warning: bool = False) -> tuple[b logger.info(f"Running: {' '.join(cmd)}") result = subprocess.run( - [PYTHON] + cmd, + [sys.executable] + cmd, stdout = SAFE_STDOUT, stderr = subprocess.PIPE, text = True, diff --git a/utils.py b/utils.py index 80dfa66ed..d4ff765c7 100644 --- a/utils.py +++ b/utils.py @@ -450,21 +450,6 @@ def __repr__(self): return self.__dict__.__repr__() -def load_model(model_path, config_path): - hps = get_hparams_from_file(config_path) - net = SynthesizerTrn( - # len(symbols), - 108, - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **hps.model, - ).to("cpu") - _ = net.eval() - _ = load_checkpoint(model_path, net, None, skip_optimizer=True) - return net - - def mix_model( network1, network2, output_path, voice_ratio=(0.5, 0.5), tone_ratio=(0.5, 0.5) ): From b01168309d20296de273222052a537871e758678 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 04:08:56 +0000 Subject: [PATCH 025/148] Remove: remove currently unused code in utils.py --- utils.py | 286 ++++++++++++++++++++++--------------------------------- 1 file changed, 116 insertions(+), 170 deletions(-) diff --git a/utils.py b/utils.py index d4ff765c7..8c7e84265 100644 --- a/utils.py +++ b/utils.py @@ -8,26 +8,14 @@ import numpy as np import torch -from huggingface_hub import hf_hub_download from safetensors import safe_open from safetensors.torch import save_file from scipy.io.wavfile import read from style_bert_vits2.logging import logger -MATPLOTLIB_FLAG = False - -def download_checkpoint( - dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi" -): - repo_id = repo_config["repo_id"] - f_list = glob.glob(os.path.join(dir_path, regex)) - if f_list: - print("Use existed model, skip downloading.") - return - for file in ["DUR_0.pth", "D_0.pth", "G_0.pth"]: - hf_hub_download(repo_id, file, local_dir=dir_path, local_dir_use_symlinks=False) +MATPLOTLIB_FLAG = False def load_checkpoint( @@ -114,28 +102,54 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) ) -def save_safetensors(model, iteration, checkpoint_path, is_half=False, for_infer=False): - """ - Save model with safetensors. +def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): + """Freeing up space by deleting saved ckpts + + Arguments: + path_to_models -- Path to the model directory + n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth + sort_by_time -- True -> chronologically delete ckpts + False -> lexicographically delete ckpts """ - if hasattr(model, "module"): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - keys = [] - for k in state_dict: - if "enc_q" in k and for_infer: - continue # noqa: E701 - keys.append(k) + import re - new_dict = ( - {k: state_dict[k].half() for k in keys} - if is_half - else {k: state_dict[k] for k in keys} - ) - new_dict["iteration"] = torch.LongTensor([iteration]) - logger.info(f"Saved safetensors to {checkpoint_path}") - save_file(new_dict, checkpoint_path) + ckpts_files = [ + f + for f in os.listdir(path_to_models) + if os.path.isfile(os.path.join(path_to_models, f)) + ] + + def name_key(_f): + return int(re.compile("._(\\d+)\\.pth").match(_f).group(1)) + + def time_key(_f): + return os.path.getmtime(os.path.join(path_to_models, _f)) + + sort_key = time_key if sort_by_time else name_key + + def x_sorted(_x): + return sorted( + [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], + key=sort_key, + ) + + to_del = [ + os.path.join(path_to_models, fn) + for fn in ( + x_sorted("G_")[:-n_ckpts_to_keep] + + x_sorted("D_")[:-n_ckpts_to_keep] + + x_sorted("WD_")[:-n_ckpts_to_keep] + + x_sorted("DUR_")[:-n_ckpts_to_keep] + ) + ] + + def del_info(fn): + return logger.info(f"Free up space by deleting ckpt {fn}") + + def del_routine(x): + return [os.remove(x), del_info(x)] + + [del_routine(fn) for fn in to_del] def load_safetensors(checkpoint_path, model, for_infer=False): @@ -169,6 +183,30 @@ def load_safetensors(checkpoint_path, model, for_infer=False): return model, iteration +def save_safetensors(model, iteration, checkpoint_path, is_half=False, for_infer=False): + """ + Save model with safetensors. + """ + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + keys = [] + for k in state_dict: + if "enc_q" in k and for_infer: + continue # noqa: E701 + keys.append(k) + + new_dict = ( + {k: state_dict[k].half() for k in keys} + if is_half + else {k: state_dict[k] for k in keys} + ) + new_dict["iteration"] = torch.LongTensor([iteration]) + logger.info(f"Saved safetensors to {checkpoint_path}") + save_file(new_dict, checkpoint_path) + + def summarize( writer, global_step, @@ -274,6 +312,51 @@ def load_filepaths_and_text(filename, split="|"): return filepaths_and_text +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +def get_steps(model_path): + matches = re.findall(r"\d+", model_path) + return matches[-1] if matches else None + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warning( + "{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + ) + ) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warning( + "git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8] + ) + ) + else: + open(path, "w").write(cur_hash) + + def get_hparams(init=True): parser = argparse.ArgumentParser() parser.add_argument( @@ -307,67 +390,6 @@ def get_hparams(init=True): return hparams -def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): - """Freeing up space by deleting saved ckpts - - Arguments: - path_to_models -- Path to the model directory - n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth - sort_by_time -- True -> chronologically delete ckpts - False -> lexicographically delete ckpts - """ - import re - - ckpts_files = [ - f - for f in os.listdir(path_to_models) - if os.path.isfile(os.path.join(path_to_models, f)) - ] - - def name_key(_f): - return int(re.compile("._(\\d+)\\.pth").match(_f).group(1)) - - def time_key(_f): - return os.path.getmtime(os.path.join(path_to_models, _f)) - - sort_key = time_key if sort_by_time else name_key - - def x_sorted(_x): - return sorted( - [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], - key=sort_key, - ) - - to_del = [ - os.path.join(path_to_models, fn) - for fn in ( - x_sorted("G_")[:-n_ckpts_to_keep] - + x_sorted("D_")[:-n_ckpts_to_keep] - + x_sorted("WD_")[:-n_ckpts_to_keep] - + x_sorted("DUR_")[:-n_ckpts_to_keep] - ) - ] - - def del_info(fn): - return logger.info(f"Free up space by deleting ckpt {fn}") - - def del_routine(x): - return [os.remove(x), del_info(x)] - - [del_routine(fn) for fn in to_del] - - -def get_hparams_from_dir(model_dir): - config_save_path = os.path.join(model_dir, "config.json") - with open(config_save_path, "r", encoding="utf-8") as f: - data = f.read() - config = json.loads(data) - - hparams = HParams(**config) - hparams.model_dir = model_dir - return hparams - - def get_hparams_from_file(config_path): # print("config_path: ", config_path) with open(config_path, "r", encoding="utf-8") as f: @@ -378,46 +400,6 @@ def get_hparams_from_file(config_path): return hparams -def check_git_hash(model_dir): - source_dir = os.path.dirname(os.path.realpath(__file__)) - if not os.path.exists(os.path.join(source_dir, ".git")): - logger.warning( - "{} is not a git repository, therefore hash value comparison will be ignored.".format( - source_dir - ) - ) - return - - cur_hash = subprocess.getoutput("git rev-parse HEAD") - - path = os.path.join(model_dir, "githash") - if os.path.exists(path): - saved_hash = open(path).read() - if saved_hash != cur_hash: - logger.warning( - "git hash values are different. {}(saved) != {}(current)".format( - saved_hash[:8], cur_hash[:8] - ) - ) - else: - open(path, "w").write(cur_hash) - - -def get_logger(model_dir, filename="train.log"): - global logger - logger = logging.getLogger(os.path.basename(model_dir)) - logger.setLevel(logging.DEBUG) - - formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") - if not os.path.exists(model_dir): - os.makedirs(model_dir) - h = logging.FileHandler(os.path.join(model_dir, filename)) - h.setLevel(logging.DEBUG) - h.setFormatter(formatter) - logger.addHandler(h) - return logger - - class HParams: def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -448,39 +430,3 @@ def __contains__(self, key): def __repr__(self): return self.__dict__.__repr__() - - -def mix_model( - network1, network2, output_path, voice_ratio=(0.5, 0.5), tone_ratio=(0.5, 0.5) -): - if hasattr(network1, "module"): - state_dict1 = network1.module.state_dict() - state_dict2 = network2.module.state_dict() - else: - state_dict1 = network1.state_dict() - state_dict2 = network2.state_dict() - for k in state_dict1.keys(): - if k not in state_dict2.keys(): - continue - if "enc_p" in k: - state_dict1[k] = ( - state_dict1[k].clone() * tone_ratio[0] - + state_dict2[k].clone() * tone_ratio[1] - ) - else: - state_dict1[k] = ( - state_dict1[k].clone() * voice_ratio[0] - + state_dict2[k].clone() * voice_ratio[1] - ) - for k in state_dict2.keys(): - if k not in state_dict1.keys(): - state_dict1[k] = state_dict2[k].clone() - torch.save( - {"model": state_dict1, "iteration": 0, "optimizer": None, "learning_rate": 0}, - output_path, - ) - - -def get_steps(model_path): - matches = re.findall(r"\d+", model_path) - return matches[-1] if matches else None From def6d88425d296457c3b1e04f3bc969c6e5d60cd Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 04:19:40 +0000 Subject: [PATCH 026/148] Refactor: style_bert_vits2/text_processing/cleaner.py integrated into style_bert_vits2/text_processing/__init__.py This was often used in 3 function sets and felt like a wasteful division with few lines. --- preprocess_text.py | 2 +- style_bert_vits2/models/infer.py | 3 +- style_bert_vits2/text_processing/__init__.py | 83 ++++++++++++++------ style_bert_vits2/text_processing/cleaner.py | 40 ---------- 4 files changed, 63 insertions(+), 65 deletions(-) delete mode 100644 style_bert_vits2/text_processing/cleaner.py diff --git a/preprocess_text.py b/preprocess_text.py index 92e00b99b..03e1232ad 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -9,7 +9,7 @@ from config import config from style_bert_vits2.logging import logger -from style_bert_vits2.text_processing.cleaner import clean_text +from style_bert_vits2.text_processing import clean_text from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT preprocess_text_config = config.preprocess_text_config diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index c4f6aed6b..0e556d092 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -6,8 +6,7 @@ from style_bert_vits2.models import commons from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra -from style_bert_vits2.text_processing import cleaned_text_to_sequence, extract_bert_feature -from style_bert_vits2.text_processing.cleaner import clean_text +from style_bert_vits2.text_processing import clean_text, cleaned_text_to_sequence, extract_bert_feature from style_bert_vits2.text_processing.symbols import SYMBOLS diff --git a/style_bert_vits2/text_processing/__init__.py b/style_bert_vits2/text_processing/__init__.py index 42bb3a54b..cd56ee7d4 100644 --- a/style_bert_vits2/text_processing/__init__.py +++ b/style_bert_vits2/text_processing/__init__.py @@ -11,28 +11,6 @@ _symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)} -def cleaned_text_to_sequence(cleaned_text: str, tones: list[int], language: Languages) -> tuple[list[int], list[int], list[int]]: - """ - Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - - Args: - cleaned_text (str): string to convert to a sequence - tones (list[int]): List of tones - language (Languages): Language of the text - - Returns: - tuple[list[int], list[int], list[int]]: List of integers corresponding to the symbols in the text - """ - - phones = [_symbol_to_id[symbol] for symbol in cleaned_text] - tone_start = LANGUAGE_TONE_START_MAP[language] - tones = [i + tone_start for i in tones] - lang_id = LANGUAGE_ID_MAP[language] - lang_ids = [lang_id for i in phones] - - return phones, tones, lang_ids - - def extract_bert_feature( text: str, word2ph: list[int], @@ -66,3 +44,64 @@ def extract_bert_feature( raise ValueError(f"Language {language} not supported") return extract_bert_feature(text, word2ph, device, assist_text, assist_text_weight) + + +def clean_text( + text: str, + language: Languages, + use_jp_extra: bool = True, + raise_yomi_error: bool = False, +) -> tuple[str, list[str], list[int], list[int]]: + """ + テキストをクリーニングし、音素に変換する + + Args: + text (str): クリーニングするテキスト + language (Languages): テキストの言語 + use_jp_extra (bool, optional): テキストが日本語の場合に JP-Extra モデルを利用するかどうか。Defaults to True. + raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False. + + Returns: + tuple[str, list[str], list[int], list[int]]: クリーニングされたテキストと、音素・アクセント・元のテキストの各文字に音素が何個割り当てられるかのリスト + """ + + # Changed to import inside if condition to avoid unnecessary import + if language == Languages.JP: + from style_bert_vits2.text_processing.japanese.g2p import g2p + from style_bert_vits2.text_processing.japanese.normalizer import normalize_text + norm_text = normalize_text(text) + phones, tones, word2ph = g2p(norm_text, use_jp_extra, raise_yomi_error) + elif language == Languages.EN: + from ...text import english as language_module + norm_text = language_module.normalize_text(text) + phones, tones, word2ph = language_module.g2p(norm_text) + elif language == Languages.ZH: + from ...text import chinese as language_module + norm_text = language_module.normalize_text(text) + phones, tones, word2ph = language_module.g2p(norm_text) + else: + raise ValueError(f"Language {language} not supported") + + return norm_text, phones, tones, word2ph + + +def cleaned_text_to_sequence(cleaned_phones: list[str], tones: list[int], language: Languages) -> tuple[list[int], list[int], list[int]]: + """ + テキスト文字列を、テキスト内の記号に対応する一連の ID に変換する + + Args: + cleaned_phones (list[str]): clean_text() でクリーニングされた音素のリスト (?) + tones (list[int]): 各音素のアクセント + language (Languages): テキストの言語 + + Returns: + tuple[list[int], list[int], list[int]]: List of integers corresponding to the symbols in the text + """ + + phones = [_symbol_to_id[symbol] for symbol in cleaned_phones] + tone_start = LANGUAGE_TONE_START_MAP[language] + tones = [i + tone_start for i in tones] + lang_id = LANGUAGE_ID_MAP[language] + lang_ids = [lang_id for i in phones] + + return phones, tones, lang_ids diff --git a/style_bert_vits2/text_processing/cleaner.py b/style_bert_vits2/text_processing/cleaner.py deleted file mode 100644 index 042009964..000000000 --- a/style_bert_vits2/text_processing/cleaner.py +++ /dev/null @@ -1,40 +0,0 @@ -from style_bert_vits2.constants import Languages - - -def clean_text( - text: str, - language: Languages, - use_jp_extra: bool = True, - raise_yomi_error: bool = False, -) -> tuple[str, list[str], list[int], list[int]]: - """ - テキストをクリーニングし、音素に変換する - - Args: - text (str): クリーニングするテキスト - language (Languages): テキストの言語 - use_jp_extra (bool, optional): テキストが日本語の場合に JP-Extra モデルを利用するかどうか。Defaults to True. - raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False. - - Returns: - tuple[str, list[str], list[int], list[int]]: クリーニングされたテキストと、音素・アクセント・元のテキストの各文字に音素が何個割り当てられるかのリスト - """ - - # Changed to import inside if condition to avoid unnecessary import - if language == Languages.JP: - from style_bert_vits2.text_processing.japanese.g2p import g2p - from style_bert_vits2.text_processing.japanese.normalizer import normalize_text - norm_text = normalize_text(text) - phones, tones, word2ph = g2p(norm_text, use_jp_extra, raise_yomi_error) - elif language == Languages.EN: - from ...text import english as language_module - norm_text = language_module.normalize_text(text) - phones, tones, word2ph = language_module.g2p(norm_text) - elif language == Languages.ZH: - from ...text import chinese as language_module - norm_text = language_module.normalize_text(text) - phones, tones, word2ph = language_module.g2p(norm_text) - else: - raise ValueError(f"Language {language} not supported") - - return norm_text, phones, tones, word2ph From 1450bfd06f6510a3003a4613f234cb63d932b0e1 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 04:21:48 +0000 Subject: [PATCH 027/148] Remove: remove webui.py, which is no longer maintained in Style-Bert-VITS2 Since app.py and server_editor.py already exist as alternative Web UI, there is no need to revive webui.py in the future. --- webui.py | 559 ------------------------------------------------------- 1 file changed, 559 deletions(-) delete mode 100644 webui.py diff --git a/webui.py b/webui.py deleted file mode 100644 index 1f31e9483..000000000 --- a/webui.py +++ /dev/null @@ -1,559 +0,0 @@ -""" -Original `webui.py` for Bert-VITS2, not working with Style-Bert-VITS2 yet. -""" - -# flake8: noqa: E402 -import os -import logging -import re_matching -from tools.sentence import split_by_language - -logging.getLogger("numba").setLevel(logging.WARNING) -logging.getLogger("markdown_it").setLevel(logging.WARNING) -logging.getLogger("urllib3").setLevel(logging.WARNING) -logging.getLogger("matplotlib").setLevel(logging.WARNING) - -logging.basicConfig( - level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s" -) - -logger = logging.getLogger(__name__) - -import gradio as gr -import librosa -import numpy as np -import torch -import webbrowser - -import utils -from config import config -from style_bert_vits2.models.infer import infer, latest_version, get_net_g, infer_multilang -from tools.translate import translate - -net_g = None - -device = config.webui_config.device -if device == "mps": - os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - - -def generate_audio( - slices, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - speaker, - language, - reference_audio, - emotion, - style_text, - style_weight, - skip_start=False, - skip_end=False, -): - audio_list = [] - # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16) - with torch.no_grad(): - for idx, piece in enumerate(slices): - skip_start = idx != 0 - skip_end = idx != len(slices) - 1 - audio = infer( - piece, - reference_audio=reference_audio, - emotion=emotion, - sdp_ratio=sdp_ratio, - noise_scale=noise_scale, - noise_scale_w=noise_scale_w, - length_scale=length_scale, - sid=speaker, - language=language, - hps=hps, - net_g=net_g, - device=device, - skip_start=skip_start, - skip_end=skip_end, - assist_text=style_text, - assist_text_weight=style_weight, - ) - audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio) - audio_list.append(audio16bit) - return audio_list - - -def generate_audio_multilang( - slices, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - speaker, - language, - reference_audio, - emotion, - skip_start=False, - skip_end=False, -): - audio_list = [] - # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16) - with torch.no_grad(): - for idx, piece in enumerate(slices): - skip_start = idx != 0 - skip_end = idx != len(slices) - 1 - audio = infer_multilang( - piece, - reference_audio=reference_audio, - emotion=emotion, - sdp_ratio=sdp_ratio, - noise_scale=noise_scale, - noise_scale_w=noise_scale_w, - length_scale=length_scale, - sid=speaker, - language=language[idx], - hps=hps, - net_g=net_g, - device=device, - skip_start=skip_start, - skip_end=skip_end, - ) - audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio) - audio_list.append(audio16bit) - return audio_list - - -def tts_split( - text: str, - speaker, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - language, - cut_by_sent, - interval_between_para, - interval_between_sent, - reference_audio, - emotion, - style_text, - style_weight, -): - while text.find("\n\n") != -1: - text = text.replace("\n\n", "\n") - text = text.replace("|", "") - para_list = re_matching.cut_para(text) - para_list = [p for p in para_list if p != ""] - audio_list = [] - for p in para_list: - if not cut_by_sent: - audio_list += process_text( - p, - speaker, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - language, - reference_audio, - emotion, - style_text, - style_weight, - ) - silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16) - audio_list.append(silence) - else: - audio_list_sent = [] - sent_list = re_matching.cut_sent(p) - sent_list = [s for s in sent_list if s != ""] - for s in sent_list: - audio_list_sent += process_text( - s, - speaker, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - language, - reference_audio, - emotion, - style_text, - style_weight, - ) - silence = np.zeros((int)(44100 * interval_between_sent)) - audio_list_sent.append(silence) - if (interval_between_para - interval_between_sent) > 0: - silence = np.zeros( - (int)(44100 * (interval_between_para - interval_between_sent)) - ) - audio_list_sent.append(silence) - audio16bit = gr.processing_utils.convert_to_16_bit_wav( - np.concatenate(audio_list_sent) - ) # 对完整句子做音量归一 - audio_list.append(audio16bit) - audio_concat = np.concatenate(audio_list) - return ("Success", (hps.data.sampling_rate, audio_concat)) - - -def process_mix(slice): - _speaker = slice.pop() - _text, _lang = [], [] - for lang, content in slice: - content = content.split("|") - content = [part for part in content if part != ""] - if len(content) == 0: - continue - if len(_text) == 0: - _text = [[part] for part in content] - _lang = [[lang] for part in content] - else: - _text[-1].append(content[0]) - _lang[-1].append(lang) - if len(content) > 1: - _text += [[part] for part in content[1:]] - _lang += [[lang] for part in content[1:]] - return _text, _lang, _speaker - - -def process_auto(text): - _text, _lang = [], [] - for slice in text.split("|"): - if slice == "": - continue - temp_text, temp_lang = [], [] - sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"]) - for sentence, lang in sentences_list: - if sentence == "": - continue - temp_text.append(sentence) - if lang == "ja": - lang = "jp" - temp_lang.append(lang.upper()) - _text.append(temp_text) - _lang.append(temp_lang) - return _text, _lang - - -def process_text( - text: str, - speaker, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - language, - reference_audio, - emotion, - style_text=None, - style_weight=0, -): - audio_list = [] - if language == "mix": - bool_valid, str_valid = re_matching.validate_text(text) - if not bool_valid: - return str_valid, ( - hps.data.sampling_rate, - np.concatenate([np.zeros(hps.data.sampling_rate // 2)]), - ) - for slice in re_matching.text_matching(text): - _text, _lang, _speaker = process_mix(slice) - if _speaker is None: - continue - print(f"Text: {_text}\nLang: {_lang}") - audio_list.extend( - generate_audio_multilang( - _text, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - _speaker, - _lang, - reference_audio, - emotion, - ) - ) - elif language.lower() == "auto": - _text, _lang = process_auto(text) - print(f"Text: {_text}\nLang: {_lang}") - audio_list.extend( - generate_audio_multilang( - _text, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - speaker, - _lang, - reference_audio, - emotion, - ) - ) - else: - audio_list.extend( - generate_audio( - text.split("|"), - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - speaker, - language, - reference_audio, - emotion, - style_text, - style_weight, - ) - ) - return audio_list - - -def tts_fn( - text: str, - speaker, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - language, - reference_audio, - emotion, - prompt_mode, - style_text=None, - style_weight=0, -): - if style_text == "": - style_text = None - if prompt_mode == "Audio prompt": - if reference_audio == None: - return ("Invalid audio prompt", None) - else: - reference_audio = load_audio(reference_audio)[1] - else: - reference_audio = None - - audio_list = process_text( - text, - speaker, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - language, - reference_audio, - emotion, - style_text, - style_weight, - ) - - audio_concat = np.concatenate(audio_list) - return "Success", (hps.data.sampling_rate, audio_concat) - - -def format_utils(text, speaker): - _text, _lang = process_auto(text) - res = f"[{speaker}]" - for lang_s, content_s in zip(_lang, _text): - for lang, content in zip(lang_s, content_s): - res += f"<{lang.lower()}>{content}" - res += "|" - return "mix", res[:-1] - - -def load_audio(path): - audio, sr = librosa.load(path, 48000) - # audio = librosa.resample(audio, 44100, 48000) - return sr, audio - - -def gr_util(item): - if item == "Text prompt": - return {"visible": True, "__type__": "update"}, { - "visible": False, - "__type__": "update", - } - else: - return {"visible": False, "__type__": "update"}, { - "visible": True, - "__type__": "update", - } - - -if __name__ == "__main__": - if config.webui_config.debug: - logger.info("Enable DEBUG-LEVEL log") - logging.basicConfig(level=logging.DEBUG) - hps = utils.get_hparams_from_file(config.webui_config.config_path) - # 若config.json中未指定版本则默认为最新版本 - version = hps.version if hasattr(hps, "version") else latest_version - net_g = get_net_g( - model_path=config.webui_config.model, version=version, device=device, hps=hps - ) - speaker_ids = hps.data.spk2id - speakers = list(speaker_ids.keys()) - languages = ["ZH", "JP", "EN", "mix", "auto"] - with gr.Blocks() as app: - with gr.Row(): - with gr.Column(): - text = gr.TextArea( - label="输入文本内容", - placeholder=""" - 如果你选择语言为\'mix\',必须按照格式输入,否则报错: - 格式举例(zh是中文,jp是日语,不区分大小写;说话人举例:gongzi): - [说话人1]你好,こんにちは! こんにちは,世界。 - [说话人2]你好吗?元気ですか? - [说话人3]谢谢。どういたしまして。 - ... - 另外,所有的语言选项都可以用'|'分割长段实现分句生成。 - """, - ) - trans = gr.Button("中翻日", variant="primary") - slicer = gr.Button("快速切分", variant="primary") - formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary") - speaker = gr.Dropdown( - choices=speakers, value=speakers[0], label="Speaker" - ) - _ = gr.Markdown( - value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n", - visible=False, - ) - prompt_mode = gr.Radio( - ["Text prompt", "Audio prompt"], - label="Prompt Mode", - value="Text prompt", - visible=False, - ) - text_prompt = gr.Textbox( - label="Text prompt", - placeholder="用文字描述生成风格。如:Happy", - value="Happy", - visible=False, - ) - audio_prompt = gr.Audio( - label="Audio prompt", type="filepath", visible=False - ) - sdp_ratio = gr.Slider( - minimum=0, maximum=1, value=0.5, step=0.1, label="SDP Ratio" - ) - noise_scale = gr.Slider( - minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise" - ) - noise_scale_w = gr.Slider( - minimum=0.1, maximum=2, value=0.9, step=0.1, label="Noise_W" - ) - length_scale = gr.Slider( - minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length" - ) - language = gr.Dropdown( - choices=languages, value=languages[0], label="Language" - ) - btn = gr.Button("生成音频!", variant="primary") - with gr.Column(): - with gr.Accordion("融合文本语义", open=False): - gr.Markdown( - value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n" - "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n" - "效果较不明确,留空即为不使用该功能" - ) - style_text = gr.Textbox(label="辅助文本") - style_weight = gr.Slider( - minimum=0, - maximum=1, - value=0.7, - step=0.1, - label="Weight", - info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本", - ) - with gr.Row(): - with gr.Column(): - interval_between_sent = gr.Slider( - minimum=0, - maximum=5, - value=0.2, - step=0.1, - label="句间停顿(秒),勾选按句切分才生效", - ) - interval_between_para = gr.Slider( - minimum=0, - maximum=10, - value=1, - step=0.1, - label="段间停顿(秒),需要大于句间停顿才有效", - ) - opt_cut_by_sent = gr.Checkbox( - label="按句切分 在按段落切分的基础上再按句子切分文本" - ) - slicer = gr.Button("切分生成", variant="primary") - text_output = gr.Textbox(label="状态信息") - audio_output = gr.Audio(label="输出音频") - # explain_image = gr.Image( - # label="参数解释信息", - # show_label=True, - # show_share_button=False, - # show_download_button=False, - # value=os.path.abspath("./img/参数说明.png"), - # ) - btn.click( - tts_fn, - inputs=[ - text, - speaker, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - language, - audio_prompt, - text_prompt, - prompt_mode, - style_text, - style_weight, - ], - outputs=[text_output, audio_output], - ) - - trans.click( - translate, - inputs=[text], - outputs=[text], - ) - slicer.click( - tts_split, - inputs=[ - text, - speaker, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - language, - opt_cut_by_sent, - interval_between_para, - interval_between_sent, - audio_prompt, - text_prompt, - style_text, - style_weight, - ], - outputs=[text_output, audio_output], - ) - - prompt_mode.change( - lambda x: gr_util(x), - inputs=[prompt_mode], - outputs=[text_prompt, audio_prompt], - ) - - audio_prompt.upload( - lambda x: load_audio(x), - inputs=[audio_prompt], - outputs=[audio_prompt], - ) - - formatter.click( - format_utils, - inputs=[text, speaker], - outputs=[language, text], - ) - - print("推理页面已开启!") - webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}") - app.launch(share=config.webui_config.share, server_port=config.webui_config.port) From d36401849b74e155ac0972569d6fa4e040eb6928 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 04:36:05 +0000 Subject: [PATCH 028/148] Refactor: moved monotonic_align/ to style_bert_vits2/models/monotonic_alignment.py --- monotonic_align/__init__.py | 16 ---- monotonic_align/core.py | 46 ---------- style_bert_vits2/models/infer.py | 8 +- style_bert_vits2/models/models.py | 4 +- style_bert_vits2/models/models_jp_extra.py | 4 +- style_bert_vits2/models/modules.py | 1 + .../models/monotonic_alignment.py | 88 +++++++++++++++++++ 7 files changed, 97 insertions(+), 70 deletions(-) delete mode 100644 monotonic_align/__init__.py delete mode 100644 monotonic_align/core.py create mode 100644 style_bert_vits2/models/monotonic_alignment.py diff --git a/monotonic_align/__init__.py b/monotonic_align/__init__.py deleted file mode 100644 index 15d8e60c4..000000000 --- a/monotonic_align/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from numpy import zeros, int32, float32 -from torch import from_numpy - -from .core import maximum_path_jit - - -def maximum_path(neg_cent, mask): - device = neg_cent.device - dtype = neg_cent.dtype - neg_cent = neg_cent.data.cpu().numpy().astype(float32) - path = zeros(neg_cent.shape, dtype=int32) - - t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32) - t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32) - maximum_path_jit(path, neg_cent, t_t_max, t_s_max) - return from_numpy(path).to(device=device, dtype=dtype) diff --git a/monotonic_align/core.py b/monotonic_align/core.py deleted file mode 100644 index ffa489da5..000000000 --- a/monotonic_align/core.py +++ /dev/null @@ -1,46 +0,0 @@ -import numba - - -@numba.jit( - numba.void( - numba.int32[:, :, ::1], - numba.float32[:, :, ::1], - numba.int32[::1], - numba.int32[::1], - ), - nopython=True, - nogil=True, -) -def maximum_path_jit(paths, values, t_ys, t_xs): - b = paths.shape[0] - max_neg_val = -1e9 - for i in range(int(b)): - path = paths[i] - value = values[i] - t_y = t_ys[i] - t_x = t_xs[i] - - v_prev = v_cur = 0.0 - index = t_x - 1 - - for y in range(t_y): - for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): - if x == y: - v_cur = max_neg_val - else: - v_cur = value[y - 1, x] - if x == 0: - if y == 0: - v_prev = 0.0 - else: - v_prev = max_neg_val - else: - v_prev = value[y - 1, x - 1] - value[y, x] += max(v_prev, v_cur) - - for y in range(t_y - 1, -1, -1): - path[y, index] = 1 - if index != 0 and ( - index == y or value[y - 1, index] < value[y - 1, index - 1] - ): - index = index - 1 diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index 0e556d092..5eb9ebda5 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -10,10 +10,6 @@ from style_bert_vits2.text_processing.symbols import SYMBOLS -class InvalidToneError(ValueError): - pass - - def get_net_g(model_path: str, version: str, device: str, hps): if version.endswith("JP-Extra"): logger.info("Using JP-Extra model") @@ -315,3 +311,7 @@ def infer_multilang( if torch.cuda.is_available(): torch.cuda.empty_cache() return audio + + +class InvalidToneError(ValueError): + pass diff --git a/style_bert_vits2/models/models.py b/style_bert_vits2/models/models.py index a8c669596..7f14be431 100644 --- a/style_bert_vits2/models/models.py +++ b/style_bert_vits2/models/models.py @@ -6,10 +6,10 @@ from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm -import monotonic_align from style_bert_vits2.models import attentions from style_bert_vits2.models import commons from style_bert_vits2.models import modules +from style_bert_vits2.models import monotonic_alignment from style_bert_vits2.models.commons import get_padding, init_weights from style_bert_vits2.text_processing.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS @@ -932,7 +932,7 @@ def forward( attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn = ( - monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)) + monotonic_alignment.maximum_path(neg_cent, attn_mask.squeeze(1)) .unsqueeze(1) .detach() ) diff --git a/style_bert_vits2/models/models_jp_extra.py b/style_bert_vits2/models/models_jp_extra.py index 8c4a4ecf0..54591f6fb 100644 --- a/style_bert_vits2/models/models_jp_extra.py +++ b/style_bert_vits2/models/models_jp_extra.py @@ -6,10 +6,10 @@ from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm -import monotonic_align from style_bert_vits2.models import attentions from style_bert_vits2.models import commons from style_bert_vits2.models import modules +from style_bert_vits2.models import monotonic_alignment from style_bert_vits2.text_processing.symbols import SYMBOLS, NUM_TONES, NUM_LANGUAGES @@ -979,7 +979,7 @@ def forward( attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn = ( - monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)) + monotonic_alignment.maximum_path(neg_cent, attn_mask.squeeze(1)) .unsqueeze(1) .detach() ) diff --git a/style_bert_vits2/models/modules.py b/style_bert_vits2/models/modules.py index e0885c4b8..c076e212b 100644 --- a/style_bert_vits2/models/modules.py +++ b/style_bert_vits2/models/modules.py @@ -10,6 +10,7 @@ from style_bert_vits2.models import commons from style_bert_vits2.models.attentions import Encoder + LRELU_SLOPE = 0.1 diff --git a/style_bert_vits2/models/monotonic_alignment.py b/style_bert_vits2/models/monotonic_alignment.py new file mode 100644 index 000000000..0f393c19b --- /dev/null +++ b/style_bert_vits2/models/monotonic_alignment.py @@ -0,0 +1,88 @@ +""" +以下に記述されている関数のコメントはリファクタリング時に GPT-4 に生成させたもので、 +コードと完全に一致している保証はない。あくまで参考程度とすること。 +""" + +import numba +import torch +from numpy import int32, float32, zeros +from typing import Any + + +def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + 与えられた負の中心とマスクを使用して最大パスを計算する + + Args: + neg_cent (torch.Tensor): 負の中心を表すテンソル + mask (torch.Tensor): マスクを表すテンソル + + Returns: + Tensor: 計算された最大パスを表すテンソル + """ + + device = neg_cent.device + dtype = neg_cent.dtype + neg_cent = neg_cent.data.cpu().numpy().astype(float32) + path = zeros(neg_cent.shape, dtype=int32) + + t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32) + t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32) + maximum_path_jit(path, neg_cent, t_t_max, t_s_max) + + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +@numba.jit( + numba.void( + numba.int32[:, :, ::1], + numba.float32[:, :, ::1], + numba.int32[::1], + numba.int32[::1], + ), + nopython = True, + nogil = True, +) # type: ignore +def maximum_path_jit(paths: Any, values: Any, t_ys: Any, t_xs: Any) -> None: + """ + 与えられたパス、値、およびターゲットの y と x 座標を使用して JIT で最大パスを計算する + + Args: + paths: 計算されたパスを格納するための整数型の 3 次元配列 + values: 値を格納するための浮動小数点型の 3 次元配列 + t_ys: ターゲットの y 座標を格納するための整数型の 1 次元配列 + t_xs: ターゲットの x 座標を格納するための整数型の 1 次元配列 + """ + + b = paths.shape[0] + max_neg_val = -1e9 + for i in range(int(b)): + path = paths[i] + value = values[i] + t_y = t_ys[i] + t_x = t_xs[i] + + v_prev = v_cur = 0.0 + index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and ( + index == y or value[y - 1, index] < value[y - 1, index - 1] + ): + index = index - 1 From f8f798d10a63739c9329d386e49657715b26c7b3 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 04:48:11 +0000 Subject: [PATCH 029/148] Refactor: moved text/ to style_bert_vits2/text_processing/(language)/ --- app.py | 2 +- server_editor.py | 2 +- style_bert_vits2/text_processing/__init__.py | 15 +++++++-------- .../text_processing/chinese/__init__.py | 4 ++-- .../text_processing/chinese}/tone_sandhi.py | 0 .../text_processing/english/__init__.py | 4 ++-- .../text_processing/english}/cmudict.rep | 0 .../english}/cmudict_cache.pickle | Bin .../text_processing/english}/opencpop-strict.txt | 0 .../text_processing/japanese/__init__.py | 2 ++ 10 files changed, 15 insertions(+), 14 deletions(-) rename text/chinese.py => style_bert_vits2/text_processing/chinese/__init__.py (98%) rename {text => style_bert_vits2/text_processing/chinese}/tone_sandhi.py (100%) rename text/english.py => style_bert_vits2/text_processing/english/__init__.py (99%) rename {text => style_bert_vits2/text_processing/english}/cmudict.rep (100%) rename {text => style_bert_vits2/text_processing/english}/cmudict_cache.pickle (100%) rename {text => style_bert_vits2/text_processing/english}/opencpop-strict.txt (100%) create mode 100644 style_bert_vits2/text_processing/japanese/__init__.py diff --git a/app.py b/app.py index 02c01cdf3..f03e91b69 100644 --- a/app.py +++ b/app.py @@ -27,8 +27,8 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.models.infer import InvalidToneError +from style_bert_vits2.text_processing.japanese import normalize_text from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone -from style_bert_vits2.text_processing.japanese.normalizer import normalize_text # Get path settings diff --git a/server_editor.py b/server_editor.py index eb1cffd87..a1e249dea 100644 --- a/server_editor.py +++ b/server_editor.py @@ -42,8 +42,8 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.text_processing import bert_models +from style_bert_vits2.text_processing.japanese import normalize_text from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone -from style_bert_vits2.text_processing.japanese.normalizer import normalize_text from style_bert_vits2.text_processing.japanese.user_dict import ( apply_word, update_dict, diff --git a/style_bert_vits2/text_processing/__init__.py b/style_bert_vits2/text_processing/__init__.py index cd56ee7d4..5e77df624 100644 --- a/style_bert_vits2/text_processing/__init__.py +++ b/style_bert_vits2/text_processing/__init__.py @@ -67,18 +67,17 @@ def clean_text( # Changed to import inside if condition to avoid unnecessary import if language == Languages.JP: - from style_bert_vits2.text_processing.japanese.g2p import g2p - from style_bert_vits2.text_processing.japanese.normalizer import normalize_text + from style_bert_vits2.text_processing.japanese import g2p, normalize_text norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text, use_jp_extra, raise_yomi_error) elif language == Languages.EN: - from ...text import english as language_module - norm_text = language_module.normalize_text(text) - phones, tones, word2ph = language_module.g2p(norm_text) + from style_bert_vits2.text_processing.english import g2p, normalize_text + norm_text = normalize_text(text) + phones, tones, word2ph = g2p(norm_text) elif language == Languages.ZH: - from ...text import chinese as language_module - norm_text = language_module.normalize_text(text) - phones, tones, word2ph = language_module.g2p(norm_text) + from style_bert_vits2.text_processing.chinese import g2p, normalize_text + norm_text = normalize_text(text) + phones, tones, word2ph = g2p(norm_text) else: raise ValueError(f"Language {language} not supported") diff --git a/text/chinese.py b/style_bert_vits2/text_processing/chinese/__init__.py similarity index 98% rename from text/chinese.py rename to style_bert_vits2/text_processing/chinese/__init__.py index 94266c247..92c821425 100644 --- a/text/chinese.py +++ b/style_bert_vits2/text_processing/chinese/__init__.py @@ -66,7 +66,7 @@ def replace_punctuation(text): return replaced_text -def g2p(text): +def g2p(text: str) -> tuple[list[str], list[int], list[int]]: pattern = r"(?<=[{0}])\s*".format("".join(PUNCTUATIONS)) sentences = [i for i in re.split(pattern, text) if i.strip() != ""] phones, tones, word2ph = _g2p(sentences) @@ -168,7 +168,7 @@ def _g2p(segments): return phones_list, tones_list, word2ph -def normalize_text(text): +def normalize_text(text: str) -> str: numbers = re.findall(r"\d+(?:\.?\d+)?", text) for number in numbers: text = text.replace(number, cn2an.an2cn(number), 1) diff --git a/text/tone_sandhi.py b/style_bert_vits2/text_processing/chinese/tone_sandhi.py similarity index 100% rename from text/tone_sandhi.py rename to style_bert_vits2/text_processing/chinese/tone_sandhi.py diff --git a/text/english.py b/style_bert_vits2/text_processing/english/__init__.py similarity index 99% rename from text/english.py rename to style_bert_vits2/text_processing/english/__init__.py index 419be3b9d..852431aaa 100644 --- a/text/english.py +++ b/style_bert_vits2/text_processing/english/__init__.py @@ -369,7 +369,7 @@ def normalize_numbers(text): return text -def normalize_text(text): +def normalize_text(text: str) -> str: text = normalize_numbers(text) text = replace_punctuation(text) text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text) @@ -419,7 +419,7 @@ def text_to_words(text): return words -def g2p(text): +def g2p(text: str) -> tuple[list[str], list[int], list[int]]: phones = [] tones = [] phone_len = [] diff --git a/text/cmudict.rep b/style_bert_vits2/text_processing/english/cmudict.rep similarity index 100% rename from text/cmudict.rep rename to style_bert_vits2/text_processing/english/cmudict.rep diff --git a/text/cmudict_cache.pickle b/style_bert_vits2/text_processing/english/cmudict_cache.pickle similarity index 100% rename from text/cmudict_cache.pickle rename to style_bert_vits2/text_processing/english/cmudict_cache.pickle diff --git a/text/opencpop-strict.txt b/style_bert_vits2/text_processing/english/opencpop-strict.txt similarity index 100% rename from text/opencpop-strict.txt rename to style_bert_vits2/text_processing/english/opencpop-strict.txt diff --git a/style_bert_vits2/text_processing/japanese/__init__.py b/style_bert_vits2/text_processing/japanese/__init__.py new file mode 100644 index 000000000..17e1785fe --- /dev/null +++ b/style_bert_vits2/text_processing/japanese/__init__.py @@ -0,0 +1,2 @@ +from style_bert_vits2.text_processing.japanese.g2p import g2p # type: ignore +from style_bert_vits2.text_processing.japanese.normalizer import normalize_text # type: ignore From bffd5a67bb93174869db2f1164211970cec67f63 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 08:34:54 +0000 Subject: [PATCH 030/148] Fix: import error --- .../text_processing/chinese/__init__.py | 6 ++-- .../text_processing/chinese/tone_sandhi.py | 36 +++++++++---------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/style_bert_vits2/text_processing/chinese/__init__.py b/style_bert_vits2/text_processing/chinese/__init__.py index 92c821425..23acbf3a1 100644 --- a/style_bert_vits2/text_processing/chinese/__init__.py +++ b/style_bert_vits2/text_processing/chinese/__init__.py @@ -2,10 +2,12 @@ import re import cn2an +import jieba.posseg as psg from pypinyin import lazy_pinyin, Style +from style_bert_vits2.text_processing.chinese.tone_sandhi import ToneSandhi from style_bert_vits2.text_processing.symbols import PUNCTUATIONS -from text.tone_sandhi import ToneSandhi + current_file_path = os.path.dirname(__file__) pinyin_to_symbol_map = { @@ -13,8 +15,6 @@ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() } -import jieba.posseg as psg - rep_map = { ":": ",", diff --git a/style_bert_vits2/text_processing/chinese/tone_sandhi.py b/style_bert_vits2/text_processing/chinese/tone_sandhi.py index 38f313785..5832434fd 100644 --- a/style_bert_vits2/text_processing/chinese/tone_sandhi.py +++ b/style_bert_vits2/text_processing/chinese/tone_sandhi.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List -from typing import Tuple import jieba from pypinyin import lazy_pinyin @@ -463,7 +461,7 @@ def __init__(self): # word: "家里" # pos: "s" # finals: ['ia1', 'i3'] - def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]: + def _neural_sandhi(self, word: str, pos: str, finals: list[str]) -> list[str]: # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺 for j, item in enumerate(word): if ( @@ -522,7 +520,7 @@ def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]: finals = sum(finals_list, []) return finals - def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]: + def _bu_sandhi(self, word: str, finals: list[str]) -> list[str]: # e.g. 看不懂 if len(word) == 3 and word[1] == "不": finals[1] = finals[1][:-1] + "5" @@ -533,7 +531,7 @@ def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]: finals[i] = finals[i][:-1] + "2" return finals - def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: + def _yi_sandhi(self, word: str, finals: list[str]) -> list[str]: # "一" in number sequences, e.g. 一零零, 二一零 if word.find("一") != -1 and all( [item.isnumeric() for item in word if item != "一"] @@ -558,9 +556,9 @@ def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: finals[i] = finals[i][:-1] + "4" return finals - def _split_word(self, word: str) -> List[str]: + def _split_word(self, word: str) -> list[str]: word_list = jieba.cut_for_search(word) - word_list = sorted(word_list, key=lambda i: len(i), reverse=False) + word_list = sorted(word_list, key=lambda i: len(i), reverse=False) # type: ignore first_subword = word_list[0] first_begin_idx = word.find(first_subword) if first_begin_idx == 0: @@ -571,7 +569,7 @@ def _split_word(self, word: str) -> List[str]: new_word_list = [second_subword, first_subword] return new_word_list - def _three_sandhi(self, word: str, finals: List[str]) -> List[str]: + def _three_sandhi(self, word: str, finals: list[str]) -> list[str]: if len(word) == 2 and self._all_tone_three(finals): finals[0] = finals[0][:-1] + "2" elif len(word) == 3: @@ -611,12 +609,12 @@ def _three_sandhi(self, word: str, finals: List[str]) -> List[str]: return finals - def _all_tone_three(self, finals: List[str]) -> bool: + def _all_tone_three(self, finals: list[str]) -> bool: return all(x[-1] == "3" for x in finals) # merge "不" and the word behind it # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error - def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + def _merge_bu(self, seg: list[tuple[str, str]]) -> list[tuple[str, str]]: new_seg = [] last_word = "" for word, pos in seg: @@ -636,7 +634,7 @@ def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: # e.g. # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')] # output seg: [['听一听', 'v']] - def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + def _merge_yi(self, seg: list[tuple[str, str]]) -> list[tuple[str, str]]: new_seg = [] * len(seg) # function 1 i = 0 @@ -674,8 +672,8 @@ def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: # the first and the second words are all_tone_three def _merge_continuous_three_tones( - self, seg: List[Tuple[str, str]] - ) -> List[Tuple[str, str]]: + self, seg: list[tuple[str, str]] + ) -> list[tuple[str, str]]: new_seg = [] sub_finals_list = [ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) @@ -709,8 +707,8 @@ def _is_reduplication(self, word: str) -> bool: # the last char of first word and the first char of second word is tone_three def _merge_continuous_three_tones_2( - self, seg: List[Tuple[str, str]] - ) -> List[Tuple[str, str]]: + self, seg: list[tuple[str, str]] + ) -> list[tuple[str, str]]: new_seg = [] sub_finals_list = [ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) @@ -738,7 +736,7 @@ def _merge_continuous_three_tones_2( new_seg.append([word, pos]) return new_seg - def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + def _merge_er(self, seg: list[tuple[str, str]]) -> list[tuple[str, str]]: new_seg = [] for i, (word, pos) in enumerate(seg): if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#": @@ -747,7 +745,7 @@ def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: new_seg.append([word, pos]) return new_seg - def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + def _merge_reduplication(self, seg: list[tuple[str, str]]) -> list[tuple[str, str]]: new_seg = [] for i, (word, pos) in enumerate(seg): if new_seg and word == new_seg[-1][0]: @@ -756,7 +754,7 @@ def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, st new_seg.append([word, pos]) return new_seg - def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + def pre_merge_for_modify(self, seg: list[tuple[str, str]]) -> list[tuple[str, str]]: seg = self._merge_bu(seg) try: seg = self._merge_yi(seg) @@ -768,7 +766,7 @@ def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, st seg = self._merge_er(seg) return seg - def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]: + def modified_tone(self, word: str, pos: str, finals: list[str]) -> list[str]: finals = self._bu_sandhi(word, finals) finals = self._yi_sandhi(word, finals) finals = self._neural_sandhi(word, pos, finals) From e57cfbf072c2a34dcfd22201ba9c8c67632f9122 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 08:46:13 +0000 Subject: [PATCH 031/148] Remove: remove currently unused code in style_bert_vits2/models/commons.py --- style_bert_vits2/models/commons.py | 126 ------------------ .../text_processing/japanese/g2p.py | 4 +- .../text_processing/japanese/normalizer.py | 46 +++---- 3 files changed, 25 insertions(+), 151 deletions(-) diff --git a/style_bert_vits2/models/commons.py b/style_bert_vits2/models/commons.py index 064ef5f58..969ed3694 100644 --- a/style_bert_vits2/models/commons.py +++ b/style_bert_vits2/models/commons.py @@ -3,7 +3,6 @@ コードと完全に一致している保証はない。あくまで参考程度とすること。 """ -import math import torch from torch.nn import functional as F from typing import Any @@ -68,54 +67,6 @@ def intersperse(lst: list[Any], item: Any) -> list[Any]: return result -def kl_divergence(m_p: torch.Tensor, logs_p: torch.Tensor, m_q: torch.Tensor, logs_q: torch.Tensor) -> torch.Tensor: - """ - 2つの正規分布間の KL ダイバージェンスを計算する - - Args: - m_p (torch.Tensor): P の平均 - logs_p (torch.Tensor): P の対数標準偏差 - m_q (torch.Tensor): Q の平均 - logs_q (torch.Tensor): Q の対数標準偏差 - - Returns: - torch.Tensor: KL ダイバージェンスの値。 - """ - kl = (logs_q - logs_p) - 0.5 - kl += ( - 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) - ) - return kl - - -def rand_gumbel(shape: torch.Size) -> torch.Tensor: - """ - Gumbel 分布からサンプリングし、オーバーフローを防ぐ - - Args: - shape (torch.Size): サンプルの形状 - - Returns: - torch.Tensor: Gumbel 分布からのサンプル - """ - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x: torch.Tensor) -> torch.Tensor: - """ - 引数と同じ形状のテンソルで、Gumbel 分布からサンプリングする - - Args: - x (torch.Tensor): 形状を基にするテンソル - - Returns: - torch.Tensor: Gumbel 分布からのサンプル - """ - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - def slice_segments(x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4) -> torch.Tensor: """ テンソルからセグメントをスライスする @@ -155,69 +106,6 @@ def rand_slice_segments(x: torch.Tensor, x_lengths: torch.Tensor | None = None, return ret, ids_str -def get_timing_signal_1d(length: int, channels: int, min_timescale: float = 1.0, max_timescale: float = 1.0e4) -> torch.Tensor: - """ - 1D タイミング信号を取得する - - Args: - length (int): シグナルの長さ - channels (int): シグナルのチャネル数 - min_timescale (float, optional): 最小のタイムスケール (デフォルト: 1.0) - max_timescale (float, optional): 最大のタイムスケール (デフォルト: 1.0e4) - - Returns: - torch.Tensor: タイミング信号 - """ - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( - num_timescales - 1 - ) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x: torch.Tensor, min_timescale: float = 1.0, max_timescale: float = 1.0e4) -> torch.Tensor: - """ - 1D タイミング信号をテンソルに追加する - - Args: - x (torch.Tensor): 入力テンソル - min_timescale (float, optional): 最小のタイムスケール (デフォルト: 1.0) - max_timescale (float, optional): 最大のタイムスケール (デフォルト: 1.0e4) - - Returns: - torch.Tensor: タイミング信号が追加されたテンソル - """ - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x: torch.Tensor, min_timescale: float = 1.0, max_timescale: float = 1.0e4, axis: int = 1) -> torch.Tensor: - """ - 1D タイミング信号をテンソルに連結する - - Args: - x (torch.Tensor): 入力テンソル - min_timescale (float, optional): 最小のタイムスケール (デフォルト: 1.0) - max_timescale (float, optional): 最大のタイムスケール (デフォルト: 1.0e4) - axis (int, optional): 連結する軸 (デフォルト: 1) - - Returns: - torch.Tensor: タイミング信号が連結されたテンソル - """ - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - def subsequent_mask(length: int) -> torch.Tensor: """ 後続のマスクを生成する @@ -253,20 +141,6 @@ def fused_add_tanh_sigmoid_multiply(input_a: torch.Tensor, input_b: torch.Tensor return acts -def shift_1d(x: torch.Tensor) -> torch.Tensor: - """ - 与えられたテンソルを 1D でシフトする - - Args: - x (torch.Tensor): シフトするテンソル - - Returns: - torch.Tensor: シフトされたテンソル - """ - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - def sequence_mask(length: torch.Tensor, max_length: int | None = None) -> torch.Tensor: """ シーケンスマスクを生成する diff --git a/style_bert_vits2/text_processing/japanese/g2p.py b/style_bert_vits2/text_processing/japanese/g2p.py index 79687511f..c04d07848 100644 --- a/style_bert_vits2/text_processing/japanese/g2p.py +++ b/style_bert_vits2/text_processing/japanese/g2p.py @@ -171,7 +171,7 @@ def __g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: list[tuple[str, int]]: 音素とアクセントのペアのリスト """ - prosodies = pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True) + prosodies = __pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True) # logger.debug(f"prosodies: {prosodies}") result: list[tuple[str, int]] = [] current_phrase: list[tuple[str, int]] = [] @@ -212,7 +212,7 @@ def __g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: return result -def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> list[str]: +def __pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> list[str]: """ ESPnet の実装から引用、変更点無し。「ん」は「N」なことに注意。 ref: https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py diff --git a/style_bert_vits2/text_processing/japanese/normalizer.py b/style_bert_vits2/text_processing/japanese/normalizer.py index 92c2e879f..827633878 100644 --- a/style_bert_vits2/text_processing/japanese/normalizer.py +++ b/style_bert_vits2/text_processing/japanese/normalizer.py @@ -48,29 +48,6 @@ def normalize_text(text: str) -> str: return res -def __convert_numbers_to_words(text: str) -> str: - """ - 記号や数字を日本語の文字表現に変換する。 - - Args: - text (str): 変換するテキスト - - Returns: - str: 変換されたテキスト - """ - - NUMBER_WITH_SEPARATOR_PATTERN = re.compile("[0-9]{1,3}(,[0-9]{3})+") - CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} - CURRENCY_PATTERN = re.compile(r"([$¥£€])([0-9.]*[0-9])") - NUMBER_PATTERN = re.compile(r"[0-9]+(\.[0-9]+)?") - - res = NUMBER_WITH_SEPARATOR_PATTERN.sub(lambda m: m[0].replace(",", ""), text) - res = CURRENCY_PATTERN.sub(lambda m: m[2] + CURRENCY_MAP.get(m[1], m[1]), res) - res = NUMBER_PATTERN.sub(lambda m: num2words(m[0], lang="ja"), res) - - return res - - def replace_punctuation(text: str) -> str: """ 句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalk で読みが取得できるもののみ残す: @@ -159,3 +136,26 @@ def replace_punctuation(text: str) -> str: ) return replaced_text + + +def __convert_numbers_to_words(text: str) -> str: + """ + 記号や数字を日本語の文字表現に変換する。 + + Args: + text (str): 変換するテキスト + + Returns: + str: 変換されたテキスト + """ + + NUMBER_WITH_SEPARATOR_PATTERN = re.compile("[0-9]{1,3}(,[0-9]{3})+") + CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} + CURRENCY_PATTERN = re.compile(r"([$¥£€])([0-9.]*[0-9])") + NUMBER_PATTERN = re.compile(r"[0-9]+(\.[0-9]+)?") + + res = NUMBER_WITH_SEPARATOR_PATTERN.sub(lambda m: m[0].replace(",", ""), text) + res = CURRENCY_PATTERN.sub(lambda m: m[2] + CURRENCY_MAP.get(m[1], m[1]), res) + res = NUMBER_PATTERN.sub(lambda m: num2words(m[0], lang="ja"), res) + + return res From 70f8d53a1e41f83ef33f20062bd45b4ff193cb64 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 09:09:24 +0000 Subject: [PATCH 032/148] Add: empty __init__.py --- style_bert_vits2/models/__init__.py | 0 style_bert_vits2/utils/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 style_bert_vits2/models/__init__.py create mode 100644 style_bert_vits2/utils/__init__.py diff --git a/style_bert_vits2/models/__init__.py b/style_bert_vits2/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/style_bert_vits2/utils/__init__.py b/style_bert_vits2/utils/__init__.py new file mode 100644 index 000000000..e69de29bb From d024b71340630d20a20f4ea4b8c4198e52e1dcc2 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 8 Mar 2024 10:10:53 +0900 Subject: [PATCH 033/148] Fix typo --- common/log.py | 1 + text/pyopenjtalk_worker/__init__.py | 4 ++-- text/pyopenjtalk_worker/__main__.py | 4 ++-- text/pyopenjtalk_worker/worker_common.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/common/log.py b/common/log.py index 679bb2c77..71b5e6394 100644 --- a/common/log.py +++ b/common/log.py @@ -14,4 +14,5 @@ "{time:MM-DD HH:mm:ss} |{level:^8}| {file}:{line} | {message}" ) +# logger.add(SAFE_STDOUT, format=log_format, backtrace=True, diagnose=True, level="TRACE") logger.add(SAFE_STDOUT, format=log_format, backtrace=True, diagnose=True) diff --git a/text/pyopenjtalk_worker/__init__.py b/text/pyopenjtalk_worker/__init__.py index 9677666b2..7aa0d0ea8 100644 --- a/text/pyopenjtalk_worker/__init__.py +++ b/text/pyopenjtalk_worker/__init__.py @@ -5,7 +5,7 @@ from typing import Optional, Any -from .worker_common import WOKER_PORT +from .worker_common import WORKER_PORT from .worker_client import WorkerClient from common.log import logger @@ -49,7 +49,7 @@ def unset_user_dict(): # initialize module when imported -def initialize(port: int = WOKER_PORT): +def initialize(port: int = WORKER_PORT): import time import socket import sys diff --git a/text/pyopenjtalk_worker/__main__.py b/text/pyopenjtalk_worker/__main__.py index 8b67aa0bb..3bb6b53a2 100644 --- a/text/pyopenjtalk_worker/__main__.py +++ b/text/pyopenjtalk_worker/__main__.py @@ -1,12 +1,12 @@ import argparse from .worker_server import WorkerServer -from .worker_common import WOKER_PORT +from .worker_common import WORKER_PORT def main(): parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, default=WOKER_PORT) + parser.add_argument("--port", type=int, default=WORKER_PORT) args = parser.parse_args() server = WorkerServer() server.start_server(port=args.port) diff --git a/text/pyopenjtalk_worker/worker_common.py b/text/pyopenjtalk_worker/worker_common.py index bea552e6a..606d0c336 100644 --- a/text/pyopenjtalk_worker/worker_common.py +++ b/text/pyopenjtalk_worker/worker_common.py @@ -3,7 +3,7 @@ import socket import json -WOKER_PORT: Final[int] = 7861 +WORKER_PORT: Final[int] = 7861 HEADER_SIZE: Final[int] = 4 From 25ed226acbe13ba9425169c89bc1452c842f424e Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 8 Mar 2024 11:19:03 +0900 Subject: [PATCH 034/148] Add speaker list api --- common/tts_model.py | 21 ++++++++------------- server_editor.py | 17 +++++++++++++---- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/common/tts_model.py b/common/tts_model.py index e09787e7e..9e2171dc7 100644 --- a/common/tts_model.py +++ b/common/tts_model.py @@ -222,12 +222,14 @@ def __init__(self, root_dir: Path, device: str): self.current_model: Optional[Model] = None self.model_names: list[str] = [] self.models: list[Model] = [] + self.models_info: list[dict[str, Union[str, list[str]]]] = [] self.refresh() def refresh(self): self.model_files_dict = {} self.model_names = [] self.current_model = None + self.models_info = [] model_dirs = [d for d in self.root_dir.iterdir() if d.is_dir()] for model_dir in model_dirs: @@ -247,26 +249,19 @@ def refresh(self): continue self.model_files_dict[model_dir.name] = model_files self.model_names.append(model_dir.name) - - def models_info(self): - if hasattr(self, "_models_info"): - return self._models_info - result = [] - for name, files in self.model_files_dict.items(): - # Get styles - config_path = self.root_dir / name / "config.json" hps = utils.get_hparams_from_file(config_path) style2id: dict[str, int] = hps.data.style2id styles = list(style2id.keys()) - result.append( + spk2id: dict[str, int] = hps.data.spk2id + speakers = list(spk2id.keys()) + self.models_info.append( { - "name": name, - "files": [str(f) for f in files], + "name": model_dir.name, + "files": [str(f) for f in model_files], "styles": styles, + "speakers": speakers, } ) - self._models_info = result - return result def load_model(self, model_name: str, model_path_str: str): model_path = Path(model_path_str) diff --git a/server_editor.py b/server_editor.py index 1ae323f32..402116f5d 100644 --- a/server_editor.py +++ b/server_editor.py @@ -16,12 +16,13 @@ from datetime import datetime from io import BytesIO from pathlib import Path -import yaml +from typing import Optional import numpy as np import requests import torch import uvicorn +import yaml from fastapi import APIRouter, FastAPI, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response @@ -42,8 +43,7 @@ from common.log import logger from common.tts_model import ModelHolder from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize -from text.user_dict import apply_word, update_dict, read_dict, rewrite_word, delete_word - +from text.user_dict import apply_word, delete_word, read_dict, rewrite_word, update_dict # ---フロントエンド部分に関する処理--- @@ -229,7 +229,7 @@ async def normalize_text(item: TextRequest): @router.get("/models_info") def models_info(): - return model_holder.models_info() + return model_holder.models_info class SynthesisRequest(BaseModel): @@ -249,6 +249,7 @@ class SynthesisRequest(BaseModel): silenceAfter: float = 0.5 pitchScale: float = 1.0 intonationScale: float = 1.0 + speaker: Optional[str] = None @router.post("/synthesis", response_class=AudioResponse) @@ -274,6 +275,13 @@ def synthesis(request: SynthesisRequest): ] phone_tone = kata_tone2phone_tone(kata_tone_list) tone = [t for _, t in phone_tone] + try: + sid = 0 if request.speaker is None else model.spk2id[request.speaker] + except KeyError: + raise HTTPException( + status_code=400, + detail=f"Speaker {request.speaker} not found in {model.spk2id}", + ) sr, audio = model.infer( text=text, language=request.language.value, @@ -290,6 +298,7 @@ def synthesis(request: SynthesisRequest): line_split=False, pitch_scale=request.pitchScale, intonation_scale=request.intonationScale, + sid=sid, ) with BytesIO() as wavContent: From 1cbe9648b0f103e803b88c1d3976e37571b45da3 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 8 Mar 2024 12:59:27 +0900 Subject: [PATCH 035/148] Initialize pyopenjtalk worker for multi-threading --- bert_gen.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bert_gen.py b/bert_gen.py index a5f7c258f..1e4fb61f4 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -6,12 +6,15 @@ from tqdm import tqdm import commons +import text.pyopenjtalk_worker as pyopenjtalk import utils from common.log import logger from common.stdout_wrapper import SAFE_STDOUT from config import config from text import cleaned_text_to_sequence, get_bert +pyopenjtalk.initialize() + def process_line(x): line, add_blank = x From 9f01e54d2dc72f1c62aab504ec4795e3ef619158 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 8 Mar 2024 13:00:15 +0900 Subject: [PATCH 036/148] Unify webui to single webui --- app.py | 51 +++++--- common/constants.py | 2 +- webui/dataset.py | 40 ++----- webui/inference.py | 262 +++++++++++++++++------------------------ webui/merge.py | 30 +---- webui/style_vectors.py | 26 +--- webui/train.py | 29 +---- 7 files changed, 168 insertions(+), 272 deletions(-) diff --git a/app.py b/app.py index 7057f459c..997bd047f 100644 --- a/app.py +++ b/app.py @@ -1,30 +1,51 @@ -import pyopenjtalk +import argparse +from pathlib import Path + import gradio as gr +import torch +import yaml + +from common.constants import GRADIO_THEME, LATEST_VERSION +from common.tts_model import ModelHolder from webui import ( create_dataset_app, - create_train_app, + create_inference_app, create_merge_app, create_style_vectors_app, + create_train_app, ) -from pathlib import Path -pyopenjtalk.unset_user_dict() +# Get path settings +with Path("configs/paths.yml").open("r", encoding="utf-8") as f: + path_config: dict[str, str] = yaml.safe_load(f.read()) + # dataset_root = path_config["dataset_root"] + assets_root = path_config["assets_root"] + +parser = argparse.ArgumentParser() +parser.add_argument("--device", type=str, default="cuda") +parser.add_argument("--no_autolaunch", action="store_true") +parser.add_argument("--share", action="store_true") + +args = parser.parse_args() +device = args.device +if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" -setting_json = Path("webui/setting.json") +model_holder = ModelHolder(Path(assets_root), device) -with gr.Blocks() as app: +with gr.Blocks(theme=GRADIO_THEME) as app: + gr.Markdown(f"# Style-Bert-VITS2 WebUI (version {LATEST_VERSION})") with gr.Tabs(): - with gr.Tab("Hello"): - gr.Markdown("## Hello, Gradio!") - gr.Textbox("input", label="Input Text") - with gr.Tab("Dataset"): + with gr.Tab("音声合成"): + create_inference_app(model_holder=model_holder) + with gr.Tab("データセット作成"): create_dataset_app() - with gr.Tab("Train"): + with gr.Tab("学習"): create_train_app() - with gr.Tab("Merge"): - create_merge_app() - with gr.Tab("Create Style Vectors"): + with gr.Tab("スタイル作成"): create_style_vectors_app() + with gr.Tab("マージ"): + create_merge_app(model_holder=model_holder) -app.launch(inbrowser=True) +app.launch(inbrowser=not args.no_autolaunch, share=args.share) diff --git a/common/constants.py b/common/constants.py index fe620195d..f3ccb69a1 100644 --- a/common/constants.py +++ b/common/constants.py @@ -4,7 +4,7 @@ # See https://huggingface.co/spaces/gradio/theme-gallery for more themes GRADIO_THEME: str = "NoCrypt/miku" -LATEST_VERSION: str = "2.3.1" +LATEST_VERSION: str = "2.4" USER_DICT_DIR = "dict_data" diff --git a/webui/dataset.py b/webui/dataset.py index 5ed656c4d..3169fd25e 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -8,12 +8,6 @@ from common.log import logger from common.subprocess_utils import run_script_with_log -# Get path settings -with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - dataset_root = path_config["dataset_root"] - # assets_root = path_config["assets_root"] - def do_slice( model_name: str, @@ -69,13 +63,11 @@ def do_transcribe( ] ) if not success: - return f"Error: {message}" + return f"Error: {message}. しかし何故かエラーが起きても正常に終了している場合がほとんどなので、書き起こし結果を確認して問題なければ学習に使えます。" return "音声の文字起こしが完了しました。" -initial_md = """ -# 簡易学習用データセット作成ツール - +how_to_md = """ Style-Bert-VITS2の学習用データセットを作成するためのツールです。以下の2つからなります。 - 与えられた音声からちょうどいい長さの発話区間を切り取りスライス @@ -107,10 +99,10 @@ def do_transcribe( """ -def create_dataset_app(): - - with gr.Blocks(theme=GRADIO_THEME) as app: - gr.Markdown(initial_md) +def create_dataset_app() -> gr.Blocks: + with gr.Blocks() as app: + with gr.Accordion("使い方", open=False): + gr.Markdown(how_to_md) model_name = gr.Textbox( label="モデル名を入力してください(話者名としても使われます)。" ) @@ -118,8 +110,8 @@ def create_dataset_app(): with gr.Row(): with gr.Column(): input_dir = gr.Textbox( - label="入力フォルダ名(デフォルトはinputs)", - placeholder="inputs", + label="元音声の入っているフォルダパス", + value="inputs", info="下記フォルダにwavファイルを入れておいてください", ) min_sec = gr.Slider( @@ -201,20 +193,4 @@ def create_dataset_app(): outputs=[result2], ) - parser = argparse.ArgumentParser() - parser.add_argument( - "--server-name", - type=str, - default=None, - help="Server name for Gradio app", - ) - parser.add_argument( - "--no-autolaunch", - action="store_true", - default=False, - help="Do not launch app automatically", - ) - args = parser.parse_args() - - # app.launch(inbrowser=not args.no_autolaunch, server_name=args.server_name) return app diff --git a/webui/inference.py b/webui/inference.py index 663d5a637..94124b980 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -1,14 +1,9 @@ import argparse import datetime import json -import os -import sys -from pathlib import Path from typing import Optional import gradio as gr -import torch -import yaml from common.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, @@ -21,7 +16,6 @@ DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, GRADIO_THEME, - LATEST_VERSION, Languages, ) from common.log import logger @@ -29,119 +23,9 @@ from infer import InvalidToneError from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize -# Get path settings -with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: - path_config: dict[str, str] = yaml.safe_load(f.read()) - # dataset_root = path_config["dataset_root"] - assets_root = path_config["assets_root"] - languages = [l.value for l in Languages] -def tts_fn( - model_name, - model_path, - text, - language, - reference_audio_path, - sdp_ratio, - noise_scale, - noise_scale_w, - length_scale, - line_split, - split_interval, - assist_text, - assist_text_weight, - use_assist_text, - style, - style_weight, - kata_tone_json_str, - use_tone, - speaker, - pitch_scale, - intonation_scale, -): - model_holder.load_model_gr(model_name, model_path) - - wrong_tone_message = "" - kata_tone: Optional[list[tuple[str, int]]] = None - if use_tone and kata_tone_json_str != "": - if language != "JP": - logger.warning("Only Japanese is supported for tone generation.") - wrong_tone_message = "アクセント指定は現在日本語のみ対応しています。" - if line_split: - logger.warning("Tone generation is not supported for line split.") - wrong_tone_message = ( - "アクセント指定は改行で分けて生成を使わない場合のみ対応しています。" - ) - try: - kata_tone = [] - json_data = json.loads(kata_tone_json_str) - # tupleを使うように変換 - for kana, tone in json_data: - assert isinstance(kana, str) and tone in (0, 1), f"{kana}, {tone}" - kata_tone.append((kana, tone)) - except Exception as e: - logger.warning(f"Error occurred when parsing kana_tone_json: {e}") - wrong_tone_message = f"アクセント指定が不正です: {e}" - kata_tone = None - - # toneは実際に音声合成に代入される際のみnot Noneになる - tone: Optional[list[int]] = None - if kata_tone is not None: - phone_tone = kata_tone2phone_tone(kata_tone) - tone = [t for _, t in phone_tone] - - speaker_id = model_holder.current_model.spk2id[speaker] - - start_time = datetime.datetime.now() - - assert model_holder.current_model is not None - - try: - sr, audio = model_holder.current_model.infer( - text=text, - language=language, - reference_audio_path=reference_audio_path, - sdp_ratio=sdp_ratio, - noise=noise_scale, - noisew=noise_scale_w, - length=length_scale, - line_split=line_split, - split_interval=split_interval, - assist_text=assist_text, - assist_text_weight=assist_text_weight, - use_assist_text=use_assist_text, - style=style, - style_weight=style_weight, - given_tone=tone, - sid=speaker_id, - pitch_scale=pitch_scale, - intonation_scale=intonation_scale, - ) - except InvalidToneError as e: - logger.error(f"Tone error: {e}") - return f"Error: アクセント指定が不正です:\n{e}", None, kata_tone_json_str - except ValueError as e: - logger.error(f"Value error: {e}") - return f"Error: {e}", None, kata_tone_json_str - - end_time = datetime.datetime.now() - duration = (end_time - start_time).total_seconds() - - if tone is None and language == "JP": - # アクセント指定に使えるようにアクセント情報を返す - norm_text = text_normalize(text) - kata_tone = g2kata_tone(norm_text) - kata_tone_json_str = json.dumps(kata_tone, ensure_ascii=False) - elif tone is None: - kata_tone_json_str = "" - message = f"Success, time: {duration} seconds." - if wrong_tone_message != "": - message = wrong_tone_message + "\n" + message - return message, (sr, audio), kata_tone_json_str - - initial_text = "こんにちは、初めまして。あなたの名前はなんていうの?" examples = [ @@ -202,9 +86,7 @@ def tts_fn( ] initial_md = f""" -# Style-Bert-VITS2 ver {LATEST_VERSION} 音声合成 - -- Ver 2.3で追加されたエディターのほうが実際に読み上げさせるには使いやすいかもしれません。`Editor.bat`か`python server_editor.py`で起動できます。 +- Ver 2.3で追加されたエディターのほうが実際に読み上げさせるには使いやすいかもしれません。`Editor.bat`か`python server_editor.py --inbrowser`で起動できます。 - 初期からある[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)は、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。ライセンスは[CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.ja)です。 """ @@ -254,43 +136,119 @@ def gr_util(item): return (gr.update(visible=False), gr.update(visible=True)) -def create_inference_app(): - parser = argparse.ArgumentParser() - parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU") - parser.add_argument( - "--dir", "-d", type=str, help="Model directory", default=assets_root - ) - parser.add_argument( - "--share", action="store_true", help="Share this app publicly", default=False - ) - parser.add_argument( - "--server-name", - type=str, - default=None, - help="Server name for Gradio app", - ) - parser.add_argument( - "--no-autolaunch", - action="store_true", - default=False, - help="Do not launch app automatically", - ) - args = parser.parse_args() - model_dir = Path(args.dir) - - if args.cpu: - device = "cpu" - else: - device = "cuda" if torch.cuda.is_available() else "cpu" +def create_inference_app(model_holder: ModelHolder) -> gr.Blocks: + def tts_fn( + model_name, + model_path, + text, + language, + reference_audio_path, + sdp_ratio, + noise_scale, + noise_scale_w, + length_scale, + line_split, + split_interval, + assist_text, + assist_text_weight, + use_assist_text, + style, + style_weight, + kata_tone_json_str, + use_tone, + speaker, + pitch_scale, + intonation_scale, + ): + model_holder.load_model(model_name, model_path) + assert model_holder.current_model is not None + + wrong_tone_message = "" + kata_tone: Optional[list[tuple[str, int]]] = None + if use_tone and kata_tone_json_str != "": + if language != "JP": + logger.warning("Only Japanese is supported for tone generation.") + wrong_tone_message = "アクセント指定は現在日本語のみ対応しています。" + if line_split: + logger.warning("Tone generation is not supported for line split.") + wrong_tone_message = ( + "アクセント指定は改行で分けて生成を使わない場合のみ対応しています。" + ) + try: + kata_tone = [] + json_data = json.loads(kata_tone_json_str) + # tupleを使うように変換 + for kana, tone in json_data: + assert isinstance(kana, str) and tone in (0, 1), f"{kana}, {tone}" + kata_tone.append((kana, tone)) + except Exception as e: + logger.warning(f"Error occurred when parsing kana_tone_json: {e}") + wrong_tone_message = f"アクセント指定が不正です: {e}" + kata_tone = None + + # toneは実際に音声合成に代入される際のみnot Noneになる + tone: Optional[list[int]] = None + if kata_tone is not None: + phone_tone = kata_tone2phone_tone(kata_tone) + tone = [t for _, t in phone_tone] + + speaker_id = model_holder.current_model.spk2id[speaker] + + start_time = datetime.datetime.now() - model_holder = ModelHolder(model_dir, device) + try: + sr, audio = model_holder.current_model.infer( + text=text, + language=language, + reference_audio_path=reference_audio_path, + sdp_ratio=sdp_ratio, + noise=noise_scale, + noisew=noise_scale_w, + length=length_scale, + line_split=line_split, + split_interval=split_interval, + assist_text=assist_text, + assist_text_weight=assist_text_weight, + use_assist_text=use_assist_text, + style=style, + style_weight=style_weight, + given_tone=tone, + sid=speaker_id, + pitch_scale=pitch_scale, + intonation_scale=intonation_scale, + ) + except InvalidToneError as e: + logger.error(f"Tone error: {e}") + return f"Error: アクセント指定が不正です:\n{e}", None, kata_tone_json_str + except ValueError as e: + logger.error(f"Value error: {e}") + return f"Error: {e}", None, kata_tone_json_str + + end_time = datetime.datetime.now() + duration = (end_time - start_time).total_seconds() + + if tone is None and language == "JP": + # アクセント指定に使えるようにアクセント情報を返す + norm_text = text_normalize(text) + kata_tone = g2kata_tone(norm_text) + kata_tone_json_str = json.dumps(kata_tone, ensure_ascii=False) + elif tone is None: + kata_tone_json_str = "" + message = f"Success, time: {duration} seconds." + if wrong_tone_message != "": + message = wrong_tone_message + "\n" + message + return message, (sr, audio), kata_tone_json_str model_names = model_holder.model_names if len(model_names) == 0: logger.error( - f"モデルが見つかりませんでした。{model_dir}にモデルを置いてください。" + f"モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。" ) - sys.exit(1) + with gr.Blocks() as app: + gr.Markdown( + f"Error: モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。" + ) + return app initial_id = 0 initial_pth_files = model_holder.model_files_dict[model_names[initial_id]] @@ -497,8 +455,4 @@ def create_inference_app(): outputs=[style, ref_audio_path], ) - # app.launch( - # inbrowser=not args.no_autolaunch, share=args.share, server_name=args.server_name - # ) - return app diff --git a/webui/merge.py b/webui/merge.py index c002e3388..c9386efbc 100644 --- a/webui/merge.py +++ b/webui/merge.py @@ -1,7 +1,6 @@ import argparse import json import os -import sys from pathlib import Path import gradio as gr @@ -28,8 +27,6 @@ # dataset_root = path_config["dataset_root"] assets_root = path_config["assets_root"] -model_holder = ModelHolder(Path(assets_root), device) - def merge_style(model_name_a, model_name_b, weight, output_name, style_triple_list): """ @@ -331,13 +328,17 @@ def load_styles_gr(model_name_a, model_name_b): """ -def create_merge_app(): +def create_merge_app(model_holder: ModelHolder) -> gr.Blocks: model_names = model_holder.model_names if len(model_names) == 0: logger.error( f"モデルが見つかりませんでした。{assets_root}にモデルを置いてください。" ) - sys.exit(1) + with gr.Blocks() as app: + gr.Markdown( + f"Error: モデルが見つかりませんでした。{assets_root}にモデルを置いてください。" + ) + return app initial_id = 0 initial_model_files = model_holder.model_files_dict[model_names[initial_id]] @@ -499,23 +500,4 @@ def create_merge_app(): outputs=[audio_output], ) - parser = argparse.ArgumentParser() - parser.add_argument( - "--server-name", - type=str, - default=None, - help="Server name for Gradio app", - ) - parser.add_argument( - "--no-autolaunch", - action="store_true", - default=False, - help="Do not launch app automatically", - ) - parser.add_argument("--share", action="store_true", default=False) - args = parser.parse_args() - - # app.launch( - # inbrowser=not args.no_autolaunch, server_name=args.server_name, share=args.share - # ) return app diff --git a/webui/style_vectors.py b/webui/style_vectors.py index 27b95bb65..8effa375a 100644 --- a/webui/style_vectors.py +++ b/webui/style_vectors.py @@ -277,9 +277,7 @@ def save_style_vectors_from_files( return f"成功!\n{style_vector_path}に保存し{config_path}を更新しました。" -initial_md = f""" -# Style Bert-VITS2 スタイルベクトルの作成 - +how_to_md = f""" Style-Bert-VITS2でこまかくスタイルを指定して音声合成するには、モデルごとにスタイルベクトルのファイル`style_vectors.npy`を手動で作成する必要があります。 ただし、学習の過程で自動的に平均スタイル「{DEFAULT_STYLE}」のみは作成されるので、それをそのまま使うこともできます(その場合はこのWebUIは使いません)。 @@ -326,7 +324,8 @@ def save_style_vectors_from_files( def create_style_vectors_app(): with gr.Blocks(theme=GRADIO_THEME) as app: - gr.Markdown(initial_md) + with gr.Accordion("使い方", open=False): + gr.Markdown(how_to_md) with gr.Row(): model_name = gr.Textbox(placeholder="your_model_name", label="モデル名") reduction_method = gr.Radio( @@ -461,23 +460,4 @@ def create_style_vectors_app(): outputs=[info2], ) - parser = argparse.ArgumentParser() - parser.add_argument( - "--server-name", - type=str, - default=None, - help="Server name for Gradio app", - ) - parser.add_argument( - "--no-autolaunch", - action="store_true", - default=False, - help="Do not launch app automatically", - ) - parser.add_argument("--share", action="store_true", default=False) - args = parser.parse_args() - - # app.launch( - # inbrowser=not args.no_autolaunch, server_name=args.server_name, share=args.share - # ) return app diff --git a/webui/train.py b/webui/train.py index 1c9a2145e..dee018e66 100644 --- a/webui/train.py +++ b/webui/train.py @@ -14,7 +14,6 @@ import gradio as gr import yaml -from common.constants import GRADIO_THEME, LATEST_VERSION from common.log import logger from common.stdout_wrapper import SAFE_STDOUT from common.subprocess_utils import run_script_with_log, second_elem_of @@ -398,9 +397,7 @@ def run_tensorboard(model_name): yield gr.Button("Tensorboardを開く") -initial_md = f""" -# Style-Bert-VITS2 ver {LATEST_VERSION} 学習用WebUI - +how_to_md = f""" ## 使い方 - データを準備して、モデル名を入力して、必要なら設定を調整してから、「自動前処理を実行」ボタンを押してください。進捗状況等はターミナルに表示されます。 @@ -452,10 +449,11 @@ def run_tensorboard(model_name): def create_train_app(): - with gr.Blocks(theme=GRADIO_THEME).queue() as app: - gr.Markdown(initial_md) - with gr.Accordion(label="データの前準備", open=False): - gr.Markdown(prepare_md) + with gr.Blocks().queue() as app: + with gr.Accordion("使い方", open=False): + gr.Markdown(how_to_md) + with gr.Accordion(label="データの前準備", open=False): + gr.Markdown(prepare_md) model_name = gr.Textbox(label="モデル名") gr.Markdown("### 自動前処理") with gr.Row(variant="panel"): @@ -794,20 +792,5 @@ def create_train_app(): outputs=[use_jp_extra_train], ) - parser = argparse.ArgumentParser() - parser.add_argument( - "--server-name", - type=str, - default=None, - help="Server name for Gradio app", - ) - parser.add_argument( - "--no-autolaunch", - action="store_true", - default=False, - help="Do not launch app automatically", - ) - args = parser.parse_args() - # app.launch(inbrowser=not args.no_autolaunch, server_name=args.server_name) return app From 3f07c256e300b4111201fee7fb2b3b0ddd538f04 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 19:33:21 +0000 Subject: [PATCH 037/148] Refactor: make variables private that are not used externally --- default_style.py | 2 +- gen_yaml.py | 3 +- .../models/monotonic_alignment.py | 4 +- style_bert_vits2/text_processing/__init__.py | 4 +- .../text_processing/bert_models.py | 72 ++++++++++++++++--- .../text_processing/chinese/bert_feature.py | 10 +-- .../text_processing/english/bert_feature.py | 10 +-- .../text_processing/japanese/bert_feature.py | 10 +-- 8 files changed, 85 insertions(+), 30 deletions(-) diff --git a/default_style.py b/default_style.py index 763e29140..67b6fc353 100644 --- a/default_style.py +++ b/default_style.py @@ -1,6 +1,6 @@ import os -from style_bert_vits2.logging import logger from style_bert_vits2.constants import DEFAULT_STYLE +from style_bert_vits2.logging import logger import numpy as np import json diff --git a/gen_yaml.py b/gen_yaml.py index 91301accb..76df20646 100644 --- a/gen_yaml.py +++ b/gen_yaml.py @@ -1,7 +1,8 @@ +import argparse import os import shutil import yaml -import argparse + parser = argparse.ArgumentParser( description="config.ymlの生成。あらかじめ前準備をしたデータをバッチファイルなどで連続で学習する時にtrain_ms.pyより前に使用する。" diff --git a/style_bert_vits2/models/monotonic_alignment.py b/style_bert_vits2/models/monotonic_alignment.py index 0f393c19b..b499ad05f 100644 --- a/style_bert_vits2/models/monotonic_alignment.py +++ b/style_bert_vits2/models/monotonic_alignment.py @@ -28,7 +28,7 @@ def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32) t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32) - maximum_path_jit(path, neg_cent, t_t_max, t_s_max) + __maximum_path_jit(path, neg_cent, t_t_max, t_s_max) return torch.from_numpy(path).to(device=device, dtype=dtype) @@ -43,7 +43,7 @@ def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: nopython = True, nogil = True, ) # type: ignore -def maximum_path_jit(paths: Any, values: Any, t_ys: Any, t_xs: Any) -> None: +def __maximum_path_jit(paths: Any, values: Any, t_ys: Any, t_xs: Any) -> None: """ 与えられたパス、値、およびターゲットの y と x 座標を使用して JIT で最大パスを計算する diff --git a/style_bert_vits2/text_processing/__init__.py b/style_bert_vits2/text_processing/__init__.py index 5e77df624..471952504 100644 --- a/style_bert_vits2/text_processing/__init__.py +++ b/style_bert_vits2/text_processing/__init__.py @@ -8,7 +8,7 @@ ) -_symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)} +__symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)} def extract_bert_feature( @@ -97,7 +97,7 @@ def cleaned_text_to_sequence(cleaned_phones: list[str], tones: list[int], langua tuple[list[int], list[int], list[int]]: List of integers corresponding to the symbols in the text """ - phones = [_symbol_to_id[symbol] for symbol in cleaned_phones] + phones = [__symbol_to_id[symbol] for symbol in cleaned_phones] tone_start = LANGUAGE_TONE_START_MAP[language] tones = [i + tone_start for i in tones] lang_id = LANGUAGE_ID_MAP[language] diff --git a/style_bert_vits2/text_processing/bert_models.py b/style_bert_vits2/text_processing/bert_models.py index 9132082c9..e8ef4b460 100644 --- a/style_bert_vits2/text_processing/bert_models.py +++ b/style_bert_vits2/text_processing/bert_models.py @@ -5,11 +5,13 @@ 場合によっては多重にロードされて非効率なほか、BERT モデルのロード元のパスがハードコードされているためライブラリ化ができない。 そこで、ライブラリの利用前に、音声合成に利用する言語の BERT モデルだけを「明示的に」ロードできるようにした。 -一度 load_tokenizer() で当該言語の BERT モデルがロードされていれば、ライブラリ内部のどこからでもロード済みのモデル/トークナイザーを取得できる。 +一度 load_model/tokenizer() で当該言語の BERT モデルがロードされていれば、ライブラリ内部のどこからでもロード済みのモデル/トークナイザーを取得できる。 """ +import gc from typing import cast +import torch from transformers import ( AutoModelForMaskedLM, AutoTokenizer, @@ -25,10 +27,10 @@ # 各言語ごとのロード済みの BERT モデルを格納する辞書 -loaded_models: dict[Languages, PreTrainedModel | DebertaV2Model] = {} +__loaded_models: dict[Languages, PreTrainedModel | DebertaV2Model] = {} # 各言語ごとのロード済みの BERT トークナイザーを格納する辞書 -loaded_tokenizers: dict[Languages, PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer] = {} +__loaded_tokenizers: dict[Languages, PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer] = {} def load_model( @@ -56,8 +58,8 @@ def load_model( """ # すでにロード済みの場合はそのまま返す - if language in loaded_models: - return loaded_models[language] + if language in __loaded_models: + return __loaded_models[language] # pretrained_model_name_or_path が指定されていない場合はデフォルトのパスを利用 if pretrained_model_name_or_path is None: @@ -71,7 +73,7 @@ def load_model( model = cast(DebertaV2Model, DebertaV2Model.from_pretrained(pretrained_model_name_or_path)) else: model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path) - loaded_models[language] = model + __loaded_models[language] = model logger.info(f"Loaded the {language} BERT model from {pretrained_model_name_or_path}") return model @@ -102,8 +104,8 @@ def load_tokenizer( """ # すでにロード済みの場合はそのまま返す - if language in loaded_tokenizers: - return loaded_tokenizers[language] + if language in __loaded_tokenizers: + return __loaded_tokenizers[language] # pretrained_model_name_or_path が指定されていない場合はデフォルトのパスを利用 if pretrained_model_name_or_path is None: @@ -117,7 +119,59 @@ def load_tokenizer( tokenizer = DebertaV2Tokenizer.from_pretrained(pretrained_model_name_or_path) else: tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) - loaded_tokenizers[language] = tokenizer + __loaded_tokenizers[language] = tokenizer logger.info(f"Loaded the {language} BERT tokenizer from {pretrained_model_name_or_path}") return tokenizer + + +def unload_model(language: Languages) -> None: + """ + 指定された言語の BERT モデルをアンロードする + + Args: + language (Languages): アンロードする BERT モデルの言語 + """ + + if language in __loaded_models: + del __loaded_models[language] + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info(f"Unloaded the {language} BERT model") + + +def unload_tokenizer(language: Languages) -> None: + """ + 指定された言語の BERT トークナイザーをアンロードする + + Args: + language (Languages): アンロードする BERT トークナイザーの言語 + """ + + if language in __loaded_tokenizers: + del __loaded_tokenizers[language] + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logger.info(f"Unloaded the {language} BERT tokenizer") + + +def unload_all_models() -> None: + """ + すべての BERT モデルをアンロードする + """ + + for language in list(__loaded_models.keys()): + unload_model(language) + logger.info("Unloaded all BERT models") + + +def unload_all_tokenizers() -> None: + """ + すべての BERT トークナイザーをアンロードする + """ + + for language in list(__loaded_tokenizers.keys()): + unload_tokenizer(language) + logger.info("Unloaded all BERT tokenizers") diff --git a/style_bert_vits2/text_processing/chinese/bert_feature.py b/style_bert_vits2/text_processing/chinese/bert_feature.py index 25024cb25..3178565a1 100644 --- a/style_bert_vits2/text_processing/chinese/bert_feature.py +++ b/style_bert_vits2/text_processing/chinese/bert_feature.py @@ -7,7 +7,7 @@ from style_bert_vits2.text_processing import bert_models -models: dict[torch.device | str, PreTrainedModel] = {} +__models: dict[torch.device | str, PreTrainedModel] = {} def extract_bert_feature( @@ -41,8 +41,8 @@ def extract_bert_feature( device = "cuda" if device == "cuda" and not torch.cuda.is_available(): device = "cpu" - if device not in models.keys(): - models[device] = bert_models.load_model(Languages.ZH).to(device) # type: ignore + if device not in __models.keys(): + __models[device] = bert_models.load_model(Languages.ZH).to(device) # type: ignore style_res_mean = None with torch.no_grad(): @@ -50,13 +50,13 @@ def extract_bert_feature( inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) # type: ignore - res = models[device](**inputs, output_hidden_states=True) + res = __models[device](**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() if assist_text: style_inputs = tokenizer(assist_text, return_tensors="pt") for i in style_inputs: style_inputs[i] = style_inputs[i].to(device) # type: ignore - style_res = models[device](**style_inputs, output_hidden_states=True) + style_res = __models[device](**style_inputs, output_hidden_states=True) style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) diff --git a/style_bert_vits2/text_processing/english/bert_feature.py b/style_bert_vits2/text_processing/english/bert_feature.py index ec556c234..b29d53180 100644 --- a/style_bert_vits2/text_processing/english/bert_feature.py +++ b/style_bert_vits2/text_processing/english/bert_feature.py @@ -7,7 +7,7 @@ from style_bert_vits2.text_processing import bert_models -models: dict[torch.device | str, PreTrainedModel] = {} +__models: dict[torch.device | str, PreTrainedModel] = {} def extract_bert_feature( @@ -41,8 +41,8 @@ def extract_bert_feature( device = "cuda" if device == "cuda" and not torch.cuda.is_available(): device = "cpu" - if device not in models.keys(): - models[device] = bert_models.load_model(Languages.EN).to(device) # type: ignore + if device not in __models.keys(): + __models[device] = bert_models.load_model(Languages.EN).to(device) # type: ignore style_res_mean = None with torch.no_grad(): @@ -50,13 +50,13 @@ def extract_bert_feature( inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) # type: ignore - res = models[device](**inputs, output_hidden_states=True) + res = __models[device](**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() if assist_text: style_inputs = tokenizer(assist_text, return_tensors="pt") for i in style_inputs: style_inputs[i] = style_inputs[i].to(device) # type: ignore - style_res = models[device](**style_inputs, output_hidden_states=True) + style_res = __models[device](**style_inputs, output_hidden_states=True) style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) diff --git a/style_bert_vits2/text_processing/japanese/bert_feature.py b/style_bert_vits2/text_processing/japanese/bert_feature.py index 3ff9d7b81..d1809fe10 100644 --- a/style_bert_vits2/text_processing/japanese/bert_feature.py +++ b/style_bert_vits2/text_processing/japanese/bert_feature.py @@ -8,7 +8,7 @@ from style_bert_vits2.text_processing.japanese.g2p import text_to_sep_kata -models: dict[torch.device | str, PreTrainedModel] = {} +__models: dict[torch.device | str, PreTrainedModel] = {} def extract_bert_feature( @@ -48,8 +48,8 @@ def extract_bert_feature( device = "cuda" if device == "cuda" and not torch.cuda.is_available(): device = "cpu" - if device not in models.keys(): - models[device] = bert_models.load_model(Languages.JP).to(device) # type: ignore + if device not in __models.keys(): + __models[device] = bert_models.load_model(Languages.JP).to(device) # type: ignore style_res_mean = None with torch.no_grad(): @@ -57,13 +57,13 @@ def extract_bert_feature( inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) # type: ignore - res = models[device](**inputs, output_hidden_states=True) + res = __models[device](**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() if assist_text: style_inputs = tokenizer(assist_text, return_tensors="pt") for i in style_inputs: style_inputs[i] = style_inputs[i].to(device) # type: ignore - style_res = models[device](**style_inputs, output_hidden_states=True) + style_res = __models[device](**style_inputs, output_hidden_states=True) style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) From 4d5c537f959a18ee978a0ce7cc9ebaea5e754057 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Thu, 7 Mar 2024 19:52:34 +0000 Subject: [PATCH 038/148] Refactor: introducing Ruff --- losses.py | 2 -- pyproject.toml | 9 +++++++++ server_fastapi.py | 4 ++-- .../text_processing/english/__init__.py | 1 - .../text_processing/japanese/__init__.py | 4 ++-- train_ms_jp_extra.py | 1 - webui_style_vectors.py | 6 +++--- webui_train.py | 20 +++++++++---------- 8 files changed, 26 insertions(+), 21 deletions(-) create mode 100644 pyproject.toml diff --git a/losses.py b/losses.py index 4a890ba30..9bb50afdb 100644 --- a/losses.py +++ b/losses.py @@ -2,8 +2,6 @@ import torchaudio from transformers import AutoModel -from style_bert_vits2.logging import logger - def feature_loss(fmap_r, fmap_g): loss = 0 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..0249363fd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,9 @@ +[tool.ruff] +# インデント幅を 4 に設定 +indent-width = 4 + +# 行の長さを 100 文字に設定 +line-length = 100 + +# Python 3.10 向けにフォーマット +target-version = "py310" diff --git a/server_fastapi.py b/server_fastapi.py index ca9520c17..b7ebb772d 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -104,7 +104,7 @@ def load_models(model_holder: ModelHolder): @app.get("/voice", response_class=AudioResponse) async def voice( request: Request, - text: str = Query(..., min_length=1, max_length=limit, description=f"セリフ"), + text: str = Query(..., min_length=1, max_length=limit, description="セリフ"), encoding: str = Query(None, description="textをURLデコードする(ex, `utf-8`)"), model_id: int = Query( 0, description="モデルID。`GET /models/info`のkeyの値を指定ください" @@ -132,7 +132,7 @@ async def voice( DEFAULT_LENGTH, description="話速。基準は1で大きくするほど音声は長くなり読み上げが遅まる", ), - language: Languages = Query(ln, description=f"textの言語"), + language: Languages = Query(ln, description="textの言語"), auto_split: bool = Query(DEFAULT_LINE_SPLIT, description="改行で分けて生成"), split_interval: float = Query( DEFAULT_SPLIT_INTERVAL, description="分けた場合に挟む無音の長さ(秒)" diff --git a/style_bert_vits2/text_processing/english/__init__.py b/style_bert_vits2/text_processing/english/__init__.py index 852431aaa..b57e25067 100644 --- a/style_bert_vits2/text_processing/english/__init__.py +++ b/style_bert_vits2/text_processing/english/__init__.py @@ -234,7 +234,6 @@ def refine_syllables(syllables): return phonemes, tones -import re import inflect _inflect = inflect.engine() diff --git a/style_bert_vits2/text_processing/japanese/__init__.py b/style_bert_vits2/text_processing/japanese/__init__.py index 17e1785fe..91233773e 100644 --- a/style_bert_vits2/text_processing/japanese/__init__.py +++ b/style_bert_vits2/text_processing/japanese/__init__.py @@ -1,2 +1,2 @@ -from style_bert_vits2.text_processing.japanese.g2p import g2p # type: ignore -from style_bert_vits2.text_processing.japanese.normalizer import normalize_text # type: ignore +from style_bert_vits2.text_processing.japanese.g2p import g2p # noqa: F401 +from style_bert_vits2.text_processing.japanese.normalizer import normalize_text # noqa: F401 diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index 1a287d9c2..f04d2aef6 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -1,6 +1,5 @@ import argparse import datetime -import gc import os import platform diff --git a/webui_style_vectors.py b/webui_style_vectors.py index 1cbabbd51..b149e2b3e 100644 --- a/webui_style_vectors.py +++ b/webui_style_vectors.py @@ -153,7 +153,7 @@ def do_dbscan_gradio(eps=2.5, min_samples=15): return [ plt, gr.Slider(maximum=MAX_CLUSTER_NUM), - f"クラスタが数が0です。パラメータを変えてみてください。", + "クラスタが数が0です。パラメータを変えてみてください。", ] + [gr.Audio(visible=False)] * MAX_AUDIO_NUM return [plt, gr.Slider(maximum=n_clusters, value=1), n_clusters] + [ @@ -212,7 +212,7 @@ def save_style_vectors_from_clustering(model_name, style_names_str: str): if len(style_name_list) != len(centroids) + 1: return f"スタイルの数が合いません。`,`で正しく{len(centroids)}個に区切られているか確認してください: {style_names_str}" if len(set(style_names)) != len(style_names): - return f"スタイル名が重複しています。" + return "スタイル名が重複しています。" logger.info(f"Backup {config_path} to {config_path}.bak") shutil.copy(config_path, f"{config_path}.bak") @@ -243,7 +243,7 @@ def save_style_vectors_from_files( return f"音声ファイルとスタイル名の数が合いません。`,`で正しく{len(style_names)}個に区切られているか確認してください: {audio_files_str}と{style_names_str}" style_name_list = [DEFAULT_STYLE] + style_names if len(set(style_names)) != len(style_names): - return f"スタイル名が重複しています。" + return "スタイル名が重複しています。" style_vectors = [mean] wavs_dir = os.path.join(dataset_root, model_name, "wavs") diff --git a/webui_train.py b/webui_train.py index fda31ca2b..0a7000caf 100644 --- a/webui_train.py +++ b/webui_train.py @@ -147,10 +147,10 @@ def resample(model_name, normalize, trim, num_processes): cmd.append("--trim") success, message = run_script_with_log(cmd) if not success: - logger.error(f"Step 2: resampling failed.") + logger.error("Step 2: resampling failed.") return False, f"Step 2, Error: 音声ファイルの前処理に失敗しました:\n{message}" elif message: - logger.warning(f"Step 2: resampling finished with stderr.") + logger.warning("Step 2: resampling finished with stderr.") return True, f"Step 2, Success: 音声ファイルの前処理が完了しました:\n{message}" logger.success("Step 2: resampling finished.") return True, "Step 2, Success: 音声ファイルの前処理が完了しました" @@ -197,13 +197,13 @@ def preprocess_text(model_name, use_jp_extra, val_per_lang, yomi_error): cmd.append("--use_jp_extra") success, message = run_script_with_log(cmd) if not success: - logger.error(f"Step 3: preprocessing text failed.") + logger.error("Step 3: preprocessing text failed.") return ( False, f"Step 3, Error: 書き起こしファイルの前処理に失敗しました:\n{message}", ) elif message: - logger.warning(f"Step 3: preprocessing text finished with stderr.") + logger.warning("Step 3: preprocessing text finished with stderr.") return ( True, f"Step 3, Success: 書き起こしファイルの前処理が完了しました:\n{message}", @@ -225,10 +225,10 @@ def bert_gen(model_name): ] ) if not success: - logger.error(f"Step 4: bert_gen failed.") + logger.error("Step 4: bert_gen failed.") return False, f"Step 4, Error: BERT特徴ファイルの生成に失敗しました:\n{message}" elif message: - logger.warning(f"Step 4: bert_gen finished with stderr.") + logger.warning("Step 4: bert_gen finished with stderr.") return ( True, f"Step 4, Success: BERT特徴ファイルの生成が完了しました:\n{message}", @@ -250,13 +250,13 @@ def style_gen(model_name, num_processes): ] ) if not success: - logger.error(f"Step 5: style_gen failed.") + logger.error("Step 5: style_gen failed.") return ( False, f"Step 5, Error: スタイル特徴ファイルの生成に失敗しました:\n{message}", ) elif message: - logger.warning(f"Step 5: style_gen finished with stderr.") + logger.warning("Step 5: style_gen finished with stderr.") return ( True, f"Step 5, Success: スタイル特徴ファイルの生成が完了しました:\n{message}", @@ -350,10 +350,10 @@ def train(model_name, skip_style=False, use_jp_extra=True, speedup=False): cmd.append("--speedup") success, message = run_script_with_log(cmd, ignore_warning=True) if not success: - logger.error(f"Train failed.") + logger.error("Train failed.") return False, f"Error: 学習に失敗しました:\n{message}" elif message: - logger.warning(f"Train finished with stderr.") + logger.warning("Train finished with stderr.") return True, f"Success: 学習が完了しました:\n{message}" logger.success("Train finished.") return True, "Success: 学習が完了しました" From 4a3519c4b934b820ddb46f451b830c461c04962a Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 05:51:54 +0000 Subject: [PATCH 039/148] Remove: Ruff I have determined that this is excessive for this project at this time. --- pyproject.toml | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 0249363fd..000000000 --- a/pyproject.toml +++ /dev/null @@ -1,9 +0,0 @@ -[tool.ruff] -# インデント幅を 4 に設定 -indent-width = 4 - -# 行の長さを 100 文字に設定 -line-length = 100 - -# Python 3.10 向けにフォーマット -target-version = "py310" From 8add1b42023f23b82efb69d981dd9dbe7fbb14bf Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 06:14:50 +0000 Subject: [PATCH 040/148] Fix: maintain compatibility with Python 3.9 --- app.py | 4 +- server_editor.py | 7 ++-- style_bert_vits2/constants.py | 2 +- style_bert_vits2/models/commons.py | 14 +++---- style_bert_vits2/models/infer.py | 9 ++-- style_bert_vits2/text_processing/__init__.py | 5 ++- .../text_processing/bert_models.py | 42 +++++++++---------- .../text_processing/chinese/bert_feature.py | 5 ++- .../text_processing/english/bert_feature.py | 5 ++- .../text_processing/japanese/bert_feature.py | 5 ++- 10 files changed, 52 insertions(+), 46 deletions(-) diff --git a/app.py b/app.py index 56f0cd812..13929cd46 100644 --- a/app.py +++ b/app.py @@ -5,7 +5,7 @@ import torch import yaml -from style_bert_vits2.constants import GRADIO_THEME, VERSION +from style_bert_vits2.constants import GRADIO_THEME, LATEST_VERSION from common.tts_model import ModelHolder from webui import ( create_dataset_app, @@ -34,7 +34,7 @@ model_holder = ModelHolder(Path(assets_root), device) with gr.Blocks(theme=GRADIO_THEME) as app: - gr.Markdown(f"# Style-Bert-VITS2 WebUI (version {VERSION})") + gr.Markdown(f"# Style-Bert-VITS2 WebUI (version {LATEST_VERSION})") with gr.Tabs(): with gr.Tab("音声合成"): create_inference_app(model_holder=model_holder) diff --git a/server_editor.py b/server_editor.py index 73f76e809..b74d74f7b 100644 --- a/server_editor.py +++ b/server_editor.py @@ -16,6 +16,7 @@ from datetime import datetime from io import BytesIO from pathlib import Path +from typing import Optional import numpy as np import requests @@ -37,7 +38,7 @@ DEFAULT_SDP_RATIO, DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, - VERSION, + LATEST_VERSION, Languages, ) from style_bert_vits2.logging import logger @@ -212,7 +213,7 @@ class AudioResponse(Response): @router.get("/version") def version() -> str: - return VERSION + return LATEST_VERSION class MoraTone(BaseModel): @@ -265,7 +266,7 @@ class SynthesisRequest(BaseModel): silenceAfter: float = 0.5 pitchScale: float = 1.0 intonationScale: float = 1.0 - speaker: str | None = None + speaker: Optional[str] = None @router.post("/synthesis", response_class=AudioResponse) diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index 7b83595b7..064221ef2 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -3,7 +3,7 @@ # Style-Bert-VITS2 のバージョン -VERSION = "2.4" +LATEST_VERSION = "2.4" # Style-Bert-VITS2 のベースディレクトリ BASE_DIR = Path(__file__).parent.parent diff --git a/style_bert_vits2/models/commons.py b/style_bert_vits2/models/commons.py index 969ed3694..89d07d507 100644 --- a/style_bert_vits2/models/commons.py +++ b/style_bert_vits2/models/commons.py @@ -5,7 +5,7 @@ import torch from torch.nn import functional as F -from typing import Any +from typing import Any, Optional def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: @@ -85,13 +85,13 @@ def slice_segments(x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4 return torch.gather(x, 2, gather_indices) -def rand_slice_segments(x: torch.Tensor, x_lengths: torch.Tensor | None = None, segment_size: int = 4) -> tuple[torch.Tensor, torch.Tensor]: +def rand_slice_segments(x: torch.Tensor, x_lengths: Optional[torch.Tensor] = None, segment_size: int = 4) -> tuple[torch.Tensor, torch.Tensor]: """ ランダムなセグメントをスライスする Args: x (torch.Tensor): 入力テンソル - x_lengths (torch.Tensor, optional): 各バッチの長さ (デフォルト: None) + x_lengths (Optional[torch.Tensor], optional): 各バッチの長さ (デフォルト: None) segment_size (int, optional): スライスのサイズ (デフォルト: 4) Returns: @@ -141,13 +141,13 @@ def fused_add_tanh_sigmoid_multiply(input_a: torch.Tensor, input_b: torch.Tensor return acts -def sequence_mask(length: torch.Tensor, max_length: int | None = None) -> torch.Tensor: +def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None) -> torch.Tensor: """ シーケンスマスクを生成する Args: length (torch.Tensor): 各シーケンスの長さ - max_length (int | None): 最大のシーケンス長さ。指定されていない場合は length の最大値を使用 + max_length (Optional[int]): 最大のシーケンス長さ。指定されていない場合は length の最大値を使用 Returns: torch.Tensor: 生成されたシーケンスマスク @@ -180,13 +180,13 @@ def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: return path -def clip_grad_value_(parameters: torch.Tensor | list[torch.Tensor], clip_value: float | None, norm_type: float = 2.0) -> float: +def clip_grad_value_(parameters: torch.Tensor | list[torch.Tensor], clip_value: Optional[float], norm_type: float = 2.0) -> float: """ 勾配の値をクリップする Args: parameters (torch.Tensor | list[torch.Tensor]): クリップするパラメータ - clip_value (float | None): クリップする値。None の場合はクリップしない + clip_value (Optional[float]): クリップする値。None の場合はクリップしない norm_type (float): ノルムの種類 Returns: diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index 5eb9ebda5..a5ce70f61 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -1,4 +1,5 @@ import torch +from typing import Optional import utils from style_bert_vits2.constants import Languages @@ -45,9 +46,9 @@ def get_text( language_str: Languages, hps, device: str, - assist_text: str | None = None, + assist_text: Optional[str] = None, assist_text_weight: float = 0.7, - given_tone: list[int] | None = None, + given_tone: Optional[list[int]] = None, ): use_jp_extra = hps.version.endswith("JP-Extra") # 推論時のみ呼び出されるので、raise_yomi_error は False に設定 @@ -122,9 +123,9 @@ def infer( device: str, skip_start: bool = False, skip_end: bool = False, - assist_text: str | None = None, + assist_text: Optional[str] = None, assist_text_weight: float = 0.7, - given_tone: list[int] | None = None, + given_tone: Optional[list[int]] = None, ): is_jp_extra = hps.version.endswith("JP-Extra") bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( diff --git a/style_bert_vits2/text_processing/__init__.py b/style_bert_vits2/text_processing/__init__.py index 471952504..4dadaed73 100644 --- a/style_bert_vits2/text_processing/__init__.py +++ b/style_bert_vits2/text_processing/__init__.py @@ -1,4 +1,5 @@ import torch +from typing import Optional from style_bert_vits2.constants import Languages from style_bert_vits2.text_processing.symbols import ( @@ -16,7 +17,7 @@ def extract_bert_feature( word2ph: list[int], language: Languages, device: torch.device | str, - assist_text: str | None = None, + assist_text: Optional[str] = None, assist_text_weight: float = 0.7, ) -> torch.Tensor: """ @@ -27,7 +28,7 @@ def extract_bert_feature( word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト language (Languages): テキストの言語 device (torch.device | str): 推論に利用するデバイス - assist_text (str | None, optional): 補助テキスト (デフォルト: None) + assist_text (Optional[str], optional): 補助テキスト (デフォルト: None) assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) Returns: diff --git a/style_bert_vits2/text_processing/bert_models.py b/style_bert_vits2/text_processing/bert_models.py index e8ef4b460..df7a0162b 100644 --- a/style_bert_vits2/text_processing/bert_models.py +++ b/style_bert_vits2/text_processing/bert_models.py @@ -9,7 +9,7 @@ """ import gc -from typing import cast +from typing import cast, Optional import torch from transformers import ( @@ -35,23 +35,23 @@ def load_model( language: Languages, - pretrained_model_name_or_path: str | None = None, + pretrained_model_name_or_path: Optional[str] = None, ) -> PreTrainedModel | DebertaV2Model: """ - 指定された言語の BERT モデルをロードし、ロード済みの BERT モデルを返す - 一度ロードされていれば、ロード済みの BERT モデルを即座に返す - ライブラリ利用時は常に pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある - ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき + 指定された言語の BERT モデルをロードし、ロード済みの BERT モデルを返す。 + 一度ロードされていれば、ロード済みの BERT モデルを即座に返す。 + ライブラリ利用時は常に pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある。 + ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき。 - Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている - これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い + Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている。 + これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い。 - 日本語: ku-nlp/deberta-v2-large-japanese-char-wwm - 英語: microsoft/deberta-v3-large - 中国語: hfl/chinese-roberta-wwm-ext-large Args: language (Languages): ロードする学習済みモデルの対象言語 - pretrained_model_name_or_path (str | None): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) + pretrained_model_name_or_path (Optional[str]): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) Returns: PreTrainedModel | DebertaV2Model: ロード済みの BERT モデル @@ -81,23 +81,23 @@ def load_model( def load_tokenizer( language: Languages, - pretrained_model_name_or_path: str | None = None, + pretrained_model_name_or_path: Optional[str] = None, ) -> PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer: """ - 指定された言語の BERT モデルをロードし、ロード済みの BERT トークナイザーを返す - 一度ロードされていれば、ロード済みの BERT トークナイザーを即座に返す - ライブラリ利用時は常に pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある - ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき + 指定された言語の BERT モデルをロードし、ロード済みの BERT トークナイザーを返す。 + 一度ロードされていれば、ロード済みの BERT トークナイザーを即座に返す。 + ライブラリ利用時は常に pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある。 + ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき。 - Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている - これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い + Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている。 + これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い。 - 日本語: ku-nlp/deberta-v2-large-japanese-char-wwm - 英語: microsoft/deberta-v3-large - 中国語: hfl/chinese-roberta-wwm-ext-large Args: language (Languages): ロードする学習済みモデルの対象言語 - pretrained_model_name_or_path (str | None): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) + pretrained_model_name_or_path (Optional[str]): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) Returns: PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer: ロード済みの BERT トークナイザー @@ -127,7 +127,7 @@ def load_tokenizer( def unload_model(language: Languages) -> None: """ - 指定された言語の BERT モデルをアンロードする + 指定された言語の BERT モデルをアンロードする。 Args: language (Languages): アンロードする BERT モデルの言語 @@ -143,7 +143,7 @@ def unload_model(language: Languages) -> None: def unload_tokenizer(language: Languages) -> None: """ - 指定された言語の BERT トークナイザーをアンロードする + 指定された言語の BERT トークナイザーをアンロードする。 Args: language (Languages): アンロードする BERT トークナイザーの言語 @@ -159,7 +159,7 @@ def unload_tokenizer(language: Languages) -> None: def unload_all_models() -> None: """ - すべての BERT モデルをアンロードする + すべての BERT モデルをアンロードする。 """ for language in list(__loaded_models.keys()): @@ -169,7 +169,7 @@ def unload_all_models() -> None: def unload_all_tokenizers() -> None: """ - すべての BERT トークナイザーをアンロードする + すべての BERT トークナイザーをアンロードする。 """ for language in list(__loaded_tokenizers.keys()): diff --git a/style_bert_vits2/text_processing/chinese/bert_feature.py b/style_bert_vits2/text_processing/chinese/bert_feature.py index 3178565a1..f8bce0416 100644 --- a/style_bert_vits2/text_processing/chinese/bert_feature.py +++ b/style_bert_vits2/text_processing/chinese/bert_feature.py @@ -1,4 +1,5 @@ import sys +from typing import Optional import torch from transformers import PreTrainedModel @@ -14,7 +15,7 @@ def extract_bert_feature( text: str, word2ph: list[int], device: torch.device | str, - assist_text: str | None = None, + assist_text: Optional[str] = None, assist_text_weight: float = 0.7, ) -> torch.Tensor: """ @@ -24,7 +25,7 @@ def extract_bert_feature( text (str): 中国語のテキスト word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト device (torch.device | str): 推論に利用するデバイス - assist_text (str | None, optional): 補助テキスト (デフォルト: None) + assist_text (Optional[str], optional): 補助テキスト (デフォルト: None) assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) Returns: diff --git a/style_bert_vits2/text_processing/english/bert_feature.py b/style_bert_vits2/text_processing/english/bert_feature.py index b29d53180..79f24c9c2 100644 --- a/style_bert_vits2/text_processing/english/bert_feature.py +++ b/style_bert_vits2/text_processing/english/bert_feature.py @@ -1,4 +1,5 @@ import sys +from typing import Optional import torch from transformers import PreTrainedModel @@ -14,7 +15,7 @@ def extract_bert_feature( text: str, word2ph: list[int], device: torch.device | str, - assist_text: str | None = None, + assist_text: Optional[str] = None, assist_text_weight: float = 0.7, ) -> torch.Tensor: """ @@ -24,7 +25,7 @@ def extract_bert_feature( text (str): 英語のテキスト word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト device (torch.device | str): 推論に利用するデバイス - assist_text (str | None, optional): 補助テキスト (デフォルト: None) + assist_text (Optional[str], optional): 補助テキスト (デフォルト: None) assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) Returns: diff --git a/style_bert_vits2/text_processing/japanese/bert_feature.py b/style_bert_vits2/text_processing/japanese/bert_feature.py index d1809fe10..3a44a8d6d 100644 --- a/style_bert_vits2/text_processing/japanese/bert_feature.py +++ b/style_bert_vits2/text_processing/japanese/bert_feature.py @@ -1,4 +1,5 @@ import sys +from typing import Optional import torch from transformers import PreTrainedModel @@ -15,7 +16,7 @@ def extract_bert_feature( text: str, word2ph: list[int], device: torch.device | str, - assist_text: str | None = None, + assist_text: Optional[str] = None, assist_text_weight: float = 0.7, ) -> torch.Tensor: """ @@ -25,7 +26,7 @@ def extract_bert_feature( text (str): 日本語のテキスト word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト device (torch.device | str): 推論に利用するデバイス - assist_text (str | None, optional): 補助テキスト (デフォルト: None) + assist_text (Optional[str], optional): 補助テキスト (デフォルト: None) assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) Returns: From fac4f9a8ab5f8bdf5eb714cc149658210925163f Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 06:20:44 +0000 Subject: [PATCH 041/148] Refactor: rename text_processing to nlp "text_processing" is clearer, but the import statement is longer. "nlp" is shorter and makes it clear that it is natural language processing. --- bert_gen.py | 4 ++-- data_utils.py | 2 +- preprocess_text.py | 2 +- server_editor.py | 8 ++++---- style_bert_vits2/constants.py | 2 +- style_bert_vits2/models/infer.py | 4 ++-- style_bert_vits2/models/models.py | 2 +- style_bert_vits2/models/models_jp_extra.py | 2 +- .../{text_processing => nlp}/__init__.py | 14 +++++++------- .../{text_processing => nlp}/bert_models.py | 0 .../{text_processing => nlp}/chinese/__init__.py | 6 +++--- .../chinese/bert_feature.py | 2 +- .../chinese/tone_sandhi.py | 0 .../{text_processing => nlp}/english/__init__.py | 4 ++-- .../english/bert_feature.py | 2 +- .../{text_processing => nlp}/english/cmudict.rep | 0 .../english/cmudict_cache.pickle | Bin .../english/opencpop-strict.txt | 0 style_bert_vits2/nlp/japanese/__init__.py | 2 ++ .../japanese/bert_feature.py | 4 ++-- .../{text_processing => nlp}/japanese/g2p.py | 10 +++++----- .../japanese/g2p_utils.py | 6 +++--- .../japanese/mora_list.py | 0 .../japanese/normalizer.py | 2 +- .../japanese/pyopenjtalk_worker/__init__.py | 4 ++-- .../japanese/pyopenjtalk_worker/__main__.py | 4 ++-- .../japanese/pyopenjtalk_worker/worker_client.py | 2 +- .../japanese/pyopenjtalk_worker/worker_common.py | 0 .../japanese/pyopenjtalk_worker/worker_server.py | 2 +- .../japanese/user_dict/README.md | 0 .../japanese/user_dict/__init__.py | 6 +++--- .../japanese/user_dict/part_of_speech_data.py | 2 +- .../japanese/user_dict/word_model.py | 0 .../{text_processing => nlp}/symbols.py | 0 .../text_processing/japanese/__init__.py | 2 -- train_ms.py | 2 +- train_ms_jp_extra.py | 2 +- webui/inference.py | 4 ++-- 38 files changed, 54 insertions(+), 54 deletions(-) rename style_bert_vits2/{text_processing => nlp}/__init__.py (86%) rename style_bert_vits2/{text_processing => nlp}/bert_models.py (100%) rename style_bert_vits2/{text_processing => nlp}/chinese/__init__.py (95%) rename style_bert_vits2/{text_processing => nlp}/chinese/bert_feature.py (98%) rename style_bert_vits2/{text_processing => nlp}/chinese/tone_sandhi.py (100%) rename style_bert_vits2/{text_processing => nlp}/english/__init__.py (98%) rename style_bert_vits2/{text_processing => nlp}/english/bert_feature.py (98%) rename style_bert_vits2/{text_processing => nlp}/english/cmudict.rep (100%) rename style_bert_vits2/{text_processing => nlp}/english/cmudict_cache.pickle (100%) rename style_bert_vits2/{text_processing => nlp}/english/opencpop-strict.txt (100%) create mode 100644 style_bert_vits2/nlp/japanese/__init__.py rename style_bert_vits2/{text_processing => nlp}/japanese/bert_feature.py (96%) rename style_bert_vits2/{text_processing => nlp}/japanese/g2p.py (98%) rename style_bert_vits2/{text_processing => nlp}/japanese/g2p_utils.py (93%) rename style_bert_vits2/{text_processing => nlp}/japanese/mora_list.py (100%) rename style_bert_vits2/{text_processing => nlp}/japanese/normalizer.py (98%) rename style_bert_vits2/{text_processing => nlp}/japanese/pyopenjtalk_worker/__init__.py (94%) rename style_bert_vits2/{text_processing => nlp}/japanese/pyopenjtalk_worker/__main__.py (58%) rename style_bert_vits2/{text_processing => nlp}/japanese/pyopenjtalk_worker/worker_client.py (93%) rename style_bert_vits2/{text_processing => nlp}/japanese/pyopenjtalk_worker/worker_common.py (100%) rename style_bert_vits2/{text_processing => nlp}/japanese/pyopenjtalk_worker/worker_server.py (98%) rename style_bert_vits2/{text_processing => nlp}/japanese/user_dict/README.md (100%) rename style_bert_vits2/{text_processing => nlp}/japanese/user_dict/__init__.py (98%) rename style_bert_vits2/{text_processing => nlp}/japanese/user_dict/part_of_speech_data.py (97%) rename style_bert_vits2/{text_processing => nlp}/japanese/user_dict/word_model.py (100%) rename style_bert_vits2/{text_processing => nlp}/symbols.py (100%) delete mode 100644 style_bert_vits2/text_processing/japanese/__init__.py diff --git a/bert_gen.py b/bert_gen.py index f22c2e872..79ea65f40 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -9,8 +9,8 @@ from config import config from style_bert_vits2.logging import logger from style_bert_vits2.models import commons -from style_bert_vits2.text_processing import cleaned_text_to_sequence, extract_bert_feature -from style_bert_vits2.text_processing.japanese import pyopenjtalk_worker as pyopenjtalk +from style_bert_vits2.nlp import cleaned_text_to_sequence, extract_bert_feature +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT pyopenjtalk.initialize() diff --git a/data_utils.py b/data_utils.py index 7738247ea..460c15fd4 100644 --- a/data_utils.py +++ b/data_utils.py @@ -12,7 +12,7 @@ from utils import load_filepaths_and_text, load_wav_to_torch from style_bert_vits2.logging import logger from style_bert_vits2.models import commons -from style_bert_vits2.text_processing import cleaned_text_to_sequence +from style_bert_vits2.nlp import cleaned_text_to_sequence """Multi speaker version""" diff --git a/preprocess_text.py b/preprocess_text.py index 03e1232ad..4966305e3 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -9,7 +9,7 @@ from config import config from style_bert_vits2.logging import logger -from style_bert_vits2.text_processing import clean_text +from style_bert_vits2.nlp import clean_text from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT preprocess_text_config = config.preprocess_text_config diff --git a/server_editor.py b/server_editor.py index b74d74f7b..f72800de4 100644 --- a/server_editor.py +++ b/server_editor.py @@ -42,10 +42,10 @@ Languages, ) from style_bert_vits2.logging import logger -from style_bert_vits2.text_processing import bert_models -from style_bert_vits2.text_processing.japanese import normalize_text -from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone -from style_bert_vits2.text_processing.japanese.user_dict import ( +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.nlp.japanese import normalize_text +from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone +from style_bert_vits2.nlp.japanese.user_dict import ( apply_word, delete_word, read_dict, diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index 064221ef2..d5a46bfd7 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -23,7 +23,7 @@ class Languages(StrEnum): } # デフォルトのユーザー辞書ディレクトリ -## style_bert_vits2.text_processing.japanese.user_dict モジュールのデフォルト値として利用される +## style_bert_vits2.nlp.japanese.user_dict モジュールのデフォルト値として利用される ## ライブラリとしての利用などで外部のユーザー辞書を指定したい場合は、user_dict 以下の各関数の実行時、引数に辞書データファイルのパスを指定する DEFAULT_USER_DICT_DIR = BASE_DIR / "dict_data" diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index a5ce70f61..9265eeec2 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -7,8 +7,8 @@ from style_bert_vits2.models import commons from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra -from style_bert_vits2.text_processing import clean_text, cleaned_text_to_sequence, extract_bert_feature -from style_bert_vits2.text_processing.symbols import SYMBOLS +from style_bert_vits2.nlp import clean_text, cleaned_text_to_sequence, extract_bert_feature +from style_bert_vits2.nlp.symbols import SYMBOLS def get_net_g(model_path: str, version: str, device: str, hps): diff --git a/style_bert_vits2/models/models.py b/style_bert_vits2/models/models.py index 7f14be431..b4e1eb718 100644 --- a/style_bert_vits2/models/models.py +++ b/style_bert_vits2/models/models.py @@ -11,7 +11,7 @@ from style_bert_vits2.models import modules from style_bert_vits2.models import monotonic_alignment from style_bert_vits2.models.commons import get_padding, init_weights -from style_bert_vits2.text_processing.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS +from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS class DurationDiscriminator(nn.Module): # vits2 diff --git a/style_bert_vits2/models/models_jp_extra.py b/style_bert_vits2/models/models_jp_extra.py index 54591f6fb..7bae8b45d 100644 --- a/style_bert_vits2/models/models_jp_extra.py +++ b/style_bert_vits2/models/models_jp_extra.py @@ -10,7 +10,7 @@ from style_bert_vits2.models import commons from style_bert_vits2.models import modules from style_bert_vits2.models import monotonic_alignment -from style_bert_vits2.text_processing.symbols import SYMBOLS, NUM_TONES, NUM_LANGUAGES +from style_bert_vits2.nlp.symbols import SYMBOLS, NUM_TONES, NUM_LANGUAGES class DurationDiscriminator(nn.Module): # vits2 diff --git a/style_bert_vits2/text_processing/__init__.py b/style_bert_vits2/nlp/__init__.py similarity index 86% rename from style_bert_vits2/text_processing/__init__.py rename to style_bert_vits2/nlp/__init__.py index 4dadaed73..47786d128 100644 --- a/style_bert_vits2/text_processing/__init__.py +++ b/style_bert_vits2/nlp/__init__.py @@ -2,7 +2,7 @@ from typing import Optional from style_bert_vits2.constants import Languages -from style_bert_vits2.text_processing.symbols import ( +from style_bert_vits2.nlp.symbols import ( LANGUAGE_ID_MAP, LANGUAGE_TONE_START_MAP, SYMBOLS, @@ -36,11 +36,11 @@ def extract_bert_feature( """ if language == Languages.JP: - from style_bert_vits2.text_processing.japanese.bert_feature import extract_bert_feature + from style_bert_vits2.nlp.japanese.bert_feature import extract_bert_feature elif language == Languages.EN: - from style_bert_vits2.text_processing.english.bert_feature import extract_bert_feature + from style_bert_vits2.nlp.english.bert_feature import extract_bert_feature elif language == Languages.ZH: - from style_bert_vits2.text_processing.chinese.bert_feature import extract_bert_feature + from style_bert_vits2.nlp.chinese.bert_feature import extract_bert_feature else: raise ValueError(f"Language {language} not supported") @@ -68,15 +68,15 @@ def clean_text( # Changed to import inside if condition to avoid unnecessary import if language == Languages.JP: - from style_bert_vits2.text_processing.japanese import g2p, normalize_text + from style_bert_vits2.nlp.japanese import g2p, normalize_text norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text, use_jp_extra, raise_yomi_error) elif language == Languages.EN: - from style_bert_vits2.text_processing.english import g2p, normalize_text + from style_bert_vits2.nlp.english import g2p, normalize_text norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text) elif language == Languages.ZH: - from style_bert_vits2.text_processing.chinese import g2p, normalize_text + from style_bert_vits2.nlp.chinese import g2p, normalize_text norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text) else: diff --git a/style_bert_vits2/text_processing/bert_models.py b/style_bert_vits2/nlp/bert_models.py similarity index 100% rename from style_bert_vits2/text_processing/bert_models.py rename to style_bert_vits2/nlp/bert_models.py diff --git a/style_bert_vits2/text_processing/chinese/__init__.py b/style_bert_vits2/nlp/chinese/__init__.py similarity index 95% rename from style_bert_vits2/text_processing/chinese/__init__.py rename to style_bert_vits2/nlp/chinese/__init__.py index 23acbf3a1..5b708636d 100644 --- a/style_bert_vits2/text_processing/chinese/__init__.py +++ b/style_bert_vits2/nlp/chinese/__init__.py @@ -5,8 +5,8 @@ import jieba.posseg as psg from pypinyin import lazy_pinyin, Style -from style_bert_vits2.text_processing.chinese.tone_sandhi import ToneSandhi -from style_bert_vits2.text_processing.symbols import PUNCTUATIONS +from style_bert_vits2.nlp.chinese.tone_sandhi import ToneSandhi +from style_bert_vits2.nlp.symbols import PUNCTUATIONS current_file_path = os.path.dirname(__file__) @@ -177,7 +177,7 @@ def normalize_text(text: str) -> str: if __name__ == "__main__": - from style_bert_vits2.text_processing.chinese.bert_feature import extract_bert_feature + from style_bert_vits2.nlp.chinese.bert_feature import extract_bert_feature text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏" text = normalize_text(text) diff --git a/style_bert_vits2/text_processing/chinese/bert_feature.py b/style_bert_vits2/nlp/chinese/bert_feature.py similarity index 98% rename from style_bert_vits2/text_processing/chinese/bert_feature.py rename to style_bert_vits2/nlp/chinese/bert_feature.py index f8bce0416..b97950b0c 100644 --- a/style_bert_vits2/text_processing/chinese/bert_feature.py +++ b/style_bert_vits2/nlp/chinese/bert_feature.py @@ -5,7 +5,7 @@ from transformers import PreTrainedModel from style_bert_vits2.constants import Languages -from style_bert_vits2.text_processing import bert_models +from style_bert_vits2.nlp import bert_models __models: dict[torch.device | str, PreTrainedModel] = {} diff --git a/style_bert_vits2/text_processing/chinese/tone_sandhi.py b/style_bert_vits2/nlp/chinese/tone_sandhi.py similarity index 100% rename from style_bert_vits2/text_processing/chinese/tone_sandhi.py rename to style_bert_vits2/nlp/chinese/tone_sandhi.py diff --git a/style_bert_vits2/text_processing/english/__init__.py b/style_bert_vits2/nlp/english/__init__.py similarity index 98% rename from style_bert_vits2/text_processing/english/__init__.py rename to style_bert_vits2/nlp/english/__init__.py index b57e25067..3610b13bd 100644 --- a/style_bert_vits2/text_processing/english/__init__.py +++ b/style_bert_vits2/nlp/english/__init__.py @@ -4,8 +4,8 @@ from g2p_en import G2p from style_bert_vits2.constants import Languages -from style_bert_vits2.text_processing import bert_models -from style_bert_vits2.text_processing.symbols import PUNCTUATIONS, SYMBOLS +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.nlp.symbols import PUNCTUATIONS, SYMBOLS current_file_path = os.path.dirname(__file__) diff --git a/style_bert_vits2/text_processing/english/bert_feature.py b/style_bert_vits2/nlp/english/bert_feature.py similarity index 98% rename from style_bert_vits2/text_processing/english/bert_feature.py rename to style_bert_vits2/nlp/english/bert_feature.py index 79f24c9c2..647920d23 100644 --- a/style_bert_vits2/text_processing/english/bert_feature.py +++ b/style_bert_vits2/nlp/english/bert_feature.py @@ -5,7 +5,7 @@ from transformers import PreTrainedModel from style_bert_vits2.constants import Languages -from style_bert_vits2.text_processing import bert_models +from style_bert_vits2.nlp import bert_models __models: dict[torch.device | str, PreTrainedModel] = {} diff --git a/style_bert_vits2/text_processing/english/cmudict.rep b/style_bert_vits2/nlp/english/cmudict.rep similarity index 100% rename from style_bert_vits2/text_processing/english/cmudict.rep rename to style_bert_vits2/nlp/english/cmudict.rep diff --git a/style_bert_vits2/text_processing/english/cmudict_cache.pickle b/style_bert_vits2/nlp/english/cmudict_cache.pickle similarity index 100% rename from style_bert_vits2/text_processing/english/cmudict_cache.pickle rename to style_bert_vits2/nlp/english/cmudict_cache.pickle diff --git a/style_bert_vits2/text_processing/english/opencpop-strict.txt b/style_bert_vits2/nlp/english/opencpop-strict.txt similarity index 100% rename from style_bert_vits2/text_processing/english/opencpop-strict.txt rename to style_bert_vits2/nlp/english/opencpop-strict.txt diff --git a/style_bert_vits2/nlp/japanese/__init__.py b/style_bert_vits2/nlp/japanese/__init__.py new file mode 100644 index 000000000..5c7f19f97 --- /dev/null +++ b/style_bert_vits2/nlp/japanese/__init__.py @@ -0,0 +1,2 @@ +from style_bert_vits2.nlp.japanese.g2p import g2p # noqa: F401 +from style_bert_vits2.nlp.japanese.normalizer import normalize_text # noqa: F401 diff --git a/style_bert_vits2/text_processing/japanese/bert_feature.py b/style_bert_vits2/nlp/japanese/bert_feature.py similarity index 96% rename from style_bert_vits2/text_processing/japanese/bert_feature.py rename to style_bert_vits2/nlp/japanese/bert_feature.py index 3a44a8d6d..ede1f83b9 100644 --- a/style_bert_vits2/text_processing/japanese/bert_feature.py +++ b/style_bert_vits2/nlp/japanese/bert_feature.py @@ -5,8 +5,8 @@ from transformers import PreTrainedModel from style_bert_vits2.constants import Languages -from style_bert_vits2.text_processing import bert_models -from style_bert_vits2.text_processing.japanese.g2p import text_to_sep_kata +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.nlp.japanese.g2p import text_to_sep_kata __models: dict[torch.device | str, PreTrainedModel] = {} diff --git a/style_bert_vits2/text_processing/japanese/g2p.py b/style_bert_vits2/nlp/japanese/g2p.py similarity index 98% rename from style_bert_vits2/text_processing/japanese/g2p.py rename to style_bert_vits2/nlp/japanese/g2p.py index 2f9746b31..45e781778 100644 --- a/style_bert_vits2/text_processing/japanese/g2p.py +++ b/style_bert_vits2/nlp/japanese/g2p.py @@ -2,11 +2,11 @@ from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger -from style_bert_vits2.text_processing import bert_models -from style_bert_vits2.text_processing.japanese import pyopenjtalk_worker as pyopenjtalk -from style_bert_vits2.text_processing.japanese.mora_list import MORA_KATA_TO_MORA_PHONEMES -from style_bert_vits2.text_processing.japanese.normalizer import replace_punctuation -from style_bert_vits2.text_processing.symbols import PUNCTUATIONS +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk +from style_bert_vits2.nlp.japanese.mora_list import MORA_KATA_TO_MORA_PHONEMES +from style_bert_vits2.nlp.japanese.normalizer import replace_punctuation +from style_bert_vits2.nlp.symbols import PUNCTUATIONS pyopenjtalk.initialize() diff --git a/style_bert_vits2/text_processing/japanese/g2p_utils.py b/style_bert_vits2/nlp/japanese/g2p_utils.py similarity index 93% rename from style_bert_vits2/text_processing/japanese/g2p_utils.py rename to style_bert_vits2/nlp/japanese/g2p_utils.py index 4ea56e819..893d3b531 100644 --- a/style_bert_vits2/text_processing/japanese/g2p_utils.py +++ b/style_bert_vits2/nlp/japanese/g2p_utils.py @@ -1,9 +1,9 @@ -from style_bert_vits2.text_processing.japanese.g2p import g2p -from style_bert_vits2.text_processing.japanese.mora_list import ( +from style_bert_vits2.nlp.japanese.g2p import g2p +from style_bert_vits2.nlp.japanese.mora_list import ( MORA_KATA_TO_MORA_PHONEMES, MORA_PHONEMES_TO_MORA_KATA, ) -from style_bert_vits2.text_processing.symbols import PUNCTUATIONS +from style_bert_vits2.nlp.symbols import PUNCTUATIONS def g2kata_tone(norm_text: str) -> list[tuple[str, int]]: diff --git a/style_bert_vits2/text_processing/japanese/mora_list.py b/style_bert_vits2/nlp/japanese/mora_list.py similarity index 100% rename from style_bert_vits2/text_processing/japanese/mora_list.py rename to style_bert_vits2/nlp/japanese/mora_list.py diff --git a/style_bert_vits2/text_processing/japanese/normalizer.py b/style_bert_vits2/nlp/japanese/normalizer.py similarity index 98% rename from style_bert_vits2/text_processing/japanese/normalizer.py rename to style_bert_vits2/nlp/japanese/normalizer.py index 827633878..b8cad9045 100644 --- a/style_bert_vits2/text_processing/japanese/normalizer.py +++ b/style_bert_vits2/nlp/japanese/normalizer.py @@ -2,7 +2,7 @@ import unicodedata from num2words import num2words -from style_bert_vits2.text_processing.symbols import PUNCTUATIONS +from style_bert_vits2.nlp.symbols import PUNCTUATIONS def normalize_text(text: str) -> str: diff --git a/style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/__init__.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py similarity index 94% rename from style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/__init__.py rename to style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py index c5fc6974d..fc8f6dab8 100644 --- a/style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/__init__.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py @@ -6,8 +6,8 @@ from typing import Any, Optional from style_bert_vits2.logging import logger -from style_bert_vits2.text_processing.japanese.pyopenjtalk_worker.worker_client import WorkerClient -from style_bert_vits2.text_processing.japanese.pyopenjtalk_worker.worker_common import WORKER_PORT +from style_bert_vits2.nlp.japanese.pyopenjtalk_worker.worker_client import WorkerClient +from style_bert_vits2.nlp.japanese.pyopenjtalk_worker.worker_common import WORKER_PORT WORKER_CLIENT: Optional[WorkerClient] = None diff --git a/style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/__main__.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__main__.py similarity index 58% rename from style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/__main__.py rename to style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__main__.py index f022901f6..2452a1643 100644 --- a/style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/__main__.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__main__.py @@ -1,7 +1,7 @@ import argparse -from style_bert_vits2.text_processing.japanese.pyopenjtalk_worker.worker_common import WORKER_PORT -from style_bert_vits2.text_processing.japanese.pyopenjtalk_worker.worker_server import WorkerServer +from style_bert_vits2.nlp.japanese.pyopenjtalk_worker.worker_common import WORKER_PORT +from style_bert_vits2.nlp.japanese.pyopenjtalk_worker.worker_server import WorkerServer def main() -> None: diff --git a/style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/worker_client.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py similarity index 93% rename from style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/worker_client.py rename to style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py index 0bfe87a06..b87a93757 100644 --- a/style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/worker_client.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py @@ -2,7 +2,7 @@ from typing import Any, cast from style_bert_vits2.logging import logger -from style_bert_vits2.text_processing.japanese.pyopenjtalk_worker.worker_common import RequestType, receive_data, send_data +from style_bert_vits2.nlp.japanese.pyopenjtalk_worker.worker_common import RequestType, receive_data, send_data class WorkerClient: diff --git a/style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/worker_common.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_common.py similarity index 100% rename from style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/worker_common.py rename to style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_common.py diff --git a/style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/worker_server.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py similarity index 98% rename from style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/worker_server.py rename to style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py index 4282c789b..e9a3dc140 100644 --- a/style_bert_vits2/text_processing/japanese/pyopenjtalk_worker/worker_server.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py @@ -6,7 +6,7 @@ import pyopenjtalk from style_bert_vits2.logging import logger -from style_bert_vits2.text_processing.japanese.pyopenjtalk_worker.worker_common import ( +from style_bert_vits2.nlp.japanese.pyopenjtalk_worker.worker_common import ( ConnectionClosedException, RequestType, receive_data, diff --git a/style_bert_vits2/text_processing/japanese/user_dict/README.md b/style_bert_vits2/nlp/japanese/user_dict/README.md similarity index 100% rename from style_bert_vits2/text_processing/japanese/user_dict/README.md rename to style_bert_vits2/nlp/japanese/user_dict/README.md diff --git a/style_bert_vits2/text_processing/japanese/user_dict/__init__.py b/style_bert_vits2/nlp/japanese/user_dict/__init__.py similarity index 98% rename from style_bert_vits2/text_processing/japanese/user_dict/__init__.py rename to style_bert_vits2/nlp/japanese/user_dict/__init__.py index 041f37c42..097cc6aad 100644 --- a/style_bert_vits2/text_processing/japanese/user_dict/__init__.py +++ b/style_bert_vits2/nlp/japanese/user_dict/__init__.py @@ -16,9 +16,9 @@ from fastapi import HTTPException from style_bert_vits2.constants import DEFAULT_USER_DICT_DIR -from style_bert_vits2.text_processing.japanese import pyopenjtalk_worker as pyopenjtalk -from style_bert_vits2.text_processing.japanese.user_dict.word_model import UserDictWord, WordTypes -from style_bert_vits2.text_processing.japanese.user_dict.part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk +from style_bert_vits2.nlp.japanese.user_dict.word_model import UserDictWord, WordTypes +from style_bert_vits2.nlp.japanese.user_dict.part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data pyopenjtalk.initialize() diff --git a/style_bert_vits2/text_processing/japanese/user_dict/part_of_speech_data.py b/style_bert_vits2/nlp/japanese/user_dict/part_of_speech_data.py similarity index 97% rename from style_bert_vits2/text_processing/japanese/user_dict/part_of_speech_data.py rename to style_bert_vits2/nlp/japanese/user_dict/part_of_speech_data.py index db42d3869..443bdc521 100644 --- a/style_bert_vits2/text_processing/japanese/user_dict/part_of_speech_data.py +++ b/style_bert_vits2/nlp/japanese/user_dict/part_of_speech_data.py @@ -7,7 +7,7 @@ from typing import Dict -from style_bert_vits2.text_processing.japanese.user_dict.word_model import ( +from style_bert_vits2.nlp.japanese.user_dict.word_model import ( USER_DICT_MAX_PRIORITY, USER_DICT_MIN_PRIORITY, PartOfSpeechDetail, diff --git a/style_bert_vits2/text_processing/japanese/user_dict/word_model.py b/style_bert_vits2/nlp/japanese/user_dict/word_model.py similarity index 100% rename from style_bert_vits2/text_processing/japanese/user_dict/word_model.py rename to style_bert_vits2/nlp/japanese/user_dict/word_model.py diff --git a/style_bert_vits2/text_processing/symbols.py b/style_bert_vits2/nlp/symbols.py similarity index 100% rename from style_bert_vits2/text_processing/symbols.py rename to style_bert_vits2/nlp/symbols.py diff --git a/style_bert_vits2/text_processing/japanese/__init__.py b/style_bert_vits2/text_processing/japanese/__init__.py deleted file mode 100644 index 91233773e..000000000 --- a/style_bert_vits2/text_processing/japanese/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from style_bert_vits2.text_processing.japanese.g2p import g2p # noqa: F401 -from style_bert_vits2.text_processing.japanese.normalizer import normalize_text # noqa: F401 diff --git a/train_ms.py b/train_ms.py index 05cd4cb02..0cb7a8216 100644 --- a/train_ms.py +++ b/train_ms.py @@ -31,7 +31,7 @@ MultiPeriodDiscriminator, SynthesizerTrn, ) -from style_bert_vits2.text_processing.symbols import SYMBOLS +from style_bert_vits2.nlp.symbols import SYMBOLS from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT torch.backends.cuda.matmul.allow_tf32 = True diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index f04d2aef6..7d0636b95 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -32,7 +32,7 @@ SynthesizerTrn, WavLMDiscriminator, ) -from style_bert_vits2.text_processing.symbols import SYMBOLS +from style_bert_vits2.nlp.symbols import SYMBOLS from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT torch.backends.cuda.matmul.allow_tf32 = True diff --git a/webui/inference.py b/webui/inference.py index f8de8c34a..026481f30 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -20,8 +20,8 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.models.infer import InvalidToneError -from style_bert_vits2.text_processing.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone -from style_bert_vits2.text_processing.japanese.normalizer import normalize_text +from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone +from style_bert_vits2.nlp.japanese.normalizer import normalize_text languages = [l.value for l in Languages] From 75467936d96e4396fcb60eeaeaea48ce379a9027 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 06:39:51 +0000 Subject: [PATCH 042/148] Refactor: cleanup style_bert_vits2/nlp/english/__init__.py --- style_bert_vits2/nlp/english/__init__.py | 564 ++++++++++------------- 1 file changed, 245 insertions(+), 319 deletions(-) diff --git a/style_bert_vits2/nlp/english/__init__.py b/style_bert_vits2/nlp/english/__init__.py index 3610b13bd..e5041d107 100644 --- a/style_bert_vits2/nlp/english/__init__.py +++ b/style_bert_vits2/nlp/english/__init__.py @@ -1,6 +1,9 @@ import pickle import os import re +from pathlib import Path + +import inflect from g2p_en import G2p from style_bert_vits2.constants import Languages @@ -8,88 +11,164 @@ from style_bert_vits2.nlp.symbols import PUNCTUATIONS, SYMBOLS -current_file_path = os.path.dirname(__file__) -CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep") -CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle") -_g2p = G2p() - -arpa = { - "AH0", - "S", - "AH1", - "EY2", - "AE2", - "EH0", - "OW2", - "UH0", - "NG", - "B", - "G", - "AY0", - "M", - "AA0", - "F", - "AO0", - "ER2", - "UH1", - "IY1", - "AH2", - "DH", - "IY0", - "EY1", - "IH0", - "K", - "N", - "W", - "IY2", - "T", - "AA1", - "ER1", - "EH2", - "OY0", - "UH2", - "UW1", - "Z", - "AW2", - "AW1", - "V", - "UW2", - "AA2", - "ER", - "AW0", - "UW0", - "R", - "OW1", - "EH1", - "ZH", - "AE0", - "IH2", - "IH", - "Y", - "JH", - "P", - "AY1", - "EY0", - "OY2", - "TH", - "HH", - "D", - "ER0", - "CH", - "AO1", - "AE1", - "AO2", - "OY1", - "AY2", - "IH1", - "OW0", - "L", - "SH", -} - - -def post_replace_ph(ph): - rep_map = { +CMU_DICT_PATH = Path(__file__).parent / "cmudict.rep" +CACHE_PATH = Path(__file__).parent / "cmudict_cache.pickle" + + +def g2p(text: str) -> tuple[list[str], list[int], list[int]]: + + ARPA = { + "AH0", + "S", + "AH1", + "EY2", + "AE2", + "EH0", + "OW2", + "UH0", + "NG", + "B", + "G", + "AY0", + "M", + "AA0", + "F", + "AO0", + "ER2", + "UH1", + "IY1", + "AH2", + "DH", + "IY0", + "EY1", + "IH0", + "K", + "N", + "W", + "IY2", + "T", + "AA1", + "ER1", + "EH2", + "OY0", + "UH2", + "UW1", + "Z", + "AW2", + "AW1", + "V", + "UW2", + "AA2", + "ER", + "AW0", + "UW0", + "R", + "OW1", + "EH1", + "ZH", + "AE0", + "IH2", + "IH", + "Y", + "JH", + "P", + "AY1", + "EY0", + "OY2", + "TH", + "HH", + "D", + "ER0", + "CH", + "AO1", + "AE1", + "AO2", + "OY1", + "AY2", + "IH1", + "OW0", + "L", + "SH", + } + + _g2p = G2p() + + phones = [] + tones = [] + phone_len = [] + # tokens = [tokenizer.tokenize(i) for i in words] + words = __text_to_words(text) + eng_dict = __get_dict() + + for word in words: + temp_phones, temp_tones = [], [] + if len(word) > 1: + if "'" in word: + word = ["".join(word)] + for w in word: + if w in PUNCTUATIONS: + temp_phones.append(w) + temp_tones.append(0) + continue + if w.upper() in eng_dict: + phns, tns = __refine_syllables(eng_dict[w.upper()]) + temp_phones += [__post_replace_ph(i) for i in phns] + temp_tones += tns + # w2ph.append(len(phns)) + else: + phone_list = list(filter(lambda p: p != " ", _g2p(w))) # type: ignore + phns = [] + tns = [] + for ph in phone_list: + if ph in ARPA: + ph, tn = __refine_ph(ph) + phns.append(ph) + tns.append(tn) + else: + phns.append(ph) + tns.append(0) + temp_phones += [__post_replace_ph(i) for i in phns] + temp_tones += tns + phones += temp_phones + tones += temp_tones + phone_len.append(len(temp_phones)) + # phones = [post_replace_ph(i) for i in phones] + + word2ph = [] + for token, pl in zip(words, phone_len): + word_len = len(token) + + aaa = __distribute_phone(pl, word_len) + word2ph += aaa + + phones = ["_"] + phones + ["_"] + tones = [0] + tones + [0] + word2ph = [1] + word2ph + [1] + assert len(phones) == len(tones), text + assert len(phones) == sum(word2ph), text + + return phones, tones, word2ph + + +def normalize_text(text: str) -> str: + text = __normalize_numbers(text) + text = __replace_punctuation(text) + text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text) + return text + + +def __normalize_numbers(text: str) -> str: + text = re.sub(__comma_number_re, __remove_commas, text) + text = re.sub(__pounds_re, r"\1 pounds", text) + text = re.sub(__dollars_re, __expand_dollars, text) + text = re.sub(__decimal_number_re, __expand_decimal_point, text) + text = re.sub(__ordinal_re, __expand_ordinal, text) + text = re.sub(__number_re, __expand_number, text) + return text + + +def __replace_punctuation(text: str) -> str: + REPLACE_MAP = { ":": ",", ";": ",", ",": ",", @@ -97,67 +176,38 @@ def post_replace_ph(ph): "!": "!", "?": "?", "\n": ".", - "·": ",", - "、": ",", + ".": ".", "…": "...", "···": "...", "・・・": "...", - "v": "V", + "·": ",", + "・": ",", + "、": ",", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "−": "-", + "~": "-", + "~": "-", + "「": "'", + "」": "'", } - if ph in rep_map.keys(): - ph = rep_map[ph] - if ph in SYMBOLS: - return ph - if ph not in SYMBOLS: - ph = "UNK" - return ph - - -rep_map = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - ".": ".", - "…": "...", - "···": "...", - "・・・": "...", - "·": ",", - "・": ",", - "、": ",", - "$": ".", - "“": "'", - "”": "'", - '"': "'", - "‘": "'", - "’": "'", - "(": "'", - ")": "'", - "(": "'", - ")": "'", - "《": "'", - "》": "'", - "【": "'", - "】": "'", - "[": "'", - "]": "'", - "—": "-", - "−": "-", - "~": "-", - "~": "-", - "「": "'", - "」": "'", -} - - -def replace_punctuation(text): - pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) - - replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) - + pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys())) + replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text) # replaced_text = re.sub( # r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" # + "".join(punctuation) @@ -165,11 +215,35 @@ def replace_punctuation(text): # "", # replaced_text, # ) - return replaced_text -def read_dict(): +def __post_replace_ph(ph: str) -> str: + REPLACE_MAP = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "…": "...", + "···": "...", + "・・・": "...", + "v": "V", + } + if ph in REPLACE_MAP.keys(): + ph = REPLACE_MAP[ph] + if ph in SYMBOLS: + return ph + if ph not in SYMBOLS: + ph = "UNK" + return ph + + +def __read_dict() -> dict[str, list[list[str]]]: g2p_dict = {} start_line = 49 with open(CMU_DICT_PATH) as f: @@ -193,26 +267,23 @@ def read_dict(): return g2p_dict -def cache_dict(g2p_dict, file_path): +def __cache_dict(g2p_dict: dict[str, list[list[str]]], file_path: Path) -> None: with open(file_path, "wb") as pickle_file: pickle.dump(g2p_dict, pickle_file) -def get_dict(): +def __get_dict() -> dict[str, list[list[str]]]: if os.path.exists(CACHE_PATH): with open(CACHE_PATH, "rb") as pickle_file: g2p_dict = pickle.load(pickle_file) else: - g2p_dict = read_dict() - cache_dict(g2p_dict, CACHE_PATH) + g2p_dict = __read_dict() + __cache_dict(g2p_dict, CACHE_PATH) return g2p_dict -eng_dict = get_dict() - - -def refine_ph(phn): +def __refine_ph(phn: str) -> tuple[str, int]: tone = 0 if re.search(r"\d$", phn): tone = int(phn[-1]) + 1 @@ -222,93 +293,28 @@ def refine_ph(phn): return phn.lower(), tone -def refine_syllables(syllables): +def __refine_syllables(syllables: list[list[str]]) -> tuple[list[str], list[int]]: tones = [] phonemes = [] for phn_list in syllables: for i in range(len(phn_list)): phn = phn_list[i] - phn, tone = refine_ph(phn) + phn, tone = __refine_ph(phn) phonemes.append(phn) tones.append(tone) return phonemes, tones -import inflect +__inflect = inflect.engine() +__comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +__decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +__pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +__dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +__ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +__number_re = re.compile(r"[0-9]+") -_inflect = inflect.engine() -_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") -_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") -_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") -_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") -_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") -_number_re = re.compile(r"[0-9]+") - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [ - (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) - for x in [ - ("mrs", "misess"), - ("mr", "mister"), - ("dr", "doctor"), - ("st", "saint"), - ("co", "company"), - ("jr", "junior"), - ("maj", "major"), - ("gen", "general"), - ("drs", "doctors"), - ("rev", "reverend"), - ("lt", "lieutenant"), - ("hon", "honorable"), - ("sgt", "sergeant"), - ("capt", "captain"), - ("esq", "esquire"), - ("ltd", "limited"), - ("col", "colonel"), - ("ft", "fort"), - ] -] - - -# List of (ipa, lazy ipa) pairs: -_lazy_ipa = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - ("r", "ɹ"), - ("æ", "e"), - ("ɑ", "a"), - ("ɔ", "o"), - ("ð", "z"), - ("θ", "s"), - ("ɛ", "e"), - ("ɪ", "i"), - ("ʊ", "u"), - ("ʒ", "ʥ"), - ("ʤ", "ʥ"), - ("ˈ", "↓"), - ] -] - -# List of (ipa, lazy ipa2) pairs: -_lazy_ipa2 = [ - (re.compile("%s" % x[0]), x[1]) - for x in [ - ("r", "ɹ"), - ("ð", "z"), - ("θ", "s"), - ("ʒ", "ʑ"), - ("ʤ", "dʑ"), - ("ˈ", "↓"), - ] -] - -# List of (ipa, ipa2) pairs -_ipa_to_ipa2 = [ - (re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")] -] - - -def _expand_dollars(m): + +def __expand_dollars(m: re.Match[str]) -> str: match = m.group(1) parts = match.split(".") if len(parts) > 2: @@ -329,53 +335,36 @@ def _expand_dollars(m): return "zero dollars" -def _remove_commas(m): +def __remove_commas(m: re.Match[str]) -> str: return m.group(1).replace(",", "") -def _expand_ordinal(m): - return _inflect.number_to_words(m.group(0)) +def __expand_ordinal(m: re.Match[str]) -> str: + return __inflect.number_to_words(m.group(0)) # type: ignore -def _expand_number(m): +def __expand_number(m: re.Match[str]) -> str: num = int(m.group(0)) if num > 1000 and num < 3000: if num == 2000: return "two thousand" elif num > 2000 and num < 2010: - return "two thousand " + _inflect.number_to_words(num % 100) + return "two thousand " + __inflect.number_to_words(num % 100) # type: ignore elif num % 100 == 0: - return _inflect.number_to_words(num // 100) + " hundred" + return __inflect.number_to_words(num // 100) + " hundred" # type: ignore else: - return _inflect.number_to_words( - num, andword="", zero="oh", group=2 - ).replace(", ", " ") + return __inflect.number_to_words( + num, andword="", zero="oh", group=2 # type: ignore + ).replace(", ", " ") # type: ignore else: - return _inflect.number_to_words(num, andword="") + return __inflect.number_to_words(num, andword="") # type: ignore -def _expand_decimal_point(m): +def __expand_decimal_point(m: re.Match[str]) -> str: return m.group(1).replace(".", " point ") -def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r"\1 pounds", text) - text = re.sub(_dollars_re, _expand_dollars, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text - - -def normalize_text(text: str) -> str: - text = normalize_numbers(text) - text = replace_punctuation(text) - text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text) - return text - - -def distribute_phone(n_phone, n_word): +def __distribute_phone(n_phone: int, n_word: int) -> list[int]: phones_per_word = [0] * n_word for task in range(n_phone): min_tasks = min(phones_per_word) @@ -384,13 +373,7 @@ def distribute_phone(n_phone, n_word): return phones_per_word -def sep_text(text): - words = re.split(r"([,;.\?\!\s+])", text) - words = [word for word in words if word.strip() != ""] - return words - - -def text_to_words(text): +def __text_to_words(text: str) -> list[list[str]]: tokenizer = bert_models.load_tokenizer(Languages.EN) tokens = tokenizer.tokenize(text) words = [] @@ -418,69 +401,12 @@ def text_to_words(text): return words -def g2p(text: str) -> tuple[list[str], list[int], list[int]]: - phones = [] - tones = [] - phone_len = [] - # words = sep_text(text) - # tokens = [tokenizer.tokenize(i) for i in words] - words = text_to_words(text) - - for word in words: - temp_phones, temp_tones = [], [] - if len(word) > 1: - if "'" in word: - word = ["".join(word)] - for w in word: - if w in PUNCTUATIONS: - temp_phones.append(w) - temp_tones.append(0) - continue - if w.upper() in eng_dict: - phns, tns = refine_syllables(eng_dict[w.upper()]) - temp_phones += [post_replace_ph(i) for i in phns] - temp_tones += tns - # w2ph.append(len(phns)) - else: - phone_list = list(filter(lambda p: p != " ", _g2p(w))) - phns = [] - tns = [] - for ph in phone_list: - if ph in arpa: - ph, tn = refine_ph(ph) - phns.append(ph) - tns.append(tn) - else: - phns.append(ph) - tns.append(0) - temp_phones += [post_replace_ph(i) for i in phns] - temp_tones += tns - phones += temp_phones - tones += temp_tones - phone_len.append(len(temp_phones)) - # phones = [post_replace_ph(i) for i in phones] - - word2ph = [] - for token, pl in zip(words, phone_len): - word_len = len(token) - - aaa = distribute_phone(pl, word_len) - word2ph += aaa - - phones = ["_"] + phones + ["_"] - tones = [0] + tones + [0] - word2ph = [1] + word2ph + [1] - assert len(phones) == len(tones), text - assert len(phones) == sum(word2ph), text - - return phones, tones, word2ph - - if __name__ == "__main__": # print(get_dict()) # print(eng_word_to_phoneme("hello")) print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")) # all_phones = set() + # eng_dict = get_dict() # for k, syllables in eng_dict.items(): # for group in syllables: # for ph in group: From 5de4884075a4b40c1b20c95d39397726f31083df Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 07:31:33 +0000 Subject: [PATCH 043/148] Fix: import error pyopenjtalk_worker.initialize() has the side effect of starting another process and should not be executed automatically on import. --- bert_gen.py | 3 --- style_bert_vits2/nlp/english/__init__.py | 2 +- style_bert_vits2/nlp/japanese/g2p.py | 5 +++-- .../nlp/japanese/pyopenjtalk_worker/__init__.py | 14 +++++++------- .../nlp/japanese/user_dict/__init__.py | 7 +++++-- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/bert_gen.py b/bert_gen.py index 79ea65f40..26df64d0f 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -10,11 +10,8 @@ from style_bert_vits2.logging import logger from style_bert_vits2.models import commons from style_bert_vits2.nlp import cleaned_text_to_sequence, extract_bert_feature -from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -pyopenjtalk.initialize() - def process_line(x): line, add_blank = x diff --git a/style_bert_vits2/nlp/english/__init__.py b/style_bert_vits2/nlp/english/__init__.py index e5041d107..b2067c33e 100644 --- a/style_bert_vits2/nlp/english/__init__.py +++ b/style_bert_vits2/nlp/english/__init__.py @@ -273,7 +273,7 @@ def __cache_dict(g2p_dict: dict[str, list[list[str]]], file_path: Path) -> None: def __get_dict() -> dict[str, list[list[str]]]: - if os.path.exists(CACHE_PATH): + if CACHE_PATH.exists(): with open(CACHE_PATH, "rb") as pickle_file: g2p_dict = pickle.load(pickle_file) else: diff --git a/style_bert_vits2/nlp/japanese/g2p.py b/style_bert_vits2/nlp/japanese/g2p.py index 45e781778..0c44563b8 100644 --- a/style_bert_vits2/nlp/japanese/g2p.py +++ b/style_bert_vits2/nlp/japanese/g2p.py @@ -8,8 +8,6 @@ from style_bert_vits2.nlp.japanese.normalizer import replace_punctuation from style_bert_vits2.nlp.symbols import PUNCTUATIONS -pyopenjtalk.initialize() - def g2p( norm_text: str, @@ -114,6 +112,9 @@ def text_to_sep_kata( tuple[list[str], list[str]]: 分割された単語リストと、その読み(カタカナ or 記号1文字)のリスト """ + # pyopenjtalk_worker を初期化 + pyopenjtalk.initialize() + # parsed: OpenJTalkの解析結果 parsed = pyopenjtalk.run_frontend(norm_text) sep_text: list[str] = [] diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py index fc8f6dab8..90419f5bf 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py @@ -50,11 +50,11 @@ def unset_user_dict(): def initialize(port: int = WORKER_PORT) -> None: - import time - import socket - import sys import atexit import signal + import socket + import sys + import time logger.debug("initialize") global WORKER_CLIENT @@ -83,7 +83,7 @@ def initialize(port: int = WORKER_PORT) -> None: else: # align with Windows behavior # start_new_session is same as specifying setsid in preexec_fn - subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True) # type: ignore + subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True) # wait until server listening count = 0 @@ -92,10 +92,10 @@ def initialize(port: int = WORKER_PORT) -> None: client = WorkerClient(port) break except socket.error: - time.sleep(1) + time.sleep(0.5) count += 1 - # 10: max number of retries - if count == 10: + # 20: max number of retries + if count == 20: raise TimeoutError("サーバーに接続できませんでした") WORKER_CLIENT = client diff --git a/style_bert_vits2/nlp/japanese/user_dict/__init__.py b/style_bert_vits2/nlp/japanese/user_dict/__init__.py index 097cc6aad..3032c019c 100644 --- a/style_bert_vits2/nlp/japanese/user_dict/__init__.py +++ b/style_bert_vits2/nlp/japanese/user_dict/__init__.py @@ -20,8 +20,6 @@ from style_bert_vits2.nlp.japanese.user_dict.word_model import UserDictWord, WordTypes from style_bert_vits2.nlp.japanese.user_dict.part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data -pyopenjtalk.initialize() - # root_dir = engine_root() # save_dir = get_save_dir() @@ -81,6 +79,11 @@ def update_dict( compiled_dict_path : Path コンパイル済み辞書ファイルのパス """ + + # pyopenjtalk_worker を初期化 + # ファイルを開く前に実行する必要がある + pyopenjtalk.initialize() + random_string = uuid4() tmp_csv_path = compiled_dict_path.with_suffix( f".dict_csv-{random_string}.tmp" From e2daa550002a13f578d2b09cb7dcd07ccae6d751 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 08:43:54 +0000 Subject: [PATCH 044/148] Refactor: split style_bert_vits2.nlp.english package --- style_bert_vits2/nlp/__init__.py | 20 +- style_bert_vits2/nlp/english/__init__.py | 414 ------------------- style_bert_vits2/nlp/english/bert_feature.py | 6 +- style_bert_vits2/nlp/english/cmudict.py | 46 +++ style_bert_vits2/nlp/english/g2p.py | 240 +++++++++++ style_bert_vits2/nlp/english/normalizer.py | 130 ++++++ 6 files changed, 432 insertions(+), 424 deletions(-) create mode 100644 style_bert_vits2/nlp/english/cmudict.py create mode 100644 style_bert_vits2/nlp/english/g2p.py create mode 100644 style_bert_vits2/nlp/english/normalizer.py diff --git a/style_bert_vits2/nlp/__init__.py b/style_bert_vits2/nlp/__init__.py index 47786d128..99f56a9f6 100644 --- a/style_bert_vits2/nlp/__init__.py +++ b/style_bert_vits2/nlp/__init__.py @@ -1,5 +1,4 @@ -import torch -from typing import Optional +from typing import Optional, TYPE_CHECKING from style_bert_vits2.constants import Languages from style_bert_vits2.nlp.symbols import ( @@ -8,6 +7,11 @@ SYMBOLS, ) +# __init__.py は配下のモジュールをインポートした時点で実行される +# Pytorch のインポートは重いので、型チェック時以外はインポートしない +if TYPE_CHECKING: + import torch + __symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)} @@ -16,10 +20,10 @@ def extract_bert_feature( text: str, word2ph: list[int], language: Languages, - device: torch.device | str, + device: str, assist_text: Optional[str] = None, assist_text_weight: float = 0.7, -) -> torch.Tensor: +) -> "torch.Tensor": """ テキストから BERT の特徴量を抽出する @@ -27,7 +31,7 @@ def extract_bert_feature( text (str): テキスト word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト language (Languages): テキストの言語 - device (torch.device | str): 推論に利用するデバイス + device (str): 推論に利用するデバイス assist_text (Optional[str], optional): 補助テキスト (デフォルト: None) assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) @@ -68,11 +72,13 @@ def clean_text( # Changed to import inside if condition to avoid unnecessary import if language == Languages.JP: - from style_bert_vits2.nlp.japanese import g2p, normalize_text + from style_bert_vits2.nlp.japanese.g2p import g2p + from style_bert_vits2.nlp.japanese.normalizer import normalize_text norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text, use_jp_extra, raise_yomi_error) elif language == Languages.EN: - from style_bert_vits2.nlp.english import g2p, normalize_text + from style_bert_vits2.nlp.english.g2p import g2p + from style_bert_vits2.nlp.english.normalizer import normalize_text norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text) elif language == Languages.ZH: diff --git a/style_bert_vits2/nlp/english/__init__.py b/style_bert_vits2/nlp/english/__init__.py index b2067c33e..e69de29bb 100644 --- a/style_bert_vits2/nlp/english/__init__.py +++ b/style_bert_vits2/nlp/english/__init__.py @@ -1,414 +0,0 @@ -import pickle -import os -import re -from pathlib import Path - -import inflect -from g2p_en import G2p - -from style_bert_vits2.constants import Languages -from style_bert_vits2.nlp import bert_models -from style_bert_vits2.nlp.symbols import PUNCTUATIONS, SYMBOLS - - -CMU_DICT_PATH = Path(__file__).parent / "cmudict.rep" -CACHE_PATH = Path(__file__).parent / "cmudict_cache.pickle" - - -def g2p(text: str) -> tuple[list[str], list[int], list[int]]: - - ARPA = { - "AH0", - "S", - "AH1", - "EY2", - "AE2", - "EH0", - "OW2", - "UH0", - "NG", - "B", - "G", - "AY0", - "M", - "AA0", - "F", - "AO0", - "ER2", - "UH1", - "IY1", - "AH2", - "DH", - "IY0", - "EY1", - "IH0", - "K", - "N", - "W", - "IY2", - "T", - "AA1", - "ER1", - "EH2", - "OY0", - "UH2", - "UW1", - "Z", - "AW2", - "AW1", - "V", - "UW2", - "AA2", - "ER", - "AW0", - "UW0", - "R", - "OW1", - "EH1", - "ZH", - "AE0", - "IH2", - "IH", - "Y", - "JH", - "P", - "AY1", - "EY0", - "OY2", - "TH", - "HH", - "D", - "ER0", - "CH", - "AO1", - "AE1", - "AO2", - "OY1", - "AY2", - "IH1", - "OW0", - "L", - "SH", - } - - _g2p = G2p() - - phones = [] - tones = [] - phone_len = [] - # tokens = [tokenizer.tokenize(i) for i in words] - words = __text_to_words(text) - eng_dict = __get_dict() - - for word in words: - temp_phones, temp_tones = [], [] - if len(word) > 1: - if "'" in word: - word = ["".join(word)] - for w in word: - if w in PUNCTUATIONS: - temp_phones.append(w) - temp_tones.append(0) - continue - if w.upper() in eng_dict: - phns, tns = __refine_syllables(eng_dict[w.upper()]) - temp_phones += [__post_replace_ph(i) for i in phns] - temp_tones += tns - # w2ph.append(len(phns)) - else: - phone_list = list(filter(lambda p: p != " ", _g2p(w))) # type: ignore - phns = [] - tns = [] - for ph in phone_list: - if ph in ARPA: - ph, tn = __refine_ph(ph) - phns.append(ph) - tns.append(tn) - else: - phns.append(ph) - tns.append(0) - temp_phones += [__post_replace_ph(i) for i in phns] - temp_tones += tns - phones += temp_phones - tones += temp_tones - phone_len.append(len(temp_phones)) - # phones = [post_replace_ph(i) for i in phones] - - word2ph = [] - for token, pl in zip(words, phone_len): - word_len = len(token) - - aaa = __distribute_phone(pl, word_len) - word2ph += aaa - - phones = ["_"] + phones + ["_"] - tones = [0] + tones + [0] - word2ph = [1] + word2ph + [1] - assert len(phones) == len(tones), text - assert len(phones) == sum(word2ph), text - - return phones, tones, word2ph - - -def normalize_text(text: str) -> str: - text = __normalize_numbers(text) - text = __replace_punctuation(text) - text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text) - return text - - -def __normalize_numbers(text: str) -> str: - text = re.sub(__comma_number_re, __remove_commas, text) - text = re.sub(__pounds_re, r"\1 pounds", text) - text = re.sub(__dollars_re, __expand_dollars, text) - text = re.sub(__decimal_number_re, __expand_decimal_point, text) - text = re.sub(__ordinal_re, __expand_ordinal, text) - text = re.sub(__number_re, __expand_number, text) - return text - - -def __replace_punctuation(text: str) -> str: - REPLACE_MAP = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - ".": ".", - "…": "...", - "···": "...", - "・・・": "...", - "·": ",", - "・": ",", - "、": ",", - "$": ".", - "“": "'", - "”": "'", - '"': "'", - "‘": "'", - "’": "'", - "(": "'", - ")": "'", - "(": "'", - ")": "'", - "《": "'", - "》": "'", - "【": "'", - "】": "'", - "[": "'", - "]": "'", - "—": "-", - "−": "-", - "~": "-", - "~": "-", - "「": "'", - "」": "'", - } - pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys())) - replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text) - # replaced_text = re.sub( - # r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" - # + "".join(punctuation) - # + r"]+", - # "", - # replaced_text, - # ) - return replaced_text - - -def __post_replace_ph(ph: str) -> str: - REPLACE_MAP = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - "·": ",", - "、": ",", - "…": "...", - "···": "...", - "・・・": "...", - "v": "V", - } - if ph in REPLACE_MAP.keys(): - ph = REPLACE_MAP[ph] - if ph in SYMBOLS: - return ph - if ph not in SYMBOLS: - ph = "UNK" - return ph - - -def __read_dict() -> dict[str, list[list[str]]]: - g2p_dict = {} - start_line = 49 - with open(CMU_DICT_PATH) as f: - line = f.readline() - line_index = 1 - while line: - if line_index >= start_line: - line = line.strip() - word_split = line.split(" ") - word = word_split[0] - - syllable_split = word_split[1].split(" - ") - g2p_dict[word] = [] - for syllable in syllable_split: - phone_split = syllable.split(" ") - g2p_dict[word].append(phone_split) - - line_index = line_index + 1 - line = f.readline() - - return g2p_dict - - -def __cache_dict(g2p_dict: dict[str, list[list[str]]], file_path: Path) -> None: - with open(file_path, "wb") as pickle_file: - pickle.dump(g2p_dict, pickle_file) - - -def __get_dict() -> dict[str, list[list[str]]]: - if CACHE_PATH.exists(): - with open(CACHE_PATH, "rb") as pickle_file: - g2p_dict = pickle.load(pickle_file) - else: - g2p_dict = __read_dict() - __cache_dict(g2p_dict, CACHE_PATH) - - return g2p_dict - - -def __refine_ph(phn: str) -> tuple[str, int]: - tone = 0 - if re.search(r"\d$", phn): - tone = int(phn[-1]) + 1 - phn = phn[:-1] - else: - tone = 3 - return phn.lower(), tone - - -def __refine_syllables(syllables: list[list[str]]) -> tuple[list[str], list[int]]: - tones = [] - phonemes = [] - for phn_list in syllables: - for i in range(len(phn_list)): - phn = phn_list[i] - phn, tone = __refine_ph(phn) - phonemes.append(phn) - tones.append(tone) - return phonemes, tones - - -__inflect = inflect.engine() -__comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") -__decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") -__pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") -__dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") -__ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") -__number_re = re.compile(r"[0-9]+") - - -def __expand_dollars(m: re.Match[str]) -> str: - match = m.group(1) - parts = match.split(".") - if len(parts) > 2: - return match + " dollars" # Unexpected format - dollars = int(parts[0]) if parts[0] else 0 - cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: - dollar_unit = "dollar" if dollars == 1 else "dollars" - cent_unit = "cent" if cents == 1 else "cents" - return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) - elif dollars: - dollar_unit = "dollar" if dollars == 1 else "dollars" - return "%s %s" % (dollars, dollar_unit) - elif cents: - cent_unit = "cent" if cents == 1 else "cents" - return "%s %s" % (cents, cent_unit) - else: - return "zero dollars" - - -def __remove_commas(m: re.Match[str]) -> str: - return m.group(1).replace(",", "") - - -def __expand_ordinal(m: re.Match[str]) -> str: - return __inflect.number_to_words(m.group(0)) # type: ignore - - -def __expand_number(m: re.Match[str]) -> str: - num = int(m.group(0)) - if num > 1000 and num < 3000: - if num == 2000: - return "two thousand" - elif num > 2000 and num < 2010: - return "two thousand " + __inflect.number_to_words(num % 100) # type: ignore - elif num % 100 == 0: - return __inflect.number_to_words(num // 100) + " hundred" # type: ignore - else: - return __inflect.number_to_words( - num, andword="", zero="oh", group=2 # type: ignore - ).replace(", ", " ") # type: ignore - else: - return __inflect.number_to_words(num, andword="") # type: ignore - - -def __expand_decimal_point(m: re.Match[str]) -> str: - return m.group(1).replace(".", " point ") - - -def __distribute_phone(n_phone: int, n_word: int) -> list[int]: - phones_per_word = [0] * n_word - for task in range(n_phone): - min_tasks = min(phones_per_word) - min_index = phones_per_word.index(min_tasks) - phones_per_word[min_index] += 1 - return phones_per_word - - -def __text_to_words(text: str) -> list[list[str]]: - tokenizer = bert_models.load_tokenizer(Languages.EN) - tokens = tokenizer.tokenize(text) - words = [] - for idx, t in enumerate(tokens): - if t.startswith("▁"): - words.append([t[1:]]) - else: - if t in PUNCTUATIONS: - if idx == len(tokens) - 1: - words.append([f"{t}"]) - else: - if ( - not tokens[idx + 1].startswith("▁") - and tokens[idx + 1] not in PUNCTUATIONS - ): - if idx == 0: - words.append([]) - words[-1].append(f"{t}") - else: - words.append([f"{t}"]) - else: - if idx == 0: - words.append([]) - words[-1].append(f"{t}") - return words - - -if __name__ == "__main__": - # print(get_dict()) - # print(eng_word_to_phoneme("hello")) - print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")) - # all_phones = set() - # eng_dict = get_dict() - # for k, syllables in eng_dict.items(): - # for group in syllables: - # for ph in group: - # all_phones.add(ph) - # print(all_phones) diff --git a/style_bert_vits2/nlp/english/bert_feature.py b/style_bert_vits2/nlp/english/bert_feature.py index 647920d23..27fd5018a 100644 --- a/style_bert_vits2/nlp/english/bert_feature.py +++ b/style_bert_vits2/nlp/english/bert_feature.py @@ -8,13 +8,13 @@ from style_bert_vits2.nlp import bert_models -__models: dict[torch.device | str, PreTrainedModel] = {} +__models: dict[str, PreTrainedModel] = {} def extract_bert_feature( text: str, word2ph: list[int], - device: torch.device | str, + device: str, assist_text: Optional[str] = None, assist_text_weight: float = 0.7, ) -> torch.Tensor: @@ -24,7 +24,7 @@ def extract_bert_feature( Args: text (str): 英語のテキスト word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト - device (torch.device | str): 推論に利用するデバイス + device (str): 推論に利用するデバイス assist_text (Optional[str], optional): 補助テキスト (デフォルト: None) assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) diff --git a/style_bert_vits2/nlp/english/cmudict.py b/style_bert_vits2/nlp/english/cmudict.py new file mode 100644 index 000000000..e6afb89a7 --- /dev/null +++ b/style_bert_vits2/nlp/english/cmudict.py @@ -0,0 +1,46 @@ +import pickle +from pathlib import Path + + +CMU_DICT_PATH = Path(__file__).parent / "cmudict.rep" +CACHE_PATH = Path(__file__).parent / "cmudict_cache.pickle" + + +def get_dict() -> dict[str, list[list[str]]]: + if CACHE_PATH.exists(): + with open(CACHE_PATH, "rb") as pickle_file: + g2p_dict = pickle.load(pickle_file) + else: + g2p_dict = read_dict() + cache_dict(g2p_dict, CACHE_PATH) + + return g2p_dict + + +def read_dict() -> dict[str, list[list[str]]]: + g2p_dict = {} + start_line = 49 + with open(CMU_DICT_PATH) as f: + line = f.readline() + line_index = 1 + while line: + if line_index >= start_line: + line = line.strip() + word_split = line.split(" ") + word = word_split[0] + + syllable_split = word_split[1].split(" - ") + g2p_dict[word] = [] + for syllable in syllable_split: + phone_split = syllable.split(" ") + g2p_dict[word].append(phone_split) + + line_index = line_index + 1 + line = f.readline() + + return g2p_dict + + +def cache_dict(g2p_dict: dict[str, list[list[str]]], file_path: Path) -> None: + with open(file_path, "wb") as pickle_file: + pickle.dump(g2p_dict, pickle_file) diff --git a/style_bert_vits2/nlp/english/g2p.py b/style_bert_vits2/nlp/english/g2p.py new file mode 100644 index 000000000..db2a87f97 --- /dev/null +++ b/style_bert_vits2/nlp/english/g2p.py @@ -0,0 +1,240 @@ +import re + +from g2p_en import G2p + +from style_bert_vits2.constants import Languages +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.nlp.english.cmudict import get_dict +from style_bert_vits2.nlp.symbols import PUNCTUATIONS, SYMBOLS + + +def g2p(text: str) -> tuple[list[str], list[int], list[int]]: + + ARPA = { + "AH0", + "S", + "AH1", + "EY2", + "AE2", + "EH0", + "OW2", + "UH0", + "NG", + "B", + "G", + "AY0", + "M", + "AA0", + "F", + "AO0", + "ER2", + "UH1", + "IY1", + "AH2", + "DH", + "IY0", + "EY1", + "IH0", + "K", + "N", + "W", + "IY2", + "T", + "AA1", + "ER1", + "EH2", + "OY0", + "UH2", + "UW1", + "Z", + "AW2", + "AW1", + "V", + "UW2", + "AA2", + "ER", + "AW0", + "UW0", + "R", + "OW1", + "EH1", + "ZH", + "AE0", + "IH2", + "IH", + "Y", + "JH", + "P", + "AY1", + "EY0", + "OY2", + "TH", + "HH", + "D", + "ER0", + "CH", + "AO1", + "AE1", + "AO2", + "OY1", + "AY2", + "IH1", + "OW0", + "L", + "SH", + } + + _g2p = G2p() + + phones = [] + tones = [] + phone_len = [] + # tokens = [tokenizer.tokenize(i) for i in words] + words = __text_to_words(text) + eng_dict = get_dict() + + for word in words: + temp_phones, temp_tones = [], [] + if len(word) > 1: + if "'" in word: + word = ["".join(word)] + for w in word: + if w in PUNCTUATIONS: + temp_phones.append(w) + temp_tones.append(0) + continue + if w.upper() in eng_dict: + phns, tns = __refine_syllables(eng_dict[w.upper()]) + temp_phones += [__post_replace_ph(i) for i in phns] + temp_tones += tns + # w2ph.append(len(phns)) + else: + phone_list = list(filter(lambda p: p != " ", _g2p(w))) # type: ignore + phns = [] + tns = [] + for ph in phone_list: + if ph in ARPA: + ph, tn = __refine_ph(ph) + phns.append(ph) + tns.append(tn) + else: + phns.append(ph) + tns.append(0) + temp_phones += [__post_replace_ph(i) for i in phns] + temp_tones += tns + phones += temp_phones + tones += temp_tones + phone_len.append(len(temp_phones)) + # phones = [post_replace_ph(i) for i in phones] + + word2ph = [] + for token, pl in zip(words, phone_len): + word_len = len(token) + + aaa = __distribute_phone(pl, word_len) + word2ph += aaa + + phones = ["_"] + phones + ["_"] + tones = [0] + tones + [0] + word2ph = [1] + word2ph + [1] + assert len(phones) == len(tones), text + assert len(phones) == sum(word2ph), text + + return phones, tones, word2ph + + +def __post_replace_ph(ph: str) -> str: + REPLACE_MAP = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "…": "...", + "···": "...", + "・・・": "...", + "v": "V", + } + if ph in REPLACE_MAP.keys(): + ph = REPLACE_MAP[ph] + if ph in SYMBOLS: + return ph + if ph not in SYMBOLS: + ph = "UNK" + return ph + + +def __refine_ph(phn: str) -> tuple[str, int]: + tone = 0 + if re.search(r"\d$", phn): + tone = int(phn[-1]) + 1 + phn = phn[:-1] + else: + tone = 3 + return phn.lower(), tone + + +def __refine_syllables(syllables: list[list[str]]) -> tuple[list[str], list[int]]: + tones = [] + phonemes = [] + for phn_list in syllables: + for i in range(len(phn_list)): + phn = phn_list[i] + phn, tone = __refine_ph(phn) + phonemes.append(phn) + tones.append(tone) + return phonemes, tones + + +def __distribute_phone(n_phone: int, n_word: int) -> list[int]: + phones_per_word = [0] * n_word + for task in range(n_phone): + min_tasks = min(phones_per_word) + min_index = phones_per_word.index(min_tasks) + phones_per_word[min_index] += 1 + return phones_per_word + + +def __text_to_words(text: str) -> list[list[str]]: + tokenizer = bert_models.load_tokenizer(Languages.EN) + tokens = tokenizer.tokenize(text) + words = [] + for idx, t in enumerate(tokens): + if t.startswith("▁"): + words.append([t[1:]]) + else: + if t in PUNCTUATIONS: + if idx == len(tokens) - 1: + words.append([f"{t}"]) + else: + if ( + not tokens[idx + 1].startswith("▁") + and tokens[idx + 1] not in PUNCTUATIONS + ): + if idx == 0: + words.append([]) + words[-1].append(f"{t}") + else: + words.append([f"{t}"]) + else: + if idx == 0: + words.append([]) + words[-1].append(f"{t}") + return words + + +if __name__ == "__main__": + # print(get_dict()) + # print(eng_word_to_phoneme("hello")) + print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")) + # all_phones = set() + # eng_dict = get_dict() + # for k, syllables in eng_dict.items(): + # for group in syllables: + # for ph in group: + # all_phones.add(ph) + # print(all_phones) diff --git a/style_bert_vits2/nlp/english/normalizer.py b/style_bert_vits2/nlp/english/normalizer.py new file mode 100644 index 000000000..0886581b9 --- /dev/null +++ b/style_bert_vits2/nlp/english/normalizer.py @@ -0,0 +1,130 @@ +import re + +import inflect + + +__INFLECT = inflect.engine() +__COMMA_NUMBER_PATTERN = re.compile(r"([0-9][0-9\,]+[0-9])") +__DECIMAL_NUMBER_PATTERN = re.compile(r"([0-9]+\.[0-9]+)") +__POUNDS_PATTERN = re.compile(r"£([0-9\,]*[0-9]+)") +__DOLLARS_PATTERN = re.compile(r"\$([0-9\.\,]*[0-9]+)") +__ORDINAL_PATTERN = re.compile(r"[0-9]+(st|nd|rd|th)") +__NUMBER_PATTERN = re.compile(r"[0-9]+") + + +def normalize_text(text: str) -> str: + text = __normalize_numbers(text) + text = __replace_punctuation(text) + text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text) + return text + + +def __normalize_numbers(text: str) -> str: + text = re.sub(__COMMA_NUMBER_PATTERN, __remove_commas, text) + text = re.sub(__POUNDS_PATTERN, r"\1 pounds", text) + text = re.sub(__DOLLARS_PATTERN, __expand_dollars, text) + text = re.sub(__DECIMAL_NUMBER_PATTERN, __expand_decimal_point, text) + text = re.sub(__ORDINAL_PATTERN, __expand_ordinal, text) + text = re.sub(__NUMBER_PATTERN, __expand_number, text) + return text + + +def __replace_punctuation(text: str) -> str: + REPLACE_MAP = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + ".": ".", + "…": "...", + "···": "...", + "・・・": "...", + "·": ",", + "・": ",", + "、": ",", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "−": "-", + "~": "-", + "~": "-", + "「": "'", + "」": "'", + } + pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys())) + replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text) + # replaced_text = re.sub( + # r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005" + # + "".join(punctuation) + # + r"]+", + # "", + # replaced_text, + # ) + return replaced_text + + +def __expand_dollars(m: re.Match[str]) -> str: + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + +def __remove_commas(m: re.Match[str]) -> str: + return m.group(1).replace(",", "") + + +def __expand_ordinal(m: re.Match[str]) -> str: + return __INFLECT.number_to_words(m.group(0)) # type: ignore + + +def __expand_number(m: re.Match[str]) -> str: + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + __INFLECT.number_to_words(num % 100) # type: ignore + elif num % 100 == 0: + return __INFLECT.number_to_words(num // 100) + " hundred" # type: ignore + else: + return __INFLECT.number_to_words( + num, andword="", zero="oh", group=2 # type: ignore + ).replace(", ", " ") # type: ignore + else: + return __INFLECT.number_to_words(num, andword="") # type: ignore + + +def __expand_decimal_point(m: re.Match[str]) -> str: + return m.group(1).replace(".", " point ") From a672aeefd93188de76e9978cd9487bed0f91d91d Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 08:45:56 +0000 Subject: [PATCH 045/148] Fix: maintain compatibility with Python 3.9 --- server_editor.py | 2 +- style_bert_vits2/models/commons.py | 6 +++--- style_bert_vits2/nlp/bert_models.py | 14 +++++++------- style_bert_vits2/nlp/chinese/bert_feature.py | 6 +++--- style_bert_vits2/nlp/japanese/__init__.py | 2 -- style_bert_vits2/nlp/japanese/bert_feature.py | 6 +++--- 6 files changed, 17 insertions(+), 19 deletions(-) diff --git a/server_editor.py b/server_editor.py index f72800de4..890a074b8 100644 --- a/server_editor.py +++ b/server_editor.py @@ -43,8 +43,8 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.nlp import bert_models -from style_bert_vits2.nlp.japanese import normalize_text from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone +from style_bert_vits2.nlp.japanese.normalizer import normalize_text from style_bert_vits2.nlp.japanese.user_dict import ( apply_word, delete_word, diff --git a/style_bert_vits2/models/commons.py b/style_bert_vits2/models/commons.py index 89d07d507..1106b219a 100644 --- a/style_bert_vits2/models/commons.py +++ b/style_bert_vits2/models/commons.py @@ -5,7 +5,7 @@ import torch from torch.nn import functional as F -from typing import Any, Optional +from typing import Any, Optional, Union def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: @@ -180,12 +180,12 @@ def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: return path -def clip_grad_value_(parameters: torch.Tensor | list[torch.Tensor], clip_value: Optional[float], norm_type: float = 2.0) -> float: +def clip_grad_value_(parameters: Union[torch.Tensor, list[torch.Tensor]], clip_value: Optional[float], norm_type: float = 2.0) -> float: """ 勾配の値をクリップする Args: - parameters (torch.Tensor | list[torch.Tensor]): クリップするパラメータ + parameters (Union[torch.Tensor, list[torch.Tensor]]): クリップするパラメータ clip_value (Optional[float]): クリップする値。None の場合はクリップしない norm_type (float): ノルムの種類 diff --git a/style_bert_vits2/nlp/bert_models.py b/style_bert_vits2/nlp/bert_models.py index df7a0162b..4385d648c 100644 --- a/style_bert_vits2/nlp/bert_models.py +++ b/style_bert_vits2/nlp/bert_models.py @@ -9,7 +9,7 @@ """ import gc -from typing import cast, Optional +from typing import cast, Optional, Union import torch from transformers import ( @@ -27,16 +27,16 @@ # 各言語ごとのロード済みの BERT モデルを格納する辞書 -__loaded_models: dict[Languages, PreTrainedModel | DebertaV2Model] = {} +__loaded_models: dict[Languages, Union[PreTrainedModel, DebertaV2Model]] = {} # 各言語ごとのロード済みの BERT トークナイザーを格納する辞書 -__loaded_tokenizers: dict[Languages, PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer] = {} +__loaded_tokenizers: dict[Languages, Union[PreTrainedTokenizer, PreTrainedTokenizerFast, DebertaV2Tokenizer]] = {} def load_model( language: Languages, pretrained_model_name_or_path: Optional[str] = None, -) -> PreTrainedModel | DebertaV2Model: +) -> Union[PreTrainedModel, DebertaV2Model]: """ 指定された言語の BERT モデルをロードし、ロード済みの BERT モデルを返す。 一度ロードされていれば、ロード済みの BERT モデルを即座に返す。 @@ -54,7 +54,7 @@ def load_model( pretrained_model_name_or_path (Optional[str]): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) Returns: - PreTrainedModel | DebertaV2Model: ロード済みの BERT モデル + Union[PreTrainedModel, DebertaV2Model]: ロード済みの BERT モデル """ # すでにロード済みの場合はそのまま返す @@ -82,7 +82,7 @@ def load_model( def load_tokenizer( language: Languages, pretrained_model_name_or_path: Optional[str] = None, -) -> PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer: +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast, DebertaV2Tokenizer]: """ 指定された言語の BERT モデルをロードし、ロード済みの BERT トークナイザーを返す。 一度ロードされていれば、ロード済みの BERT トークナイザーを即座に返す。 @@ -100,7 +100,7 @@ def load_tokenizer( pretrained_model_name_or_path (Optional[str]): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) Returns: - PreTrainedTokenizer | PreTrainedTokenizerFast | DebertaV2Tokenizer: ロード済みの BERT トークナイザー + Union[PreTrainedTokenizer, PreTrainedTokenizerFast, DebertaV2Tokenizer]: ロード済みの BERT トークナイザー """ # すでにロード済みの場合はそのまま返す diff --git a/style_bert_vits2/nlp/chinese/bert_feature.py b/style_bert_vits2/nlp/chinese/bert_feature.py index b97950b0c..f448b301d 100644 --- a/style_bert_vits2/nlp/chinese/bert_feature.py +++ b/style_bert_vits2/nlp/chinese/bert_feature.py @@ -8,13 +8,13 @@ from style_bert_vits2.nlp import bert_models -__models: dict[torch.device | str, PreTrainedModel] = {} +__models: dict[str, PreTrainedModel] = {} def extract_bert_feature( text: str, word2ph: list[int], - device: torch.device | str, + device: str, assist_text: Optional[str] = None, assist_text_weight: float = 0.7, ) -> torch.Tensor: @@ -24,7 +24,7 @@ def extract_bert_feature( Args: text (str): 中国語のテキスト word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト - device (torch.device | str): 推論に利用するデバイス + device (str): 推論に利用するデバイス assist_text (Optional[str], optional): 補助テキスト (デフォルト: None) assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) diff --git a/style_bert_vits2/nlp/japanese/__init__.py b/style_bert_vits2/nlp/japanese/__init__.py index 5c7f19f97..e69de29bb 100644 --- a/style_bert_vits2/nlp/japanese/__init__.py +++ b/style_bert_vits2/nlp/japanese/__init__.py @@ -1,2 +0,0 @@ -from style_bert_vits2.nlp.japanese.g2p import g2p # noqa: F401 -from style_bert_vits2.nlp.japanese.normalizer import normalize_text # noqa: F401 diff --git a/style_bert_vits2/nlp/japanese/bert_feature.py b/style_bert_vits2/nlp/japanese/bert_feature.py index ede1f83b9..0d70014fd 100644 --- a/style_bert_vits2/nlp/japanese/bert_feature.py +++ b/style_bert_vits2/nlp/japanese/bert_feature.py @@ -9,13 +9,13 @@ from style_bert_vits2.nlp.japanese.g2p import text_to_sep_kata -__models: dict[torch.device | str, PreTrainedModel] = {} +__models: dict[str, PreTrainedModel] = {} def extract_bert_feature( text: str, word2ph: list[int], - device: torch.device | str, + device: str, assist_text: Optional[str] = None, assist_text_weight: float = 0.7, ) -> torch.Tensor: @@ -25,7 +25,7 @@ def extract_bert_feature( Args: text (str): 日本語のテキスト word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト - device (torch.device | str): 推論に利用するデバイス + device (str): 推論に利用するデバイス assist_text (Optional[str], optional): 補助テキスト (デフォルト: None) assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7) From df687716512a8e5f0c89e9dda959c75348f6fb0e Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 09:04:40 +0000 Subject: [PATCH 046/148] Refactor: split style_bert_vits2.nlp.chinese package Configured so that the same public function is exported from the module with the same name for each language. --- style_bert_vits2/nlp/__init__.py | 3 +- style_bert_vits2/nlp/chinese/__init__.py | 193 ------------------ style_bert_vits2/nlp/chinese/g2p.py | 135 ++++++++++++ style_bert_vits2/nlp/chinese/normalizer.py | 61 ++++++ .../{english => chinese}/opencpop-strict.txt | 0 style_bert_vits2/nlp/english/normalizer.py | 24 +-- 6 files changed, 210 insertions(+), 206 deletions(-) create mode 100644 style_bert_vits2/nlp/chinese/g2p.py create mode 100644 style_bert_vits2/nlp/chinese/normalizer.py rename style_bert_vits2/nlp/{english => chinese}/opencpop-strict.txt (100%) diff --git a/style_bert_vits2/nlp/__init__.py b/style_bert_vits2/nlp/__init__.py index 99f56a9f6..afc6cb09e 100644 --- a/style_bert_vits2/nlp/__init__.py +++ b/style_bert_vits2/nlp/__init__.py @@ -82,7 +82,8 @@ def clean_text( norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text) elif language == Languages.ZH: - from style_bert_vits2.nlp.chinese import g2p, normalize_text + from style_bert_vits2.nlp.chinese.g2p import g2p + from style_bert_vits2.nlp.chinese.normalizer import normalize_text norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text) else: diff --git a/style_bert_vits2/nlp/chinese/__init__.py b/style_bert_vits2/nlp/chinese/__init__.py index 5b708636d..e69de29bb 100644 --- a/style_bert_vits2/nlp/chinese/__init__.py +++ b/style_bert_vits2/nlp/chinese/__init__.py @@ -1,193 +0,0 @@ -import os -import re - -import cn2an -import jieba.posseg as psg -from pypinyin import lazy_pinyin, Style - -from style_bert_vits2.nlp.chinese.tone_sandhi import ToneSandhi -from style_bert_vits2.nlp.symbols import PUNCTUATIONS - - -current_file_path = os.path.dirname(__file__) -pinyin_to_symbol_map = { - line.split("\t")[0]: line.strip().split("\t")[1] - for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() -} - - -rep_map = { - ":": ",", - ";": ",", - ",": ",", - "。": ".", - "!": "!", - "?": "?", - "\n": ".", - "·": ",", - "、": ",", - "...": "…", - "$": ".", - "“": "'", - "”": "'", - '"': "'", - "‘": "'", - "’": "'", - "(": "'", - ")": "'", - "(": "'", - ")": "'", - "《": "'", - "》": "'", - "【": "'", - "】": "'", - "[": "'", - "]": "'", - "—": "-", - "~": "-", - "~": "-", - "「": "'", - "」": "'", -} - -tone_modifier = ToneSandhi() - - -def replace_punctuation(text): - text = text.replace("嗯", "恩").replace("呣", "母") - pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) - - replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) - - replaced_text = re.sub( - r"[^\u4e00-\u9fa5" + "".join(PUNCTUATIONS) + r"]+", "", replaced_text - ) - - return replaced_text - - -def g2p(text: str) -> tuple[list[str], list[int], list[int]]: - pattern = r"(?<=[{0}])\s*".format("".join(PUNCTUATIONS)) - sentences = [i for i in re.split(pattern, text) if i.strip() != ""] - phones, tones, word2ph = _g2p(sentences) - assert sum(word2ph) == len(phones) - assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch. - phones = ["_"] + phones + ["_"] - tones = [0] + tones + [0] - word2ph = [1] + word2ph + [1] - return phones, tones, word2ph - - -def _get_initials_finals(word): - initials = [] - finals = [] - orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) - orig_finals = lazy_pinyin( - word, neutral_tone_with_five=True, style=Style.FINALS_TONE3 - ) - for c, v in zip(orig_initials, orig_finals): - initials.append(c) - finals.append(v) - return initials, finals - - -def _g2p(segments): - phones_list = [] - tones_list = [] - word2ph = [] - for seg in segments: - # Replace all English words in the sentence - seg = re.sub("[a-zA-Z]+", "", seg) - seg_cut = psg.lcut(seg) - initials = [] - finals = [] - seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) - for word, pos in seg_cut: - if pos == "eng": - continue - sub_initials, sub_finals = _get_initials_finals(word) - sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) - initials.append(sub_initials) - finals.append(sub_finals) - - # assert len(sub_initials) == len(sub_finals) == len(word) - initials = sum(initials, []) - finals = sum(finals, []) - # - for c, v in zip(initials, finals): - raw_pinyin = c + v - # NOTE: post process for pypinyin outputs - # we discriminate i, ii and iii - if c == v: - assert c in PUNCTUATIONS - phone = [c] - tone = "0" - word2ph.append(1) - else: - v_without_tone = v[:-1] - tone = v[-1] - - pinyin = c + v_without_tone - assert tone in "12345" - - if c: - # 多音节 - v_rep_map = { - "uei": "ui", - "iou": "iu", - "uen": "un", - } - if v_without_tone in v_rep_map.keys(): - pinyin = c + v_rep_map[v_without_tone] - else: - # 单音节 - pinyin_rep_map = { - "ing": "ying", - "i": "yi", - "in": "yin", - "u": "wu", - } - if pinyin in pinyin_rep_map.keys(): - pinyin = pinyin_rep_map[pinyin] - else: - single_rep_map = { - "v": "yu", - "e": "e", - "i": "y", - "u": "w", - } - if pinyin[0] in single_rep_map.keys(): - pinyin = single_rep_map[pinyin[0]] + pinyin[1:] - - assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) - phone = pinyin_to_symbol_map[pinyin].split(" ") - word2ph.append(len(phone)) - - phones_list += phone - tones_list += [int(tone)] * len(phone) - return phones_list, tones_list, word2ph - - -def normalize_text(text: str) -> str: - numbers = re.findall(r"\d+(?:\.?\d+)?", text) - for number in numbers: - text = text.replace(number, cn2an.an2cn(number), 1) - text = replace_punctuation(text) - return text - - -if __name__ == "__main__": - from style_bert_vits2.nlp.chinese.bert_feature import extract_bert_feature - - text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏" - text = normalize_text(text) - print(text) - phones, tones, word2ph = g2p(text) - bert = extract_bert_feature(text, word2ph, 'cuda') - - print(phones, tones, word2ph, bert.shape) - - -# # 示例用法 -# text = "这是一个示例文本:,你好!这是一个测试...." -# print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试 diff --git a/style_bert_vits2/nlp/chinese/g2p.py b/style_bert_vits2/nlp/chinese/g2p.py new file mode 100644 index 000000000..1cb3839f6 --- /dev/null +++ b/style_bert_vits2/nlp/chinese/g2p.py @@ -0,0 +1,135 @@ +import re +from pathlib import Path + +import jieba.posseg as psg +from pypinyin import lazy_pinyin, Style + +from style_bert_vits2.nlp.chinese.tone_sandhi import ToneSandhi +from style_bert_vits2.nlp.symbols import PUNCTUATIONS + + +__PINYIN_TO_SYMBOL_MAP = { + line.split("\t")[0]: line.strip().split("\t")[1] + for line in open(Path(__file__).parent / "opencpop-strict.txt").readlines() +} + + +def g2p(text: str) -> tuple[list[str], list[int], list[int]]: + pattern = r"(?<=[{0}])\s*".format("".join(PUNCTUATIONS)) + sentences = [i for i in re.split(pattern, text) if i.strip() != ""] + phones, tones, word2ph = __g2p(sentences) + assert sum(word2ph) == len(phones) + assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch. + phones = ["_"] + phones + ["_"] + tones = [0] + tones + [0] + word2ph = [1] + word2ph + [1] + return phones, tones, word2ph + + +def __g2p(segments: list[str]) -> tuple[list[str], list[int], list[int]]: + phones_list = [] + tones_list = [] + word2ph = [] + tone_modifier = ToneSandhi() + for seg in segments: + # Replace all English words in the sentence + seg = re.sub("[a-zA-Z]+", "", seg) + seg_cut = psg.lcut(seg) + initials = [] + finals = [] + seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) # type: ignore + for word, pos in seg_cut: + if pos == "eng": + continue + sub_initials, sub_finals = __get_initials_finals(word) + sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) + initials.append(sub_initials) + finals.append(sub_finals) + + # assert len(sub_initials) == len(sub_finals) == len(word) + initials = sum(initials, []) + finals = sum(finals, []) + # + for c, v in zip(initials, finals): + raw_pinyin = c + v + # NOTE: post process for pypinyin outputs + # we discriminate i, ii and iii + if c == v: + assert c in PUNCTUATIONS + phone = [c] + tone = "0" + word2ph.append(1) + else: + v_without_tone = v[:-1] + tone = v[-1] + + pinyin = c + v_without_tone + assert tone in "12345" + + if c: + # 多音节 + v_rep_map = { + "uei": "ui", + "iou": "iu", + "uen": "un", + } + if v_without_tone in v_rep_map.keys(): + pinyin = c + v_rep_map[v_without_tone] + else: + # 单音节 + pinyin_rep_map = { + "ing": "ying", + "i": "yi", + "in": "yin", + "u": "wu", + } + if pinyin in pinyin_rep_map.keys(): + pinyin = pinyin_rep_map[pinyin] + else: + single_rep_map = { + "v": "yu", + "e": "e", + "i": "y", + "u": "w", + } + if pinyin[0] in single_rep_map.keys(): + pinyin = single_rep_map[pinyin[0]] + pinyin[1:] + + assert pinyin in __PINYIN_TO_SYMBOL_MAP.keys(), (pinyin, seg, raw_pinyin) + phone = __PINYIN_TO_SYMBOL_MAP[pinyin].split(" ") + word2ph.append(len(phone)) + + phones_list += phone + tones_list += [int(tone)] * len(phone) + return phones_list, tones_list, word2ph + + +def __get_initials_finals(word: str) -> tuple[list[str], list[str]]: + initials = [] + finals = [] + orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) + orig_finals = lazy_pinyin( + word, neutral_tone_with_five=True, style=Style.FINALS_TONE3 + ) + for c, v in zip(orig_initials, orig_finals): + initials.append(c) + finals.append(v) + return initials, finals + + +if __name__ == "__main__": + from style_bert_vits2.nlp.chinese.bert_feature import extract_bert_feature + from style_bert_vits2.nlp.chinese.normalizer import normalize_text + + text = "啊!但是《原神》是由,米哈游自主, [研发]的一款全.新开放世界.冒险游戏" + text = normalize_text(text) + print(text) + phones, tones, word2ph = g2p(text) + bert = extract_bert_feature(text, word2ph, 'cuda') + + print(phones, tones, word2ph, bert.shape) + + +# 示例用法 +# text = "这是一个示例文本:,你好!这是一个测试...." +# print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试 diff --git a/style_bert_vits2/nlp/chinese/normalizer.py b/style_bert_vits2/nlp/chinese/normalizer.py new file mode 100644 index 000000000..c56c636f3 --- /dev/null +++ b/style_bert_vits2/nlp/chinese/normalizer.py @@ -0,0 +1,61 @@ +import re + +import cn2an + +from style_bert_vits2.nlp.symbols import PUNCTUATIONS + + +def normalize_text(text: str) -> str: + numbers = re.findall(r"\d+(?:\.?\d+)?", text) + for number in numbers: + text = text.replace(number, cn2an.an2cn(number), 1) + text = replace_punctuation(text) + return text + + +def replace_punctuation(text: str) -> str: + + REPLACE_MAP = { + ":": ",", + ";": ",", + ",": ",", + "。": ".", + "!": "!", + "?": "?", + "\n": ".", + "·": ",", + "、": ",", + "...": "…", + "$": ".", + "“": "'", + "”": "'", + '"': "'", + "‘": "'", + "’": "'", + "(": "'", + ")": "'", + "(": "'", + ")": "'", + "《": "'", + "》": "'", + "【": "'", + "】": "'", + "[": "'", + "]": "'", + "—": "-", + "~": "-", + "~": "-", + "「": "'", + "」": "'", + } + + text = text.replace("嗯", "恩").replace("呣", "母") + pattern = re.compile("|".join(re.escape(p) for p in REPLACE_MAP.keys())) + + replaced_text = pattern.sub(lambda x: REPLACE_MAP[x.group()], text) + + replaced_text = re.sub( + r"[^\u4e00-\u9fa5" + "".join(PUNCTUATIONS) + r"]+", "", replaced_text + ) + + return replaced_text diff --git a/style_bert_vits2/nlp/english/opencpop-strict.txt b/style_bert_vits2/nlp/chinese/opencpop-strict.txt similarity index 100% rename from style_bert_vits2/nlp/english/opencpop-strict.txt rename to style_bert_vits2/nlp/chinese/opencpop-strict.txt diff --git a/style_bert_vits2/nlp/english/normalizer.py b/style_bert_vits2/nlp/english/normalizer.py index 0886581b9..81b71d7cc 100644 --- a/style_bert_vits2/nlp/english/normalizer.py +++ b/style_bert_vits2/nlp/english/normalizer.py @@ -14,22 +14,12 @@ def normalize_text(text: str) -> str: text = __normalize_numbers(text) - text = __replace_punctuation(text) + text = replace_punctuation(text) text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text) return text -def __normalize_numbers(text: str) -> str: - text = re.sub(__COMMA_NUMBER_PATTERN, __remove_commas, text) - text = re.sub(__POUNDS_PATTERN, r"\1 pounds", text) - text = re.sub(__DOLLARS_PATTERN, __expand_dollars, text) - text = re.sub(__DECIMAL_NUMBER_PATTERN, __expand_decimal_point, text) - text = re.sub(__ORDINAL_PATTERN, __expand_ordinal, text) - text = re.sub(__NUMBER_PATTERN, __expand_number, text) - return text - - -def __replace_punctuation(text: str) -> str: +def replace_punctuation(text: str) -> str: REPLACE_MAP = { ":": ",", ";": ",", @@ -80,6 +70,16 @@ def __replace_punctuation(text: str) -> str: return replaced_text +def __normalize_numbers(text: str) -> str: + text = re.sub(__COMMA_NUMBER_PATTERN, __remove_commas, text) + text = re.sub(__POUNDS_PATTERN, r"\1 pounds", text) + text = re.sub(__DOLLARS_PATTERN, __expand_dollars, text) + text = re.sub(__DECIMAL_NUMBER_PATTERN, __expand_decimal_point, text) + text = re.sub(__ORDINAL_PATTERN, __expand_ordinal, text) + text = re.sub(__NUMBER_PATTERN, __expand_number, text) + return text + + def __expand_dollars(m: re.Match[str]) -> str: match = m.group(1) parts = match.split(".") From ebb37ad7bb7e54a1757b9da0820e88a2349b476e Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 8 Mar 2024 18:18:38 +0900 Subject: [PATCH 047/148] Fix multiprocessing bert_gen error of pyopenjtalk_worker --- bert_gen.py | 5 ++--- infer.py | 3 ++- preprocess_all.py | 7 +++++-- text/__init__.py | 13 ------------- text/japanese.py | 6 ++---- text/japanese_bert.py | 2 +- webui/merge.py | 4 ++-- 7 files changed, 14 insertions(+), 26 deletions(-) diff --git a/bert_gen.py b/bert_gen.py index 1e4fb61f4..19dbbb61c 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -6,14 +6,13 @@ from tqdm import tqdm import commons +from text.get_bert import get_bert import text.pyopenjtalk_worker as pyopenjtalk import utils from common.log import logger from common.stdout_wrapper import SAFE_STDOUT from config import config -from text import cleaned_text_to_sequence, get_bert - -pyopenjtalk.initialize() +from text import cleaned_text_to_sequence def process_line(x): diff --git a/infer.py b/infer.py index 3707df1f9..436ea8bf0 100644 --- a/infer.py +++ b/infer.py @@ -1,10 +1,11 @@ import torch import commons +from text.get_bert import get_bert import utils from models import SynthesizerTrn from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra -from text import cleaned_text_to_sequence, get_bert +from text import cleaned_text_to_sequence from text.cleaner import clean_text from text.symbols import symbols from common.log import logger diff --git a/preprocess_all.py b/preprocess_all.py index fd3dcd5b7..62c0b4f6e 100644 --- a/preprocess_all.py +++ b/preprocess_all.py @@ -1,5 +1,5 @@ import argparse -from webui_train import preprocess_all +from webui.train import preprocess_all from multiprocessing import cpu_count if __name__ == "__main__": @@ -75,7 +75,10 @@ default=200, ) parser.add_argument( - "--yomi_error", type=str, help="Yomi error. raise, skip, use", default="raise" + "--yomi_error", + type=str, + help="Yomi error. Options: raise, skip, use", + default="raise", ) args = parser.parse_args() diff --git a/text/__init__.py b/text/__init__.py index d8ae88dea..7ba6e2095 100644 --- a/text/__init__.py +++ b/text/__init__.py @@ -16,16 +16,3 @@ def cleaned_text_to_sequence(cleaned_text, tones, language): lang_id = language_id_map[language] lang_ids = [lang_id for i in phones] return phones, tones, lang_ids - - -def get_bert(text, word2ph, language, device, assist_text=None, assist_text_weight=0.7): - if language == "ZH": - from .chinese_bert import get_bert_feature - elif language == "EN": - from .english_bert_mock import get_bert_feature - elif language == "JP": - from .japanese_bert import get_bert_feature - else: - raise ValueError(f"Language {language} not supported") - - return get_bert_feature(text, word2ph, device, assist_text, assist_text_weight) diff --git a/text/japanese.py b/text/japanese.py index fea0eaa5d..f7bfcae8f 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -4,9 +4,6 @@ import unicodedata from pathlib import Path -from . import pyopenjtalk_worker as pyopenjtalk - -pyopenjtalk.initialize() from num2words import num2words from transformers import AutoTokenizer @@ -16,9 +13,10 @@ mora_kata_to_mora_phonemes, mora_phonemes_to_mora_kata, ) - from text.user_dict import update_dict +from . import pyopenjtalk_worker as pyopenjtalk + # 最初にpyopenjtalkの辞書を更新 update_dict() diff --git a/text/japanese_bert.py b/text/japanese_bert.py index dcee0f3d2..9ebc64a09 100644 --- a/text/japanese_bert.py +++ b/text/japanese_bert.py @@ -52,7 +52,7 @@ def get_bert_feature( style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) - assert len(word2ph) == len(text) + 2, text + assert len(word2ph) == len(text) + 2, f"word2ph: {word2ph}, text: {text}" word2phone = word2ph phone_level_feature = [] for i in range(len(word2phone)): diff --git a/webui/merge.py b/webui/merge.py index c9386efbc..aeab08d5d 100644 --- a/webui/merge.py +++ b/webui/merge.py @@ -254,7 +254,7 @@ def simple_tts(model_name, text, style=DEFAULT_STYLE, style_weight=1.0): return model.infer(text, style=style, style_weight=style_weight) -def update_two_model_names_dropdown(): +def update_two_model_names_dropdown(model_holder: ModelHolder): new_names, new_files, _ = model_holder.update_model_names_gr() return new_names, new_files, new_names, new_files @@ -455,7 +455,7 @@ def create_merge_app(model_holder: ModelHolder) -> gr.Blocks: ) refresh_button.click( - update_two_model_names_dropdown, + lambda: update_two_model_names_dropdown(model_holder), outputs=[model_name_a, model_path_a, model_name_b, model_path_b], ) From 766699e812084322822027dee76b1003dbe02918 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 09:21:09 +0000 Subject: [PATCH 048/148] Fix: pyopenjtalk_worker not working --- style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py | 4 ++-- .../nlp/japanese/pyopenjtalk_worker/worker_server.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py index 90419f5bf..670461de5 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py @@ -56,7 +56,6 @@ def initialize(port: int = WORKER_PORT) -> None: import sys import time - logger.debug("initialize") global WORKER_CLIENT if WORKER_CLIENT: return @@ -98,6 +97,7 @@ def initialize(port: int = WORKER_PORT) -> None: if count == 20: raise TimeoutError("サーバーに接続できませんでした") + logger.debug("pyopenjtalk worker server started") WORKER_CLIENT = client atexit.register(terminate) @@ -110,7 +110,7 @@ def signal_handler(signum: int, frame: Any): # top-level declaration def terminate() -> None: - logger.debug("terminate") + logger.debug("pyopenjtalk worker server terminated") global WORKER_CLIENT if not WORKER_CLIENT: return diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py index e9a3dc140..149323ad8 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py @@ -16,7 +16,7 @@ # To make it as fast as possible # Probably faster than calling getattr every time -__PYOPENJTALK_FUNC_DICT = { +PYOPENJTALK_FUNC_DICT = { "run_frontend": pyopenjtalk.run_frontend, "make_label": pyopenjtalk.make_label, "mecab_dict_index": pyopenjtalk.mecab_dict_index, @@ -57,7 +57,7 @@ def handle_request(self, request: dict[str, Any]) -> dict[str, Any]: elif request_type == RequestType.PYOPENJTALK: func_name = request.get("func") assert isinstance(func_name, str) - func = __PYOPENJTALK_FUNC_DICT[func_name] + func = PYOPENJTALK_FUNC_DICT[func_name] args = request.get("args") kwargs = request.get("kwargs") assert isinstance(args, list) From fe7e31e0806a82972c44569bbb9ac9abc2be4c87 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 09:34:44 +0000 Subject: [PATCH 049/148] Refactor: moved common/tts_model.py to style_bert_vits2/ --- app.py | 2 +- server_editor.py | 2 +- server_fastapi.py | 2 +- speech_mos.py | 4 ++-- {common => style_bert_vits2}/tts_model.py | 0 webui/inference.py | 2 +- webui/merge.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) rename {common => style_bert_vits2}/tts_model.py (100%) diff --git a/app.py b/app.py index 13929cd46..5d05d7cfc 100644 --- a/app.py +++ b/app.py @@ -6,7 +6,7 @@ import yaml from style_bert_vits2.constants import GRADIO_THEME, LATEST_VERSION -from common.tts_model import ModelHolder +from style_bert_vits2.tts_model import ModelHolder from webui import ( create_dataset_app, create_inference_app, diff --git a/server_editor.py b/server_editor.py index 890a074b8..022a95adf 100644 --- a/server_editor.py +++ b/server_editor.py @@ -30,7 +30,6 @@ from pydantic import BaseModel from scipy.io import wavfile -from common.tts_model import ModelHolder from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_NOISE, @@ -52,6 +51,7 @@ rewrite_word, update_dict, ) +from style_bert_vits2.tts_model import ModelHolder # ---フロントエンド部分に関する処理--- diff --git a/server_fastapi.py b/server_fastapi.py index b7ebb772d..833af4ab5 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -20,7 +20,6 @@ from fastapi.responses import FileResponse, Response from scipy.io import wavfile -from common.tts_model import Model, ModelHolder from config import config from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, @@ -35,6 +34,7 @@ Languages, ) from style_bert_vits2.logging import logger +from style_bert_vits2.tts_model import Model, ModelHolder ln = config.server_config.language diff --git a/speech_mos.py b/speech_mos.py index 15cccef5f..79221adcc 100644 --- a/speech_mos.py +++ b/speech_mos.py @@ -10,9 +10,9 @@ import torch from tqdm import tqdm -from style_bert_vits2.logging import logger -from common.tts_model import Model from config import config +from style_bert_vits2.logging import logger +from style_bert_vits2.tts_model import Model warnings.filterwarnings("ignore") diff --git a/common/tts_model.py b/style_bert_vits2/tts_model.py similarity index 100% rename from common/tts_model.py rename to style_bert_vits2/tts_model.py diff --git a/webui/inference.py b/webui/inference.py index 026481f30..02d6f3abb 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -4,7 +4,6 @@ import gradio as gr -from common.tts_model import ModelHolder from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, @@ -22,6 +21,7 @@ from style_bert_vits2.models.infer import InvalidToneError from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.nlp.japanese.normalizer import normalize_text +from style_bert_vits2.tts_model import ModelHolder languages = [l.value for l in Languages] diff --git a/webui/merge.py b/webui/merge.py index 1fc3ceef3..e9dd1f004 100644 --- a/webui/merge.py +++ b/webui/merge.py @@ -9,9 +9,9 @@ from safetensors import safe_open from safetensors.torch import save_file -from common.tts_model import Model, ModelHolder from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME from style_bert_vits2.logging import logger +from style_bert_vits2.tts_model import Model, ModelHolder voice_keys = ["dec"] From 67ff3105c1a7065971319bf71cebe1fef4478b75 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 09:40:27 +0000 Subject: [PATCH 050/148] Refactor: moved utils.py to style_bert_vits2/models/ --- bert_gen.py | 2 +- data_utils.py | 2 +- style_bert_vits2/models/infer.py | 2 +- utils.py => style_bert_vits2/models/utils.py | 0 style_bert_vits2/tts_model.py | 2 +- style_gen.py | 2 +- train_ms.py | 2 +- train_ms_jp_extra.py | 2 +- 8 files changed, 7 insertions(+), 7 deletions(-) rename utils.py => style_bert_vits2/models/utils.py (100%) diff --git a/bert_gen.py b/bert_gen.py index 26df64d0f..09359295a 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -5,10 +5,10 @@ import torch.multiprocessing as mp from tqdm import tqdm -import utils from config import config from style_bert_vits2.logging import logger from style_bert_vits2.models import commons +from style_bert_vits2.models import utils from style_bert_vits2.nlp import cleaned_text_to_sequence, extract_bert_feature from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT diff --git a/data_utils.py b/data_utils.py index 460c15fd4..99da2e4db 100644 --- a/data_utils.py +++ b/data_utils.py @@ -9,9 +9,9 @@ from config import config from mel_processing import mel_spectrogram_torch, spectrogram_torch -from utils import load_filepaths_and_text, load_wav_to_torch from style_bert_vits2.logging import logger from style_bert_vits2.models import commons +from style_bert_vits2.models.utils import load_filepaths_and_text, load_wav_to_torch from style_bert_vits2.nlp import cleaned_text_to_sequence """Multi speaker version""" diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index 9265eeec2..7394d0b68 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -1,10 +1,10 @@ import torch from typing import Optional -import utils from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger from style_bert_vits2.models import commons +from style_bert_vits2.models import utils from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra from style_bert_vits2.nlp import clean_text, cleaned_text_to_sequence, extract_bert_feature diff --git a/utils.py b/style_bert_vits2/models/utils.py similarity index 100% rename from utils.py rename to style_bert_vits2/models/utils.py diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index a5901e112..bf04fe24b 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -7,7 +7,6 @@ import torch from gradio.processing_utils import convert_to_16_bit_wav -import utils from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, @@ -19,6 +18,7 @@ DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, ) +from style_bert_vits2.models import utils from style_bert_vits2.models.infer import get_net_g, infer from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra diff --git a/style_gen.py b/style_gen.py index 1c1f0340f..d7f692f50 100644 --- a/style_gen.py +++ b/style_gen.py @@ -6,8 +6,8 @@ import torch from tqdm import tqdm -import utils from style_bert_vits2.logging import logger +from style_bert_vits2.models import utils from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config diff --git a/train_ms.py b/train_ms.py index 0cb7a8216..977b393ee 100644 --- a/train_ms.py +++ b/train_ms.py @@ -15,7 +15,6 @@ # logging.getLogger("numba").setLevel(logging.WARNING) import default_style -import utils from config import config from data_utils import ( DistributedBucketSampler, @@ -26,6 +25,7 @@ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch from style_bert_vits2.logging import logger from style_bert_vits2.models import commons +from style_bert_vits2.models import utils from style_bert_vits2.models.models import ( DurationDiscriminator, MultiPeriodDiscriminator, diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index 7d0636b95..3b1c01aad 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -15,7 +15,6 @@ # logging.getLogger("numba").setLevel(logging.WARNING) import default_style -import utils from config import config from data_utils import ( DistributedBucketSampler, @@ -26,6 +25,7 @@ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch from style_bert_vits2.logging import logger from style_bert_vits2.models import commons +from style_bert_vits2.models import utils from style_bert_vits2.models.models_jp_extra import ( DurationDiscriminator, MultiPeriodDiscriminator, From c915215ad2c3d99b01f0d7843dc8e35efcd4cf1c Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 09:43:31 +0000 Subject: [PATCH 051/148] Refactor: moved transforms.py to style_bert_vits2/models/ --- style_bert_vits2/models/modules.py | 2 +- transforms.py => style_bert_vits2/models/transforms.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename transforms.py => style_bert_vits2/models/transforms.py (100%) diff --git a/style_bert_vits2/models/modules.py b/style_bert_vits2/models/modules.py index c076e212b..eede771fb 100644 --- a/style_bert_vits2/models/modules.py +++ b/style_bert_vits2/models/modules.py @@ -5,10 +5,10 @@ from torch.nn import Conv1d from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, weight_norm -from transforms import piecewise_rational_quadratic_transform from style_bert_vits2.models import commons from style_bert_vits2.models.attentions import Encoder +from style_bert_vits2.models.transforms import piecewise_rational_quadratic_transform LRELU_SLOPE = 0.1 diff --git a/transforms.py b/style_bert_vits2/models/transforms.py similarity index 100% rename from transforms.py rename to style_bert_vits2/models/transforms.py From 8fad591ac0894a7cae11ceb9afa617ecf15a1231 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 11:59:14 +0000 Subject: [PATCH 052/148] Fix: backported StrEnum from Python 3.11 --- style_bert_vits2/constants.py | 3 ++- style_bert_vits2/utils/strenum.py | 36 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 style_bert_vits2/utils/strenum.py diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index d5a46bfd7..5f32a15a4 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -1,6 +1,7 @@ -from enum import StrEnum from pathlib import Path +from style_bert_vits2.utils.strenum import StrEnum + # Style-Bert-VITS2 のバージョン LATEST_VERSION = "2.4" diff --git a/style_bert_vits2/utils/strenum.py b/style_bert_vits2/utils/strenum.py new file mode 100644 index 000000000..40d3b0cb5 --- /dev/null +++ b/style_bert_vits2/utils/strenum.py @@ -0,0 +1,36 @@ +import enum + + +class StrEnum(str, enum.Enum): + """ + Enum where members are also (and must be) strings (backported from Python 3.11). + """ + + def __new__(cls, *values: str) -> "StrEnum": + "values must already be of type `str`" + if len(values) > 3: + raise TypeError('too many arguments for str(): %r' % (values, )) + if len(values) == 1: + # it must be a string + if not isinstance(values[0], str): # type: ignore + raise TypeError('%r is not a string' % (values[0], )) + if len(values) >= 2: + # check that encoding argument is a string + if not isinstance(values[1], str): # type: ignore + raise TypeError('encoding must be a string, not %r' % (values[1], )) + if len(values) == 3: + # check that errors argument is a string + if not isinstance(values[2], str): # type: ignore + raise TypeError('errors must be a string, not %r' % (values[2])) + value = str(*values) + member = str.__new__(cls, value) + member._value_ = value + return member + + + @staticmethod + def _generate_next_value_(name: str, start: int, count: int, last_values: list[str]) -> str: + """ + Return the lower-cased version of the member name. + """ + return name.lower() From 7f0b2528066e3b7e02bfed100dcef3ca5806dca3 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 12:49:11 +0000 Subject: [PATCH 053/148] Refactor: add Pydantic model representing hyper-parameters of Style-Bert-VITS2 model --- requirements.txt | 1 + style_bert_vits2/models/hyper_parameters.py | 116 ++++++++++++++++++++ style_bert_vits2/models/utils.py | 33 ------ 3 files changed, 117 insertions(+), 33 deletions(-) create mode 100644 style_bert_vits2/models/hyper_parameters.py diff --git a/requirements.txt b/requirements.txt index 6bf329ffb..669af2f94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ numba numpy psutil pyannote.audio>=3.1.0 +pydantic pyloudnorm # pyopenjtalk-prebuilt # Should be manually uninstalled pyopenjtalk-dict diff --git a/style_bert_vits2/models/hyper_parameters.py b/style_bert_vits2/models/hyper_parameters.py new file mode 100644 index 000000000..cee79246d --- /dev/null +++ b/style_bert_vits2/models/hyper_parameters.py @@ -0,0 +1,116 @@ +""" +Style-Bert-VITS2 モデルのハイパーパラメータを表す Pydantic モデル。 +デフォルト値は configs/configs_jp_extra.json 内の定義と同一で、 +万が一ロードした config.json に存在しないキーがあった際のフェイルセーフとして適用される。 +""" + +from pathlib import Path +from typing import Optional, Union + +from pydantic import BaseModel + + +class __HyperParametersTrain(BaseModel): + log_interval: int = 200 + eval_interval: int = 1000 + seed: int = 42 + epochs: int = 1000 + learning_rate: float = 0.0001 + betas: list[float] = [0.8, 0.99] + eps: float = 1e-9 + batch_size: int = 2 + bf16_run: bool = False + fp16_run: bool = False + lr_decay: float = 0.99996 + segment_size: int = 16384 + init_lr_ratio: int = 1 + warmup_epochs: int = 0 + c_mel: int = 45 + c_kl: float = 1.0 + c_commit: int = 100 + skip_optimizer: bool = False + freeze_ZH_bert: bool = False + freeze_JP_bert: bool = False + freeze_EN_bert: bool = False + freeze_emo: bool = False + freeze_style: bool = False + freeze_decoder: bool = False + +class __HyperParametersData(BaseModel): + use_jp_extra: bool = True + training_files: str = "Data/dummy/train.list" + validation_files: str = "Data/dummy/val.list" + max_wav_value: float = 32768.0 + sampling_rate: int = 44100 + filter_length: int = 2048 + hop_length: int = 512 + win_length: int = 2048 + n_mel_channels: int = 128 + mel_fmin: float = 0.0 + mel_fmax: Optional[float] = None + add_blank: bool = True + n_speakers: int = 512 + cleaned_text: bool = True + spk2id: dict[str, int] = { + "dummy": 0 + } + num_styles: int = 1 + style2id: dict[str, int] = { + "Neutral": 0, + } + +class __HyperParametersModel(BaseModel): + use_spk_conditioned_encoder: bool = True + use_noise_scaled_mas: bool = True + use_mel_posterior_encoder: bool = False + use_duration_discriminator: bool = False + use_wavlm_discriminator: bool = True + inter_channels: int = 192 + hidden_channels: int = 192 + filter_channels: int = 768 + n_heads: int = 2 + n_layers: int = 6 + kernel_size: int = 3 + p_dropout: float = 0.1 + resblock: str = "1" + resblock_kernel_sizes: list[int] = [3, 7, 11] + resblock_dilation_sizes: list[list[int]] = [ + [1, 3, 5], + [1, 3, 5], + [1, 3, 5] + ] + upsample_rates: list[int] = [8, 8, 2, 2, 2] + upsample_initial_channel: int = 512 + upsample_kernel_sizes: list[int] = [16, 16, 8, 2, 2] + n_layers_q: int = 3 + use_spectral_norm: bool = False + gin_channels: int = 512 + slm: dict[str, Union[int, str]] = { + "model": "./slm/wavlm-base-plus", + "sr": 16000, + "hidden": 768, + "nlayers": 13, + "initial_channel": 64 + } + +class HyperParameters(BaseModel): + version: str = "2.0-JP-Extra" + model_name: str = 'dummy' + train: __HyperParametersTrain + data: __HyperParametersData + model: __HyperParametersModel + + + @staticmethod + def load_from_json(json_path: Union[str, Path]) -> "HyperParameters": + """ + 与えられた JSON ファイルからハイパーパラメータを読み込む。 + + Args: + json_path (Union[str, Path]): JSON ファイルのパス + + Returns: + HyperParameters: ハイパーパラメータ + """ + with open(json_path, "r") as f: + return HyperParameters.model_validate_json(f.read()) diff --git a/style_bert_vits2/models/utils.py b/style_bert_vits2/models/utils.py index 8c7e84265..9a7947f0f 100644 --- a/style_bert_vits2/models/utils.py +++ b/style_bert_vits2/models/utils.py @@ -357,39 +357,6 @@ def check_git_hash(model_dir): open(path, "w").write(cur_hash) -def get_hparams(init=True): - parser = argparse.ArgumentParser() - parser.add_argument( - "-c", - "--config", - type=str, - default="./configs/base.json", - help="JSON file for configuration", - ) - parser.add_argument("-m", "--model", type=str, required=True, help="Model name") - - args = parser.parse_args() - model_dir = os.path.join("./logs", args.model) - - if not os.path.exists(model_dir): - os.makedirs(model_dir) - - config_path = args.config - config_save_path = os.path.join(model_dir, "config.json") - if init: - with open(config_path, "r", encoding="utf-8") as f: - data = f.read() - with open(config_save_path, "w", encoding="utf-8") as f: - f.write(data) - else: - with open(config_save_path, "r", vencoding="utf-8") as f: - data = f.read() - config = json.loads(data) - hparams = HParams(**config) - hparams.model_dir = model_dir - return hparams - - def get_hparams_from_file(config_path): # print("config_path: ", config_path) with open(config_path, "r", encoding="utf-8") as f: From ea3c05dcbb1ccb92cdc75f9c78ae52808587b0c7 Mon Sep 17 00:00:00 2001 From: Aka Diffusion <116415100+aka7774@users.noreply.github.com> Date: Fri, 8 Mar 2024 23:45:41 +0900 Subject: [PATCH 054/148] /voice method GET to POST --- server_fastapi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server_fastapi.py b/server_fastapi.py index ce6ed0432..5a68b8b20 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -101,7 +101,7 @@ def load_models(model_holder: ModelHolder): ) app.logger = logger - @app.get("/voice", response_class=AudioResponse) + @app.post("/voice", response_class=AudioResponse) async def voice( request: Request, text: str = Query(..., min_length=1, max_length=limit, description=f"セリフ"), From 7046f89d74075b6ae63d771cd94f99472260ace2 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 8 Mar 2024 23:46:03 +0900 Subject: [PATCH 055/148] Add missing file --- text/get_bert.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 text/get_bert.py diff --git a/text/get_bert.py b/text/get_bert.py new file mode 100644 index 000000000..07e4bd769 --- /dev/null +++ b/text/get_bert.py @@ -0,0 +1,16 @@ +from .japanese_bert import get_bert_feature as get_japanese_bert_feature + + +def get_bert(text, word2ph, language, device, assist_text=None, assist_text_weight=0.7): + if language == "ZH": + from .chinese_bert import get_bert_feature + elif language == "EN": + from .english_bert_mock import get_bert_feature + elif language == "JP": + # pyopenjtalkのworkerを1度だけ起動するため、ここでのimportは避ける + # 他言語のようにimportすると、get_bertが呼ばれるたびにpyopenjtalkのworkerが起動してしまう + get_bert_feature = get_japanese_bert_feature + else: + raise ValueError(f"Language {language} not supported") + + return get_bert_feature(text, word2ph, device, assist_text, assist_text_weight) From a0b5cd3d1b0f0b78aa3af3931dfba78916639479 Mon Sep 17 00:00:00 2001 From: kale4eat Date: Sat, 9 Mar 2024 00:23:23 +0900 Subject: [PATCH 056/148] Extended timeout in client side --- text/pyopenjtalk_worker/worker_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/text/pyopenjtalk_worker/worker_client.py b/text/pyopenjtalk_worker/worker_client.py index 86d8969a9..710243d24 100644 --- a/text/pyopenjtalk_worker/worker_client.py +++ b/text/pyopenjtalk_worker/worker_client.py @@ -9,8 +9,8 @@ class WorkerClient: def __init__(self, port: int) -> None: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # 5: timeout - sock.settimeout(5) + # 60: timeout + sock.settimeout(60) sock.connect((socket.gethostname(), port)) self.sock = sock From a84783a6cc75ef17c8f6773728368c1b1642285c Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 15:52:37 +0000 Subject: [PATCH 057/148] Refactor: replace utils.HParams with HyperParameters Pydantic model HyperParameters is largely a drop-in replacement for utils.HParams, which ensures type safety for hyper-parameters. --- bert_gen.py | 4 +- data_utils.py | 3 +- style_bert_vits2/models/hyper_parameters.py | 30 ++- style_bert_vits2/models/infer.py | 192 +++++++------------- style_bert_vits2/models/models.py | 4 +- style_bert_vits2/models/models_jp_extra.py | 4 +- style_bert_vits2/models/utils.py | 42 ----- style_bert_vits2/tts_model.py | 71 ++++---- style_gen.py | 3 +- train_ms.py | 29 ++- train_ms_jp_extra.py | 27 ++- 11 files changed, 189 insertions(+), 220 deletions(-) diff --git a/bert_gen.py b/bert_gen.py index 09359295a..5a16af7e5 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -8,7 +8,7 @@ from config import config from style_bert_vits2.logging import logger from style_bert_vits2.models import commons -from style_bert_vits2.models import utils +from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.nlp import cleaned_text_to_sequence, extract_bert_feature from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT @@ -62,7 +62,7 @@ def process_line(x): ) args, _ = parser.parse_known_args() config_path = args.config - hps = utils.get_hparams_from_file(config_path) + hps = HyperParameters.load_from_json(config_path) lines = [] with open(hps.data.training_files, encoding="utf-8") as f: lines.extend(f.readlines()) diff --git a/data_utils.py b/data_utils.py index 99da2e4db..04047e210 100644 --- a/data_utils.py +++ b/data_utils.py @@ -11,6 +11,7 @@ from mel_processing import mel_spectrogram_torch, spectrogram_torch from style_bert_vits2.logging import logger from style_bert_vits2.models import commons +from style_bert_vits2.models.hyper_parameters import HyperParametersData from style_bert_vits2.models.utils import load_filepaths_and_text, load_wav_to_torch from style_bert_vits2.nlp import cleaned_text_to_sequence @@ -24,7 +25,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): 3) computes spectrograms from audio files. """ - def __init__(self, audiopaths_sid_text, hparams): + def __init__(self, audiopaths_sid_text: str, hparams: HyperParametersData): self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) self.max_wav_value = hparams.max_wav_value self.sampling_rate = hparams.sampling_rate diff --git a/style_bert_vits2/models/hyper_parameters.py b/style_bert_vits2/models/hyper_parameters.py index cee79246d..9dc5afb8d 100644 --- a/style_bert_vits2/models/hyper_parameters.py +++ b/style_bert_vits2/models/hyper_parameters.py @@ -1,16 +1,16 @@ """ Style-Bert-VITS2 モデルのハイパーパラメータを表す Pydantic モデル。 -デフォルト値は configs/configs_jp_extra.json 内の定義と同一で、 +デフォルト値は configs/configs_jp_extra.json 内の定義と概ね同一で、 万が一ロードした config.json に存在しないキーがあった際のフェイルセーフとして適用される。 """ from pathlib import Path from typing import Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -class __HyperParametersTrain(BaseModel): +class HyperParametersTrain(BaseModel): log_interval: int = 200 eval_interval: int = 1000 seed: int = 42 @@ -36,7 +36,8 @@ class __HyperParametersTrain(BaseModel): freeze_style: bool = False freeze_decoder: bool = False -class __HyperParametersData(BaseModel): + +class HyperParametersData(BaseModel): use_jp_extra: bool = True training_files: str = "Data/dummy/train.list" validation_files: str = "Data/dummy/val.list" @@ -59,7 +60,8 @@ class __HyperParametersData(BaseModel): "Neutral": 0, } -class __HyperParametersModel(BaseModel): + +class HyperParametersModel(BaseModel): use_spk_conditioned_encoder: bool = True use_noise_scaled_mas: bool = True use_mel_posterior_encoder: bool = False @@ -93,12 +95,21 @@ class __HyperParametersModel(BaseModel): "initial_channel": 64 } + class HyperParameters(BaseModel): - version: str = "2.0-JP-Extra" model_name: str = 'dummy' - train: __HyperParametersTrain - data: __HyperParametersData - model: __HyperParametersModel + version: str = "2.0-JP-Extra" + train: HyperParametersTrain + data: HyperParametersData + model: HyperParametersModel + + # 以下は学習時にのみ動的に設定されるパラメータ (通常 config.json には存在しない) + model_dir: Optional[str] = None + speedup: bool = False + repo_id: Optional[str] = None + + # model_ 以下を Pydantic の保護対象から除外する + model_config = ConfigDict(protected_namespaces=()) @staticmethod @@ -112,5 +123,6 @@ def load_from_json(json_path: Union[str, Path]) -> "HyperParameters": Returns: HyperParameters: ハイパーパラメータ """ + with open(json_path, "r") as f: return HyperParameters.model_validate_json(f.read()) diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index 7394d0b68..93bf8c285 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -1,34 +1,81 @@ +from typing import Any, cast, Optional, Union + import torch -from typing import Optional +from numpy.typing import NDArray from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger from style_bert_vits2.models import commons from style_bert_vits2.models import utils +from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra from style_bert_vits2.nlp import clean_text, cleaned_text_to_sequence, extract_bert_feature from style_bert_vits2.nlp.symbols import SYMBOLS -def get_net_g(model_path: str, version: str, device: str, hps): +def get_net_g(model_path: str, version: str, device: str, hps: HyperParameters): if version.endswith("JP-Extra"): logger.info("Using JP-Extra model") net_g = SynthesizerTrnJPExtra( - len(SYMBOLS), - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **hps.model, + n_vocab = len(SYMBOLS), + spec_channels = hps.data.filter_length // 2 + 1, + segment_size = hps.train.segment_size // hps.data.hop_length, + n_speakers = hps.data.n_speakers, + # hps.model 以下のすべての値を引数に渡す + use_spk_conditioned_encoder = hps.model.use_spk_conditioned_encoder, + use_noise_scaled_mas = hps.model.use_noise_scaled_mas, + use_mel_posterior_encoder = hps.model.use_mel_posterior_encoder, + use_duration_discriminator = hps.model.use_duration_discriminator, + use_wavlm_discriminator = hps.model.use_wavlm_discriminator, + inter_channels = hps.model.inter_channels, + hidden_channels = hps.model.hidden_channels, + filter_channels = hps.model.filter_channels, + n_heads = hps.model.n_heads, + n_layers = hps.model.n_layers, + kernel_size = hps.model.kernel_size, + p_dropout = hps.model.p_dropout, + resblock = hps.model.resblock, + resblock_kernel_sizes = hps.model.resblock_kernel_sizes, + resblock_dilation_sizes = hps.model.resblock_dilation_sizes, + upsample_rates = hps.model.upsample_rates, + upsample_initial_channel = hps.model.upsample_initial_channel, + upsample_kernel_sizes = hps.model.upsample_kernel_sizes, + n_layers_q = hps.model.n_layers_q, + use_spectral_norm = hps.model.use_spectral_norm, + gin_channels = hps.model.gin_channels, + slm = hps.model.slm, ).to(device) else: logger.info("Using normal model") net_g = SynthesizerTrn( - len(SYMBOLS), - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, + n_vocab = len(SYMBOLS), + spec_channels = hps.data.filter_length // 2 + 1, + segment_size = hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, - **hps.model, + # hps.model 以下のすべての値を引数に渡す + use_spk_conditioned_encoder = hps.model.use_spk_conditioned_encoder, + use_noise_scaled_mas = hps.model.use_noise_scaled_mas, + use_mel_posterior_encoder = hps.model.use_mel_posterior_encoder, + use_duration_discriminator = hps.model.use_duration_discriminator, + use_wavlm_discriminator = hps.model.use_wavlm_discriminator, + inter_channels = hps.model.inter_channels, + hidden_channels = hps.model.hidden_channels, + filter_channels = hps.model.filter_channels, + n_heads = hps.model.n_heads, + n_layers = hps.model.n_layers, + kernel_size = hps.model.kernel_size, + p_dropout = hps.model.p_dropout, + resblock = hps.model.resblock, + resblock_kernel_sizes = hps.model.resblock_kernel_sizes, + resblock_dilation_sizes = hps.model.resblock_dilation_sizes, + upsample_rates = hps.model.upsample_rates, + upsample_initial_channel = hps.model.upsample_initial_channel, + upsample_kernel_sizes = hps.model.upsample_kernel_sizes, + n_layers_q = hps.model.n_layers_q, + use_spectral_norm = hps.model.use_spectral_norm, + gin_channels = hps.model.gin_channels, + slm = hps.model.slm, ).to(device) net_g.state_dict() _ = net_g.eval() @@ -44,7 +91,7 @@ def get_net_g(model_path: str, version: str, device: str, hps): def get_text( text: str, language_str: Languages, - hps, + hps: HyperParameters, device: str, assist_text: Optional[str] = None, assist_text_weight: float = 0.7, @@ -111,15 +158,15 @@ def get_text( def infer( text: str, - style_vec, + style_vec: NDArray[Any], sdp_ratio: float, noise_scale: float, noise_scale_w: float, length_scale: float, sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id language: Languages, - hps, - net_g, + hps: HyperParameters, + net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra], device: str, skip_start: bool = False, skip_end: bool = False, @@ -159,25 +206,25 @@ def infer( ja_bert = ja_bert.to(device).unsqueeze(0) en_bert = en_bert.to(device).unsqueeze(0) x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) - style_vec = torch.from_numpy(style_vec).to(device).unsqueeze(0) + style_vec_tensor = torch.from_numpy(style_vec).to(device).unsqueeze(0) del phones sid_tensor = torch.LongTensor([sid]).to(device) if is_jp_extra: - output = net_g.infer( + output = cast(SynthesizerTrnJPExtra, net_g).infer( x_tst, x_tst_lengths, sid_tensor, tones, lang_ids, ja_bert, - style_vec=style_vec, + style_vec=style_vec_tensor, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, ) else: - output = net_g.infer( + output = cast(SynthesizerTrn, net_g).infer( x_tst, x_tst_lengths, sid_tensor, @@ -186,7 +233,7 @@ def infer( bert, ja_bert, en_bert, - style_vec=style_vec, + style_vec=style_vec_tensor, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, @@ -209,110 +256,5 @@ def infer( return audio -def infer_multilang( - text: str, - style_vec, - sdp_ratio: float, - noise_scale: float, - noise_scale_w: float, - length_scale: float, - sid: int, - language: Languages, - hps, - net_g, - device: str, - skip_start: bool = False, - skip_end: bool = False, -): - bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], [] - # emo = get_emo_(reference_audio, emotion, sid) - # if isinstance(reference_audio, np.ndarray): - # emo = get_clap_audio_feature(reference_audio, device) - # else: - # emo = get_clap_text_feature(emotion, device) - # emo = torch.squeeze(emo, dim=1) - for idx, (txt, lang) in enumerate(zip(text, language)): - _skip_start = (idx != 0) or (skip_start and idx == 0) - _skip_end = (idx != len(language) - 1) or skip_end - ( - temp_bert, - temp_ja_bert, - temp_en_bert, - temp_phones, - temp_tones, - temp_lang_ids, - ) = get_text(txt, lang, hps, device) # type: ignore - if _skip_start: - temp_bert = temp_bert[:, 3:] - temp_ja_bert = temp_ja_bert[:, 3:] - temp_en_bert = temp_en_bert[:, 3:] - temp_phones = temp_phones[3:] - temp_tones = temp_tones[3:] - temp_lang_ids = temp_lang_ids[3:] - if _skip_end: - temp_bert = temp_bert[:, :-2] - temp_ja_bert = temp_ja_bert[:, :-2] - temp_en_bert = temp_en_bert[:, :-2] - temp_phones = temp_phones[:-2] - temp_tones = temp_tones[:-2] - temp_lang_ids = temp_lang_ids[:-2] - bert.append(temp_bert) - ja_bert.append(temp_ja_bert) - en_bert.append(temp_en_bert) - phones.append(temp_phones) - tones.append(temp_tones) - lang_ids.append(temp_lang_ids) - bert = torch.concatenate(bert, dim=1) - ja_bert = torch.concatenate(ja_bert, dim=1) - en_bert = torch.concatenate(en_bert, dim=1) - phones = torch.concatenate(phones, dim=0) - tones = torch.concatenate(tones, dim=0) - lang_ids = torch.concatenate(lang_ids, dim=0) - with torch.no_grad(): - x_tst = phones.to(device).unsqueeze(0) - tones = tones.to(device).unsqueeze(0) - lang_ids = lang_ids.to(device).unsqueeze(0) - bert = bert.to(device).unsqueeze(0) - ja_bert = ja_bert.to(device).unsqueeze(0) - en_bert = en_bert.to(device).unsqueeze(0) - # emo = emo.to(device).unsqueeze(0) - x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) - del phones - speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device) - audio = ( - net_g.infer( - x_tst, - x_tst_lengths, - speakers, - tones, - lang_ids, - bert, - ja_bert, - en_bert, - style_vec=style_vec, - sdp_ratio=sdp_ratio, - noise_scale=noise_scale, - noise_scale_w=noise_scale_w, - length_scale=length_scale, - )[0][0, 0] - .data.cpu() - .float() - .numpy() - ) - del ( - x_tst, - tones, - lang_ids, - bert, - x_tst_lengths, - speakers, - ja_bert, - en_bert, - ) # , emo - if torch.cuda.is_available(): - torch.cuda.empty_cache() - return audio - - class InvalidToneError(ValueError): pass diff --git a/style_bert_vits2/models/models.py b/style_bert_vits2/models/models.py index b4e1eb718..4efd10059 100644 --- a/style_bert_vits2/models/models.py +++ b/style_bert_vits2/models/models.py @@ -983,10 +983,10 @@ def infer( en_bert, style_vec, noise_scale=0.667, - length_scale=1, + length_scale=1.0, noise_scale_w=0.8, max_len=None, - sdp_ratio=0, + sdp_ratio=0.0, y=None, ): # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert) diff --git a/style_bert_vits2/models/models_jp_extra.py b/style_bert_vits2/models/models_jp_extra.py index 7bae8b45d..3a43d5191 100644 --- a/style_bert_vits2/models/models_jp_extra.py +++ b/style_bert_vits2/models/models_jp_extra.py @@ -1029,10 +1029,10 @@ def infer( bert, style_vec, noise_scale=0.667, - length_scale=1, + length_scale=1.0, noise_scale_w=0.8, max_len=None, - sdp_ratio=0, + sdp_ratio=0.0, y=None, ): # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert) diff --git a/style_bert_vits2/models/utils.py b/style_bert_vits2/models/utils.py index 9a7947f0f..37ea269dd 100644 --- a/style_bert_vits2/models/utils.py +++ b/style_bert_vits2/models/utils.py @@ -355,45 +355,3 @@ def check_git_hash(model_dir): ) else: open(path, "w").write(cur_hash) - - -def get_hparams_from_file(config_path): - # print("config_path: ", config_path) - with open(config_path, "r", encoding="utf-8") as f: - data = f.read() - config = json.loads(data) - - hparams = HParams(**config) - return hparams - - -class HParams: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - if type(v) == dict: - v = HParams(**v) - self[k] = v - - def keys(self): - return self.__dict__.keys() - - def items(self): - return self.__dict__.items() - - def values(self): - return self.__dict__.values() - - def __len__(self): - return len(self.__dict__) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - return setattr(self, key, value) - - def __contains__(self, key): - return key in self.__dict__ - - def __repr__(self): - return self.__dict__.__repr__() diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index bf04fe24b..86c8425e5 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -1,11 +1,12 @@ import warnings from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union import gradio as gr import numpy as np import torch from gradio.processing_utils import convert_to_16_bit_wav +from numpy.typing import NDArray from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, @@ -17,15 +18,22 @@ DEFAULT_SPLIT_INTERVAL, DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, + Languages, ) -from style_bert_vits2.models import utils +from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.infer import get_net_g, infer from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra from style_bert_vits2.logging import logger -def adjust_voice(fs, wave, pitch_scale, intonation_scale): +def adjust_voice( + fs: int, + wave: NDArray[Any], + pitch_scale: float, + intonation_scale: float, +) -> tuple[int, NDArray[Any]]: + if pitch_scale == 1.0 and intonation_scale == 1.0: # 初期値の場合は、音質劣化を避けるためにそのまま返す return fs, wave @@ -37,15 +45,17 @@ def adjust_voice(fs, wave, pitch_scale, intonation_scale): "pyworld is not installed. Please install it by `pip install pyworld`" ) - # pyworldでf0を加工して合成 - # pyworldよりもよいのがあるかもしれないが…… + # pyworld で f0 を加工して合成 + # pyworld よりもよいのがあるかもしれないが…… + ## pyworld は Cython で書かれているが、スタブファイルがないため型補完が全く効かない… wave = wave.astype(np.double) - f0, t = pyworld.harvest(wave, fs) + # 質が高そうだしとりあえずharvestにしておく + f0, t = pyworld.harvest(wave, fs) # type: ignore - sp = pyworld.cheaptrick(wave, f0, t, fs) - ap = pyworld.d4c(wave, f0, t, fs) + sp = pyworld.cheaptrick(wave, f0, t, fs) # type: ignore + ap = pyworld.d4c(wave, f0, t, fs) # type: ignore non_zero_f0 = [f for f in f0 if f != 0] f0_mean = sum(non_zero_f0) / len(non_zero_f0) @@ -55,7 +65,7 @@ def adjust_voice(fs, wave, pitch_scale, intonation_scale): continue f0[i] = pitch_scale * f0_mean + intonation_scale * (f - f0_mean) - wave = pyworld.synthesize(f0, sp, ap, fs) + wave = pyworld.synthesize(f0, sp, ap, fs) # type: ignore return fs, wave @@ -67,7 +77,7 @@ def __init__( self.config_path: Path = config_path self.style_vec_path: Path = style_vec_path self.device: str = device - self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path) + self.hps: HyperParameters = HyperParameters.load_from_json(self.config_path) self.spk2id: dict[str, int] = self.hps.data.spk2id self.id2spk: dict[int, str] = {v: k for k, v in self.spk2id.items()} @@ -81,7 +91,7 @@ def __init__( f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})" ) - self.style_vectors: np.ndarray = np.load(self.style_vec_path) + self.style_vectors: NDArray[Any] = np.load(self.style_vec_path) if self.style_vectors.shape[0] != self.num_styles: raise ValueError( f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})" @@ -97,7 +107,7 @@ def load_net_g(self): hps=self.hps, ) - def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray: + def get_style_vector(self, style_id: int, weight: float = 1.0) -> NDArray[Any]: mean = self.style_vectors[0] style_vec = self.style_vectors[style_id] style_vec = mean + (style_vec - mean) * weight @@ -105,7 +115,7 @@ def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray: def get_style_vector_from_audio( self, audio_path: str, weight: float = 1.0 - ) -> np.ndarray: + ) -> NDArray[Any]: from style_gen import get_style_vector xvec = get_style_vector(audio_path) @@ -116,7 +126,7 @@ def get_style_vector_from_audio( def infer( self, text: str, - language: str = "JP", + language: Languages = Languages.JP, sid: int = 0, reference_audio_path: Optional[str] = None, sdp_ratio: float = DEFAULT_SDP_RATIO, @@ -133,7 +143,7 @@ def infer( given_tone: Optional[list[int]] = None, pitch_scale: float = 1.0, intonation_scale: float = 1.0, - ) -> tuple[int, np.ndarray]: + ) -> tuple[int, NDArray[Any]]: logger.info(f"Start generating audio data from text:\n{text}") if language != "JP" and self.hps.version.endswith("JP-Extra"): raise ValueError( @@ -146,6 +156,7 @@ def infer( if self.net_g is None: self.load_net_g() + assert self.net_g is not None if reference_audio_path is None: style_id = self.style2id[style] style_vector = self.get_style_vector(style_id, style_weight) @@ -246,19 +257,17 @@ def refresh(self): continue self.model_files_dict[model_dir.name] = model_files self.model_names.append(model_dir.name) - hps = utils.get_hparams_from_file(config_path) + hps = HyperParameters.load_from_json(config_path) style2id: dict[str, int] = hps.data.style2id styles = list(style2id.keys()) spk2id: dict[str, int] = hps.data.spk2id speakers = list(spk2id.keys()) - self.models_info.append( - { - "name": model_dir.name, - "files": [str(f) for f in model_files], - "styles": styles, - "speakers": speakers, - } - ) + self.models_info.append({ + "name": model_dir.name, + "files": [str(f) for f in model_files], + "styles": styles, + "speakers": speakers, + }) def load_model(self, model_name: str, model_path_str: str): model_path = Path(model_path_str) @@ -291,9 +300,9 @@ def load_model_gr( speakers = list(self.current_model.spk2id.keys()) styles = list(self.current_model.style2id.keys()) return ( - gr.Dropdown(choices=styles, value=styles[0]), + gr.Dropdown(choices=styles, value=styles[0]), # type: ignore gr.Button(interactive=True, value="音声合成"), - gr.Dropdown(choices=speakers, value=speakers[0]), + gr.Dropdown(choices=speakers, value=speakers[0]), # type: ignore ) self.current_model = Model( model_path=model_path, @@ -304,21 +313,21 @@ def load_model_gr( speakers = list(self.current_model.spk2id.keys()) styles = list(self.current_model.style2id.keys()) return ( - gr.Dropdown(choices=styles, value=styles[0]), + gr.Dropdown(choices=styles, value=styles[0]), # type: ignore gr.Button(interactive=True, value="音声合成"), - gr.Dropdown(choices=speakers, value=speakers[0]), + gr.Dropdown(choices=speakers, value=speakers[0]), # type: ignore ) def update_model_files_gr(self, model_name: str) -> gr.Dropdown: model_files = self.model_files_dict[model_name] - return gr.Dropdown(choices=model_files, value=model_files[0]) + return gr.Dropdown(choices=model_files, value=model_files[0]) # type: ignore def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]: self.refresh() initial_model_name = self.model_names[0] initial_model_files = self.model_files_dict[initial_model_name] return ( - gr.Dropdown(choices=self.model_names, value=initial_model_name), - gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]), + gr.Dropdown(choices=self.model_names, value=initial_model_name), # type: ignore + gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]), # type: ignore gr.Button(interactive=False), # For tts_button ) diff --git a/style_gen.py b/style_gen.py index d7f692f50..ec0b50778 100644 --- a/style_gen.py +++ b/style_gen.py @@ -8,6 +8,7 @@ from style_bert_vits2.logging import logger from style_bert_vits2.models import utils +from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT from config import config @@ -72,7 +73,7 @@ def save_average_style_vector(style_vectors, filename="style_vectors.npy"): config_path = args.config num_processes = args.num_processes - hps = utils.get_hparams_from_file(config_path) + hps = HyperParameters.load_from_json(config_path) device = config.style_gen_config.device diff --git a/train_ms.py b/train_ms.py index 977b393ee..b2cb02bf4 100644 --- a/train_ms.py +++ b/train_ms.py @@ -26,6 +26,7 @@ from style_bert_vits2.logging import logger from style_bert_vits2.models import commons from style_bert_vits2.models import utils +from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.models import ( DurationDiscriminator, MultiPeriodDiscriminator, @@ -130,7 +131,7 @@ def run(): local_rank = int(os.environ["LOCAL_RANK"]) n_gpus = dist.get_world_size() - hps = utils.get_hparams_from_file(args.config) + hps = HyperParameters.load_from_json(args.config) # This is needed because we have to pass values to `train_and_evaluate()` hps.model_dir = model_dir hps.speedup = args.speedup @@ -288,7 +289,29 @@ def run(): n_speakers=hps.data.n_speakers, mas_noise_scale_initial=mas_noise_scale_initial, noise_scale_delta=noise_scale_delta, - **hps.model, + # hps.model 以下のすべての値を引数に渡す + use_spk_conditioned_encoder = hps.model.use_spk_conditioned_encoder, + use_noise_scaled_mas = hps.model.use_noise_scaled_mas, + use_mel_posterior_encoder = hps.model.use_mel_posterior_encoder, + use_duration_discriminator = hps.model.use_duration_discriminator, + use_wavlm_discriminator = hps.model.use_wavlm_discriminator, + inter_channels = hps.model.inter_channels, + hidden_channels = hps.model.hidden_channels, + filter_channels = hps.model.filter_channels, + n_heads = hps.model.n_heads, + n_layers = hps.model.n_layers, + kernel_size = hps.model.kernel_size, + p_dropout = hps.model.p_dropout, + resblock = hps.model.resblock, + resblock_kernel_sizes = hps.model.resblock_kernel_sizes, + resblock_dilation_sizes = hps.model.resblock_dilation_sizes, + upsample_rates = hps.model.upsample_rates, + upsample_initial_channel = hps.model.upsample_initial_channel, + upsample_kernel_sizes = hps.model.upsample_kernel_sizes, + n_layers_q = hps.model.n_layers_q, + use_spectral_norm = hps.model.use_spectral_norm, + gin_channels = hps.model.gin_channels, + slm = hps.model.slm, ).cuda(local_rank) if getattr(hps.train, "freeze_ZH_bert", False): @@ -547,7 +570,7 @@ def train_and_evaluate( rank, local_rank, epoch, - hps, + hps: HyperParameters, nets, optims, schedulers, diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index 3b1c01aad..a464d29b9 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -26,6 +26,7 @@ from style_bert_vits2.logging import logger from style_bert_vits2.models import commons from style_bert_vits2.models import utils +from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.models_jp_extra import ( DurationDiscriminator, MultiPeriodDiscriminator, @@ -129,7 +130,7 @@ def run(): local_rank = int(os.environ["LOCAL_RANK"]) n_gpus = dist.get_world_size() - hps = utils.get_hparams_from_file(args.config) + hps = HyperParameters.load_from_json(args.config) # This is needed because we have to pass values to `train_and_evaluate() hps.model_dir = model_dir hps.speedup = args.speedup @@ -298,7 +299,29 @@ def run(): n_speakers=hps.data.n_speakers, mas_noise_scale_initial=mas_noise_scale_initial, noise_scale_delta=noise_scale_delta, - **hps.model, + # hps.model 以下のすべての値を引数に渡す + use_spk_conditioned_encoder = hps.model.use_spk_conditioned_encoder, + use_noise_scaled_mas = hps.model.use_noise_scaled_mas, + use_mel_posterior_encoder = hps.model.use_mel_posterior_encoder, + use_duration_discriminator = hps.model.use_duration_discriminator, + use_wavlm_discriminator = hps.model.use_wavlm_discriminator, + inter_channels = hps.model.inter_channels, + hidden_channels = hps.model.hidden_channels, + filter_channels = hps.model.filter_channels, + n_heads = hps.model.n_heads, + n_layers = hps.model.n_layers, + kernel_size = hps.model.kernel_size, + p_dropout = hps.model.p_dropout, + resblock = hps.model.resblock, + resblock_kernel_sizes = hps.model.resblock_kernel_sizes, + resblock_dilation_sizes = hps.model.resblock_dilation_sizes, + upsample_rates = hps.model.upsample_rates, + upsample_initial_channel = hps.model.upsample_initial_channel, + upsample_kernel_sizes = hps.model.upsample_kernel_sizes, + n_layers_q = hps.model.n_layers_q, + use_spectral_norm = hps.model.use_spectral_norm, + gin_channels = hps.model.gin_channels, + slm = hps.model.slm, ).cuda(local_rank) if getattr(hps.train, "freeze_JP_bert", False): logger.info("Freezing (JP) bert encoder !!!") From b9e486e72a3fedcbe91f60f32ffdfc213d11c2ec Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 16:44:56 +0000 Subject: [PATCH 058/148] Fix: extend timeout for style_bert_vits2.nlp.japanese.pyopenjtalk_worker.worker_client ref: https://github.com/litagin02/Style-Bert-VITS2/pull/91 --- .../nlp/japanese/pyopenjtalk_worker/worker_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py index b87a93757..4507af414 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py @@ -11,8 +11,8 @@ class WorkerClient: def __init__(self, port: int) -> None: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # 5: timeout - sock.settimeout(5) + # 60: timeout + sock.settimeout(60) sock.connect((socket.gethostname(), port)) self.sock = sock From 30ea08d6ea662b26f9f57698465027eda687f186 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 18:22:40 +0000 Subject: [PATCH 059/148] Fix: forgot to write pyopenjtalk.initialize() --- style_bert_vits2/nlp/japanese/g2p.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/style_bert_vits2/nlp/japanese/g2p.py b/style_bert_vits2/nlp/japanese/g2p.py index 839fb16a6..70b4e811c 100644 --- a/style_bert_vits2/nlp/japanese/g2p.py +++ b/style_bert_vits2/nlp/japanese/g2p.py @@ -249,6 +249,10 @@ def _numeric_feature_by_regex(regex: str, s: str) -> int: return -50 return int(match.group(1)) + # pyopenjtalk_worker を初期化 + ## 一度 worker を起動すれば、明示的に終了するかプロセス終了まで同一の worker に接続される + pyopenjtalk.initialize() + labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text)) N = len(labels) From d22a11ebb2fb71bde61155305fe7842aa0201a9c Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 22:09:47 +0000 Subject: [PATCH 060/148] Fix: a bug that prevented speech synthesis in app.py --- app.py | 1 + server_editor.py | 8 +++++--- server_fastapi.py | 17 +++++++++++++++++ .../japanese/pyopenjtalk_worker/__init__.py | 7 ++++++- webui/inference.py | 18 +++++++++++++++++- 5 files changed, 46 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 5d05d7cfc..c2237b590 100644 --- a/app.py +++ b/app.py @@ -15,6 +15,7 @@ create_train_app, ) + # Get path settings with Path("configs/paths.yml").open("r", encoding="utf-8") as f: path_config: dict[str, str] = yaml.safe_load(f.read()) diff --git a/server_editor.py b/server_editor.py index 022a95adf..75fa4727c 100644 --- a/server_editor.py +++ b/server_editor.py @@ -149,9 +149,11 @@ def save_last_download(latest_release): # 以降はAPIの設定 # 最初に pyopenjtalk の辞書を更新 +## pyopenjtalk_worker の起動も同時に行われる update_dict() -# 単語分割に使う BERT モデル/トークナイザーを事前にロードしておく +# 事前に BERT モデル/トークナイザーをロードしておく +## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い ## server_editor.py は日本語にしか対応していないため、日本語の BERT モデル/トークナイザーのみロードする bert_models.load_model(Languages.JP) bert_models.load_tokenizer(Languages.JP) @@ -301,7 +303,7 @@ def synthesis(request: SynthesisRequest): ) sr, audio = model.infer( text=text, - language=request.language.value, + language=request.language, sdp_ratio=request.sdpRatio, noise=request.noise, noisew=request.noisew, @@ -361,7 +363,7 @@ def multi_synthesis(request: MultiSynthesisRequest): tone = [t for _, t in phone_tone] sr, audio = model.infer( text=text, - language=req.language.value, + language=req.language, sdp_ratio=req.sdpRatio, noise=req.noise, noisew=req.noisew, diff --git a/server_fastapi.py b/server_fastapi.py index 833af4ab5..d8da95611 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -34,11 +34,28 @@ Languages, ) from style_bert_vits2.logging import logger +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.tts_model import Model, ModelHolder ln = config.server_config.language +# pyopenjtalk_worker を起動 +## Gradio はマルチスレッドだが、initialize() 内部で利用されている signal はマルチスレッドから設定できない +## さらに起動には若干時間がかかるため、事前に起動しておいた方が体験が良い +pyopenjtalk.initialize() + +# 事前に BERT モデル/トークナイザーをロードしておく +## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い +bert_models.load_model(Languages.JP) +bert_models.load_tokenizer(Languages.JP) +bert_models.load_model(Languages.EN) +bert_models.load_tokenizer(Languages.EN) +bert_models.load_model(Languages.ZH) +bert_models.load_tokenizer(Languages.ZH) + + def raise_validation_error(msg: str, param: str): logger.warning(f"Validation error: {msg}") raise HTTPException( diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py index 670461de5..6c8712313 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py @@ -105,7 +105,12 @@ def initialize(port: int = WORKER_PORT) -> None: def signal_handler(signum: int, frame: Any): terminate() - signal.signal(signal.SIGTERM, signal_handler) + try: + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + except ValueError: + # signal only works in main thread + pass # top-level declaration diff --git a/webui/inference.py b/webui/inference.py index 02d6f3abb..ef1f97f1c 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -19,12 +19,28 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.models.infer import InvalidToneError +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.nlp.japanese.normalizer import normalize_text from style_bert_vits2.tts_model import ModelHolder -languages = [l.value for l in Languages] +# pyopenjtalk_worker を起動 +## Gradio はマルチスレッドだが、initialize() 内部で利用されている signal はマルチスレッドから設定できない +## さらに起動には若干時間がかかるため、事前に起動しておいた方が体験が良い +pyopenjtalk.initialize() + +# 事前に BERT モデル/トークナイザーをロードしておく +## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い +bert_models.load_model(Languages.JP) +bert_models.load_tokenizer(Languages.JP) +bert_models.load_model(Languages.EN) +bert_models.load_tokenizer(Languages.EN) +bert_models.load_model(Languages.ZH) +bert_models.load_tokenizer(Languages.ZH) + +languages = [l.value for l in Languages] initial_text = "こんにちは、初めまして。あなたの名前はなんていうの?" From b98fecb46d18108ea1c5fdd36a256dcf7e8f655d Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 22:14:39 +0000 Subject: [PATCH 061/148] Add: --host/--port option to app.py to allow specifying listening host/port --- app.py | 4 +++- webui/train.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index c2237b590..f0f94d036 100644 --- a/app.py +++ b/app.py @@ -24,6 +24,8 @@ parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda") +parser.add_argument("--host", type=str, default="127.0.0.1") +parser.add_argument("--port", type=int, default=7860) parser.add_argument("--no_autolaunch", action="store_true") parser.add_argument("--share", action="store_true") @@ -49,4 +51,4 @@ create_merge_app(model_holder=model_holder) -app.launch(inbrowser=not args.no_autolaunch, share=args.share) +app.launch(server_name=args.host, server_port=args.port, inbrowser=not args.no_autolaunch, share=args.share) diff --git a/webui/train.py b/webui/train.py index fa29719e9..d83747149 100644 --- a/webui/train.py +++ b/webui/train.py @@ -792,5 +792,4 @@ def create_train_app(): outputs=[use_jp_extra_train], ) - # app.launch(inbrowser=not args.no_autolaunch, server_name=args.server_name) return app From 717ba7925f5fc76debbedfaf14a99c0226df88c9 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 22:27:21 +0000 Subject: [PATCH 062/148] Fix: app.py cannot be closed with Ctrl+C --- style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py index 6c8712313..3593e60a7 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py @@ -106,7 +106,6 @@ def signal_handler(signum: int, frame: Any): terminate() try: - signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) except ValueError: # signal only works in main thread From e1fad54c9950e001df4ef4a9439702c493c4fdde Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 22:58:40 +0000 Subject: [PATCH 063/148] Refactor: add type hints to attentions.py / modules.py / transforms.py I didn't add docstring because it is very technical code and I don't understand what is being implemented. --- style_bert_vits2/models/attentions.py | 110 +++++++++++---------- style_bert_vits2/models/modules.py | 137 +++++++++++++++----------- style_bert_vits2/models/transforms.py | 84 ++++++++-------- 3 files changed, 182 insertions(+), 149 deletions(-) diff --git a/style_bert_vits2/models/attentions.py b/style_bert_vits2/models/attentions.py index 6d43e0864..b262681b8 100644 --- a/style_bert_vits2/models/attentions.py +++ b/style_bert_vits2/models/attentions.py @@ -1,3 +1,5 @@ +from typing import Any, Optional + import math import torch from torch import nn @@ -7,7 +9,7 @@ class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): + def __init__(self, channels: int, eps: float = 1e-5): super().__init__() self.channels = channels self.eps = eps @@ -15,14 +17,14 @@ def __init__(self, channels, eps=1e-5): self.gamma = nn.Parameter(torch.ones(channels)) self.beta = nn.Parameter(torch.zeros(channels)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.transpose(1, -1) x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) return x.transpose(1, -1) -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): +@torch.jit.script # type: ignore +def fused_add_tanh_sigmoid_multiply(input_a: torch.Tensor, input_b: torch.Tensor, n_channels: list[int]) -> torch.Tensor: n_channels_int = n_channels[0] in_act = input_a + input_b t_act = torch.tanh(in_act[:, :n_channels_int, :]) @@ -34,15 +36,15 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): class Encoder(nn.Module): def __init__( self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - window_size=4, - isflow=True, - **kwargs + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int = 1, + p_dropout: float = 0.0, + window_size: int = 4, + isflow: bool = True, + **kwargs: Any ): super().__init__() self.hidden_channels = hidden_channels @@ -97,12 +99,13 @@ def __init__( ) self.norm_layers_2.append(LayerNorm(hidden_channels)) - def forward(self, x, x_mask, g=None): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask for i in range(self.n_layers): if i == self.cond_layer_idx and g is not None: g = self.spk_emb_linear(g.transpose(1, 2)) + assert g is not None g = g.transpose(1, 2) x = x + g x = x * x_mask @@ -120,15 +123,15 @@ def forward(self, x, x_mask, g=None): class Decoder(nn.Module): def __init__( self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - proximal_bias=False, - proximal_init=True, - **kwargs + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int = 1, + p_dropout: float = 0.0, + proximal_bias: bool = False, + proximal_init: bool = True, + **kwargs: Any ): super().__init__() self.hidden_channels = hidden_channels @@ -177,7 +180,7 @@ def __init__( ) self.norm_layers_2.append(LayerNorm(hidden_channels)) - def forward(self, x, x_mask, h, h_mask): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, h: torch.Tensor, h_mask: torch.Tensor): """ x: decoder input h: encoder output @@ -206,15 +209,15 @@ def forward(self, x, x_mask, h, h_mask): class MultiHeadAttention(nn.Module): def __init__( self, - channels, - out_channels, - n_heads, - p_dropout=0.0, - window_size=None, - heads_share=True, - block_length=None, - proximal_bias=False, - proximal_init=False, + channels: int, + out_channels: int, + n_heads: int, + p_dropout: float = 0.0, + window_size: Optional[int] = None, + heads_share: bool = True, + block_length: Optional[int] = None, + proximal_bias: bool = False, + proximal_init: bool = False, ): super().__init__() assert channels % n_heads == 0 @@ -255,9 +258,11 @@ def __init__( if proximal_init: with torch.no_grad(): self.conv_k.weight.copy_(self.conv_q.weight) + assert self.conv_k.bias is not None + assert self.conv_q.bias is not None self.conv_k.bias.copy_(self.conv_q.bias) - def forward(self, x, c, attn_mask=None): + def forward(self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: q = self.conv_q(x) k = self.conv_k(c) v = self.conv_v(c) @@ -267,7 +272,7 @@ def forward(self, x, c, attn_mask=None): x = self.conv_o(x) return x - def attention(self, query, key, value, mask=None): + def attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]: # reshape [b, d, t] -> [b, n_h, t, d_k] b, d, t_s, t_t = (*key.size(), query.size(2)) query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) @@ -318,7 +323,7 @@ def attention(self, query, key, value, mask=None): ) # [b, n_h, t_t, d_k] -> [b, d, t_t] return output, p_attn - def _matmul_with_relative_values(self, x, y): + def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ x: [b, h, l, m] y: [h or 1, m, d] @@ -327,7 +332,7 @@ def _matmul_with_relative_values(self, x, y): ret = torch.matmul(x, y.unsqueeze(0)) return ret - def _matmul_with_relative_keys(self, x, y): + def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ x: [b, h, l, d] y: [h or 1, m, d] @@ -336,8 +341,9 @@ def _matmul_with_relative_keys(self, x, y): ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) return ret - def _get_relative_embeddings(self, relative_embeddings, length): - 2 * self.window_size + 1 + def _get_relative_embeddings(self, relative_embeddings: torch.Tensor, length: int) -> torch.Tensor: + assert self.window_size is not None + 2 * self.window_size + 1 # type: ignore # Pad first before slice to avoid using cond ops. pad_length = max(length - (self.window_size + 1), 0) slice_start_position = max((self.window_size + 1) - length, 0) @@ -354,7 +360,7 @@ def _get_relative_embeddings(self, relative_embeddings, length): ] return used_relative_embeddings - def _relative_position_to_absolute_position(self, x): + def _relative_position_to_absolute_position(self, x: torch.Tensor) -> torch.Tensor: """ x: [b, h, l, 2*l-1] ret: [b, h, l, l] @@ -375,7 +381,7 @@ def _relative_position_to_absolute_position(self, x): ] return x_final - def _absolute_position_to_relative_position(self, x): + def _absolute_position_to_relative_position(self, x: torch.Tensor) -> torch.Tensor: """ x: [b, h, l, l] ret: [b, h, l, 2*l-1] @@ -391,7 +397,7 @@ def _absolute_position_to_relative_position(self, x): x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] return x_final - def _attention_bias_proximal(self, length): + def _attention_bias_proximal(self, length: int) -> torch.Tensor: """Bias for self-attention to encourage attention to close positions. Args: length: an integer scalar. @@ -406,13 +412,13 @@ def _attention_bias_proximal(self, length): class FFN(nn.Module): def __init__( self, - in_channels, - out_channels, - filter_channels, - kernel_size, - p_dropout=0.0, - activation=None, - causal=False, + in_channels: int, + out_channels: int, + filter_channels: int, + kernel_size: int, + p_dropout: float = 0.0, + activation: Optional[str] = None, + causal: bool = False, ): super().__init__() self.in_channels = in_channels @@ -432,7 +438,7 @@ def __init__( self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) self.drop = nn.Dropout(p_dropout) - def forward(self, x, x_mask): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor: x = self.conv_1(self.padding(x * x_mask)) if self.activation == "gelu": x = x * torch.sigmoid(1.702 * x) @@ -442,7 +448,7 @@ def forward(self, x, x_mask): x = self.conv_2(self.padding(x * x_mask)) return x * x_mask - def _causal_padding(self, x): + def _causal_padding(self, x: torch.Tensor) -> torch.Tensor: if self.kernel_size == 1: return x pad_l = self.kernel_size - 1 @@ -451,7 +457,7 @@ def _causal_padding(self, x): x = F.pad(x, commons.convert_pad_shape(padding)) return x - def _same_padding(self, x): + def _same_padding(self, x: torch.Tensor) -> torch.Tensor: if self.kernel_size == 1: return x pad_l = (self.kernel_size - 1) // 2 diff --git a/style_bert_vits2/models/modules.py b/style_bert_vits2/models/modules.py index eede771fb..df3807123 100644 --- a/style_bert_vits2/models/modules.py +++ b/style_bert_vits2/models/modules.py @@ -1,4 +1,5 @@ import math +from typing import Any, Optional, Union import torch from torch import nn @@ -15,7 +16,7 @@ class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): + def __init__(self, channels: int, eps: float = 1e-5): super().__init__() self.channels = channels self.eps = eps @@ -23,7 +24,7 @@ def __init__(self, channels, eps=1e-5): self.gamma = nn.Parameter(torch.ones(channels)) self.beta = nn.Parameter(torch.zeros(channels)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.transpose(1, -1) x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) return x.transpose(1, -1) @@ -32,12 +33,12 @@ def forward(self, x): class ConvReluNorm(nn.Module): def __init__( self, - in_channels, - hidden_channels, - out_channels, - kernel_size, - n_layers, - p_dropout, + in_channels: int, + hidden_channels: int, + out_channels: int, + kernel_size: int, + n_layers: int, + p_dropout: float, ): super().__init__() self.in_channels = in_channels @@ -69,9 +70,10 @@ def __init__( self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj.weight.data.zero_() + assert self.proj.bias is not None self.proj.bias.data.zero_() - def forward(self, x, x_mask): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor: x_org = x for i in range(self.n_layers): x = self.conv_layers[i](x * x_mask) @@ -86,7 +88,7 @@ class DDSConv(nn.Module): Dialted and Depth-Separable Convolution """ - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + def __init__(self, channels: int, kernel_size: int, n_layers: int, p_dropout: float = 0.0): super().__init__() self.channels = channels self.kernel_size = kernel_size @@ -115,7 +117,7 @@ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): self.norms_1.append(LayerNorm(channels)) self.norms_2.append(LayerNorm(channels)) - def forward(self, x, x_mask, g=None): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: if g is not None: x = x + g for i in range(self.n_layers): @@ -133,12 +135,12 @@ def forward(self, x, x_mask, g=None): class WN(torch.nn.Module): def __init__( self, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0, - p_dropout=0, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + n_layers: int, + gin_channels: int = 0, + p_dropout: float = 0, ): super(WN, self).__init__() assert kernel_size % 2 == 1 @@ -182,7 +184,7 @@ def __init__( res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) - def forward(self, x, x_mask, g=None, **kwargs): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, **kwargs: Any) -> torch.Tensor: output = torch.zeros_like(x) n_channels_tensor = torch.IntTensor([self.hidden_channels]) @@ -209,7 +211,7 @@ def forward(self, x, x_mask, g=None, **kwargs): output = output + res_skip_acts return output * x_mask - def remove_weight_norm(self): + def remove_weight_norm(self) -> None: if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer) for l in self.in_layers: @@ -219,7 +221,7 @@ def remove_weight_norm(self): class ResBlock1(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int, int] = (1, 3, 5)): super(ResBlock1, self).__init__() self.convs1 = nn.ModuleList( [ @@ -293,7 +295,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): ) self.convs2.apply(commons.init_weights) - def forward(self, x, x_mask=None): + def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None) -> torch.Tensor: for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, LRELU_SLOPE) if x_mask is not None: @@ -308,7 +310,7 @@ def forward(self, x, x_mask=None): x = x * x_mask return x - def remove_weight_norm(self): + def remove_weight_norm(self) -> None: for l in self.convs1: remove_weight_norm(l) for l in self.convs2: @@ -316,7 +318,7 @@ def remove_weight_norm(self): class ResBlock2(torch.nn.Module): - def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int] = (1, 3)): super(ResBlock2, self).__init__() self.convs = nn.ModuleList( [ @@ -344,7 +346,7 @@ def __init__(self, channels, kernel_size=3, dilation=(1, 3)): ) self.convs.apply(commons.init_weights) - def forward(self, x, x_mask=None): + def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None) -> torch.Tensor: for c in self.convs: xt = F.leaky_relu(x, LRELU_SLOPE) if x_mask is not None: @@ -355,13 +357,13 @@ def forward(self, x, x_mask=None): x = x * x_mask return x - def remove_weight_norm(self): + def remove_weight_norm(self) -> None: for l in self.convs: remove_weight_norm(l) class Log(nn.Module): - def forward(self, x, x_mask, reverse=False, **kwargs): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, reverse: bool = False, **kwargs: Any): if not reverse: y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask logdet = torch.sum(-y, [1, 2]) @@ -372,7 +374,13 @@ def forward(self, x, x_mask, reverse=False, **kwargs): class Flip(nn.Module): - def forward(self, x, *args, reverse=False, **kwargs): + def forward( + self, + x: torch.Tensor, + *args: Any, + reverse: bool = False, + **kwargs: Any, + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: x = torch.flip(x, [1]) if not reverse: logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) @@ -382,13 +390,19 @@ def forward(self, x, *args, reverse=False, **kwargs): class ElementwiseAffine(nn.Module): - def __init__(self, channels): + def __init__(self, channels: int): super().__init__() self.channels = channels self.m = nn.Parameter(torch.zeros(channels, 1)) self.logs = nn.Parameter(torch.zeros(channels, 1)) - def forward(self, x, x_mask, reverse=False, **kwargs): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + reverse: bool = False, + **kwargs: Any, + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if not reverse: y = self.m + torch.exp(self.logs) * x y = y * x_mask @@ -402,14 +416,14 @@ def forward(self, x, x_mask, reverse=False, **kwargs): class ResidualCouplingLayer(nn.Module): def __init__( self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - p_dropout=0, - gin_channels=0, - mean_only=False, + channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + n_layers: int, + p_dropout: float = 0, + gin_channels: int = 0, + mean_only: bool = False, ): assert channels % 2 == 0, "channels should be divisible by 2" super().__init__() @@ -432,9 +446,10 @@ def __init__( ) self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) self.post.weight.data.zero_() + assert self.post.bias is not None self.post.bias.data.zero_() - def forward(self, x, x_mask, g=None, reverse=False): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): x0, x1 = torch.split(x, [self.half_channels] * 2, 1) h = self.pre(x0) * x_mask h = self.enc(h, x_mask, g=g) @@ -459,12 +474,12 @@ def forward(self, x, x_mask, g=None, reverse=False): class ConvFlow(nn.Module): def __init__( self, - in_channels, - filter_channels, - kernel_size, - n_layers, - num_bins=10, - tail_bound=5.0, + in_channels: int, + filter_channels: int, + kernel_size: int, + n_layers: int, + num_bins: int = 10, + tail_bound: float = 5.0, ): super().__init__() self.in_channels = in_channels @@ -481,9 +496,10 @@ def __init__( filter_channels, self.half_channels * (num_bins * 3 - 1), 1 ) self.proj.weight.data.zero_() + assert self.proj.bias is not None self.proj.bias.data.zero_() - def forward(self, x, x_mask, g=None, reverse=False): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): x0, x1 = torch.split(x, [self.half_channels] * 2, 1) h = self.pre(x0) h = self.convs(h, x_mask, g=g) @@ -519,17 +535,17 @@ def forward(self, x, x_mask, g=None, reverse=False): class TransformerCouplingLayer(nn.Module): def __init__( self, - channels, - hidden_channels, - kernel_size, - n_layers, - n_heads, - p_dropout=0, - filter_channels=0, - mean_only=False, - wn_sharing_parameter=None, - gin_channels=0, - ): + channels: int, + hidden_channels: int, + kernel_size: int, + n_layers: int, + n_heads: int, + p_dropout: float = 0, + filter_channels: int = 0, + mean_only: bool = False, + wn_sharing_parameter: Optional[nn.Module] = None, + gin_channels: int = 0, + ) -> None: assert channels % 2 == 0, "channels should be divisible by 2" super().__init__() self.channels = channels @@ -556,9 +572,16 @@ def __init__( ) self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) self.post.weight.data.zero_() + assert self.post.bias is not None self.post.bias.data.zero_() - def forward(self, x, x_mask, g=None, reverse=False): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: x0, x1 = torch.split(x, [self.half_channels] * 2, 1) h = self.pre(x0) * x_mask h = self.enc(h, x_mask, g=g) diff --git a/style_bert_vits2/models/transforms.py b/style_bert_vits2/models/transforms.py index a11f799e0..61306adc7 100644 --- a/style_bert_vits2/models/transforms.py +++ b/style_bert_vits2/models/transforms.py @@ -1,7 +1,8 @@ -import torch -from torch.nn import functional as F +from typing import Optional import numpy as np +import torch +from torch.nn import functional as F DEFAULT_MIN_BIN_WIDTH = 1e-3 @@ -10,17 +11,18 @@ def piecewise_rational_quadratic_transform( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails=None, - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): + inputs: torch.Tensor, + unnormalized_widths: torch.Tensor, + unnormalized_heights: torch.Tensor, + unnormalized_derivatives: torch.Tensor, + inverse: bool = False, + tails: Optional[str] = None, + tail_bound: float = 1.0, + min_bin_width: float = DEFAULT_MIN_BIN_WIDTH, + min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT, + min_derivative: float = DEFAULT_MIN_DERIVATIVE, +) -> tuple[torch.Tensor, torch.Tensor]: + if tails is None: spline_fn = rational_quadratic_spline spline_kwargs = {} @@ -37,28 +39,29 @@ def piecewise_rational_quadratic_transform( min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, - **spline_kwargs + **spline_kwargs # type: ignore ) return outputs, logabsdet -def searchsorted(bin_locations, inputs, eps=1e-6): +def searchsorted(bin_locations: torch.Tensor, inputs: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: bin_locations[..., -1] += eps return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 def unconstrained_rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails="linear", - tail_bound=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): + inputs: torch.Tensor, + unnormalized_widths: torch.Tensor, + unnormalized_heights: torch.Tensor, + unnormalized_derivatives: torch.Tensor, + inverse: bool = False, + tails: str = "linear", + tail_bound: float = 1.0, + min_bin_width: float = DEFAULT_MIN_BIN_WIDTH, + min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT, + min_derivative: float = DEFAULT_MIN_DERIVATIVE, +) -> tuple[torch.Tensor, torch.Tensor]: + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) outside_interval_mask = ~inside_interval_mask @@ -74,7 +77,7 @@ def unconstrained_rational_quadratic_spline( outputs[outside_interval_mask] = inputs[outside_interval_mask] logabsdet[outside_interval_mask] = 0 else: - raise RuntimeError("{} tails are not implemented.".format(tails)) + raise RuntimeError(f"{tails} tails are not implemented.") ( outputs[inside_interval_mask], @@ -98,19 +101,20 @@ def unconstrained_rational_quadratic_spline( def rational_quadratic_spline( - inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - left=0.0, - right=1.0, - bottom=0.0, - top=1.0, - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE, -): + inputs: torch.Tensor, + unnormalized_widths: torch.Tensor, + unnormalized_heights: torch.Tensor, + unnormalized_derivatives: torch.Tensor, + inverse: bool = False, + left: float = 0.0, + right: float = 1.0, + bottom: float = 0.0, + top: float = 1.0, + min_bin_width: float = DEFAULT_MIN_BIN_WIDTH, + min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT, + min_derivative: float = DEFAULT_MIN_DERIVATIVE, +) -> tuple[torch.Tensor, torch.Tensor]: + if torch.min(inputs) < left or torch.max(inputs) > right: raise ValueError("Input to a transform is not within its domain") From 8feef04cef8e49bbb895d2c954d112ad43c8d063 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Fri, 8 Mar 2024 23:52:44 +0000 Subject: [PATCH 064/148] Refactor: add type hints to models.py / models_jp_extra.py I didn't add docstring because it is very technical code and I don't understand what is being implemented. --- style_bert_vits2/models/attentions.py | 20 +- style_bert_vits2/models/models.py | 384 ++++++++++++--------- style_bert_vits2/models/models_jp_extra.py | 369 ++++++++++++-------- style_bert_vits2/models/modules.py | 42 ++- 4 files changed, 490 insertions(+), 325 deletions(-) diff --git a/style_bert_vits2/models/attentions.py b/style_bert_vits2/models/attentions.py index b262681b8..03b238d8a 100644 --- a/style_bert_vits2/models/attentions.py +++ b/style_bert_vits2/models/attentions.py @@ -9,7 +9,7 @@ class LayerNorm(nn.Module): - def __init__(self, channels: int, eps: float = 1e-5): + def __init__(self, channels: int, eps: float = 1e-5) -> None: super().__init__() self.channels = channels self.eps = eps @@ -45,7 +45,7 @@ def __init__( window_size: int = 4, isflow: bool = True, **kwargs: Any - ): + ) -> None: super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels @@ -132,7 +132,7 @@ def __init__( proximal_bias: bool = False, proximal_init: bool = True, **kwargs: Any - ): + ) -> None: super().__init__() self.hidden_channels = hidden_channels self.filter_channels = filter_channels @@ -180,7 +180,7 @@ def __init__( ) self.norm_layers_2.append(LayerNorm(hidden_channels)) - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, h: torch.Tensor, h_mask: torch.Tensor): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, h: torch.Tensor, h_mask: torch.Tensor) -> torch.Tensor: """ x: decoder input h: encoder output @@ -218,7 +218,7 @@ def __init__( block_length: Optional[int] = None, proximal_bias: bool = False, proximal_init: bool = False, - ): + ) -> None: super().__init__() assert channels % n_heads == 0 @@ -272,7 +272,13 @@ def forward(self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Te x = self.conv_o(x) return x - def attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]: + def attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: # reshape [b, d, t] -> [b, n_h, t, d_k] b, d, t_s, t_t = (*key.size(), query.size(2)) query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) @@ -419,7 +425,7 @@ def __init__( p_dropout: float = 0.0, activation: Optional[str] = None, causal: bool = False, - ): + ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels diff --git a/style_bert_vits2/models/models.py b/style_bert_vits2/models/models.py index 4efd10059..829ca4a06 100644 --- a/style_bert_vits2/models/models.py +++ b/style_bert_vits2/models/models.py @@ -1,4 +1,5 @@ import math +from typing import Any, Optional import torch from torch import nn @@ -10,14 +11,18 @@ from style_bert_vits2.models import commons from style_bert_vits2.models import modules from style_bert_vits2.models import monotonic_alignment -from style_bert_vits2.models.commons import get_padding, init_weights from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS class DurationDiscriminator(nn.Module): # vits2 def __init__( - self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 - ): + self, + in_channels: int, + filter_channels: int, + kernel_size: int, + p_dropout: float, + gin_channels: int = 0 + ) -> None: super().__init__() self.in_channels = in_channels @@ -51,7 +56,13 @@ def __init__( self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid()) - def forward_probability(self, x, x_mask, dur, g=None): + def forward_probability( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + dur: torch.Tensor, + g: Optional[torch.Tensor] = None, + ) -> torch.Tensor: dur = self.dur_proj(dur) x = torch.cat([x, dur], dim=1) x = self.pre_out_conv_1(x * x_mask) @@ -67,7 +78,14 @@ def forward_probability(self, x, x_mask, dur, g=None): output_prob = self.output_layer(x) return output_prob - def forward(self, x, x_mask, dur_r, dur_hat, g=None): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + dur_r: torch.Tensor, + dur_hat: torch.Tensor, + g: Optional[torch.Tensor] = None, + ) -> list[torch.Tensor]: x = torch.detach(x) if g is not None: g = torch.detach(g) @@ -92,17 +110,17 @@ def forward(self, x, x_mask, dur_r, dur_hat, g=None): class TransformerCouplingBlock(nn.Module): def __init__( self, - channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - n_flows=4, - gin_channels=0, - share_parameter=False, - ): + channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: float, + n_flows: int = 4, + gin_channels: int = 0, + share_parameter: bool = False, + ) -> None: super().__init__() self.channels = channels self.hidden_channels = hidden_channels @@ -114,16 +132,17 @@ def __init__( self.flows = nn.ModuleList() self.wn = ( - attentions.FFT( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - isflow=True, - gin_channels=self.gin_channels, - ) + # attentions.FFT( + # hidden_channels, + # filter_channels, + # n_heads, + # n_layers, + # kernel_size, + # p_dropout, + # isflow=True, + # gin_channels=self.gin_channels, + # ) + None if share_parameter else None ) @@ -145,7 +164,13 @@ def __init__( ) self.flows.append(modules.Flip()) - def forward(self, x, x_mask, g=None, reverse=False): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> torch.Tensor: if not reverse: for flow in self.flows: x, _ = flow(x, x_mask, g=g, reverse=reverse) @@ -158,13 +183,13 @@ def forward(self, x, x_mask, g=None, reverse=False): class StochasticDurationPredictor(nn.Module): def __init__( self, - in_channels, - filter_channels, - kernel_size, - p_dropout, - n_flows=4, - gin_channels=0, - ): + in_channels: int, + filter_channels: int, + kernel_size: int, + p_dropout: float, + n_flows: int = 4, + gin_channels: int = 0, + ) -> None: super().__init__() filter_channels = in_channels # it needs to be removed from future version. self.in_channels = in_channels @@ -204,7 +229,15 @@ def __init__( if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, filter_channels, 1) - def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + w: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + noise_scale: float = 1.0, + ) -> torch.Tensor: x = torch.detach(x) x = self.pre(x) if g is not None: @@ -268,8 +301,13 @@ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): class DurationPredictor(nn.Module): def __init__( - self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 - ): + self, + in_channels: int, + filter_channels: int, + kernel_size: int, + p_dropout: float, + gin_channels: int = 0, + ) -> None: super().__init__() self.in_channels = in_channels @@ -292,7 +330,7 @@ def __init__( if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, in_channels, 1) - def forward(self, x, x_mask, g=None): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: x = torch.detach(x) if g is not None: g = torch.detach(g) @@ -312,17 +350,17 @@ def forward(self, x, x_mask, g=None): class TextEncoder(nn.Module): def __init__( self, - n_vocab, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - n_speakers, - gin_channels=0, - ): + n_vocab: int, + out_channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: float, + n_speakers: int, + gin_channels: int = 0, + ) -> None: super().__init__() self.n_vocab = n_vocab self.out_channels = out_channels @@ -357,17 +395,17 @@ def __init__( def forward( self, - x, - x_lengths, - tone, - language, - bert, - ja_bert, - en_bert, - style_vec, - sid, - g=None, - ): + x: torch.Tensor, + x_lengths: torch.Tensor, + tone: torch.Tensor, + language: torch.Tensor, + bert: torch.Tensor, + ja_bert: torch.Tensor, + en_bert: torch.Tensor, + style_vec: torch.Tensor, + sid: torch.Tensor, + g: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: bert_emb = self.bert_proj(bert).transpose(1, 2) ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2) en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2) @@ -399,14 +437,14 @@ def forward( class ResidualCouplingBlock(nn.Module): def __init__( self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=0, - ): + channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + n_layers: int, + n_flows: int = 4, + gin_channels: int = 0, + ) -> None: super().__init__() self.channels = channels self.hidden_channels = hidden_channels @@ -431,7 +469,13 @@ def __init__( ) self.flows.append(modules.Flip()) - def forward(self, x, x_mask, g=None, reverse=False): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> torch.Tensor: if not reverse: for flow in self.flows: x, _ = flow(x, x_mask, g=g, reverse=reverse) @@ -444,14 +488,14 @@ def forward(self, x, x_mask, g=None, reverse=False): class PosteriorEncoder(nn.Module): def __init__( self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0, - ): + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + n_layers: int, + gin_channels: int = 0, + ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -471,7 +515,12 @@ def __init__( ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, g=None): + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + g: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x.dtype ) @@ -486,22 +535,22 @@ def forward(self, x, x_lengths, g=None): class Generator(torch.nn.Module): def __init__( self, - initial_channel, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=0, - ): + initial_channel: int, + resblock_str: str, + resblock_kernel_sizes: list[int], + resblock_dilation_sizes: list[list[int]], + upsample_rates: list[int], + upsample_initial_channel: int, + upsample_kernel_sizes: list[int], + gin_channels: int = 0, + ) -> None: super(Generator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) self.conv_pre = Conv1d( initial_channel, upsample_initial_channel, 7, 1, padding=3 ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + resblock = modules.ResBlock1 if resblock_str == "1" else modules.ResBlock2 self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -518,20 +567,22 @@ def __init__( ) self.resblocks = nn.ModuleList() + ch = None for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate( zip(resblock_kernel_sizes, resblock_dilation_sizes) ): - self.resblocks.append(resblock(ch, k, d)) + self.resblocks.append(resblock(ch, k, d)) # type: ignore + assert ch is not None self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) + self.ups.apply(commons.init_weights) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - def forward(self, x, g=None): + def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.conv_pre(x) if g is not None: x = x + self.cond(g) @@ -545,6 +596,7 @@ def forward(self, x, g=None): xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) + assert xs is not None x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) @@ -552,7 +604,7 @@ def forward(self, x, g=None): return x - def remove_weight_norm(self): + def remove_weight_norm(self) -> None: print("Removing weight norm...") for layer in self.ups: remove_weight_norm(layer) @@ -561,7 +613,7 @@ def remove_weight_norm(self): class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + def __init__(self, period: int, kernel_size: int = 5, stride: int = 3, use_spectral_norm: bool = False) -> None: super(DiscriminatorP, self).__init__() self.period = period self.use_spectral_norm = use_spectral_norm @@ -574,7 +626,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 32, (kernel_size, 1), (stride, 1), - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), norm_f( @@ -583,7 +635,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 128, (kernel_size, 1), (stride, 1), - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), norm_f( @@ -592,7 +644,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 512, (kernel_size, 1), (stride, 1), - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), norm_f( @@ -601,7 +653,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 1024, (kernel_size, 1), (stride, 1), - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), norm_f( @@ -610,14 +662,14 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 1024, (kernel_size, 1), 1, - padding=(get_padding(kernel_size, 1), 0), + padding=(commons.get_padding(kernel_size, 1), 0), ) ), ] ) self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - def forward(self, x): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: fmap = [] # 1d to 2d @@ -640,7 +692,7 @@ def forward(self, x): class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): + def __init__(self, use_spectral_norm: bool = False) -> None: super(DiscriminatorS, self).__init__() norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList( @@ -655,7 +707,7 @@ def __init__(self, use_spectral_norm=False): ) self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: fmap = [] for layer in self.convs: @@ -670,7 +722,7 @@ def forward(self, x): class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self, use_spectral_norm=False): + def __init__(self, use_spectral_norm: bool = False) -> None: super(MultiPeriodDiscriminator, self).__init__() periods = [2, 3, 5, 7, 11] @@ -680,7 +732,11 @@ def __init__(self, use_spectral_norm=False): ] self.discriminators = nn.ModuleList(discs) - def forward(self, y, y_hat): + def forward( + self, + y: torch.Tensor, + y_hat: torch.Tensor, + ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: y_d_rs = [] y_d_gs = [] fmap_rs = [] @@ -702,7 +758,7 @@ class ReferenceEncoder(nn.Module): outputs --- [N, ref_enc_gru_size] """ - def __init__(self, spec_channels, gin_channels=0): + def __init__(self, spec_channels: int, gin_channels: int = 0) -> None: super().__init__() self.spec_channels = spec_channels ref_enc_filters = [32, 32, 64, 64, 128, 128] @@ -731,7 +787,7 @@ def __init__(self, spec_channels, gin_channels=0): ) self.proj = nn.Linear(128, gin_channels) - def forward(self, inputs, mask=None): + def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: N = inputs.size(0) out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] for conv in self.convs: @@ -749,7 +805,7 @@ def forward(self, inputs, mask=None): return self.proj(out.squeeze(0)) - def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int: for i in range(n_convs): L = (L - kernel_size + 2 * pad) // stride + 1 return L @@ -762,31 +818,31 @@ class SynthesizerTrn(nn.Module): def __init__( self, - n_vocab, - spec_channels, - segment_size, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - n_speakers=256, - gin_channels=256, - use_sdp=True, - n_flow_layer=4, - n_layers_trans_flow=4, - flow_share_parameter=False, - use_transformer_flow=True, - **kwargs, - ): + n_vocab: int, + spec_channels: int, + segment_size: int, + inter_channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: float, + resblock: str, + resblock_kernel_sizes: list[int], + resblock_dilation_sizes: list[list[int]], + upsample_rates: list[int], + upsample_initial_channel: int, + upsample_kernel_sizes: list[int], + n_speakers: int = 256, + gin_channels: int = 256, + use_sdp: bool = True, + n_flow_layer: int = 4, + n_layers_trans_flow: int = 4, + flow_share_parameter: bool = False, + use_transformer_flow: bool = True, + **kwargs: Any, + ) -> None: super().__init__() self.n_vocab = n_vocab self.spec_channels = spec_channels @@ -884,18 +940,27 @@ def __init__( def forward( self, - x, - x_lengths, - y, - y_lengths, - sid, - tone, - language, - bert, - ja_bert, - en_bert, - style_vec, - ): + x: torch.Tensor, + x_lengths: torch.Tensor, + y: torch.Tensor, + y_lengths: torch.Tensor, + sid: torch.Tensor, + tone: torch.Tensor, + language: torch.Tensor, + bert: torch.Tensor, + ja_bert: torch.Tensor, + en_bert: torch.Tensor, + style_vec: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + tuple[torch.Tensor, ...], + tuple[torch.Tensor, ...], + ]: if self.n_speakers > 0: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] else: @@ -973,27 +1038,28 @@ def forward( def infer( self, - x, - x_lengths, - sid, - tone, - language, - bert, - ja_bert, - en_bert, - style_vec, - noise_scale=0.667, - length_scale=1.0, - noise_scale_w=0.8, - max_len=None, - sdp_ratio=0.0, - y=None, - ): + x: torch.Tensor, + x_lengths: torch.Tensor, + sid: torch.Tensor, + tone: torch.Tensor, + language: torch.Tensor, + bert: torch.Tensor, + ja_bert: torch.Tensor, + en_bert: torch.Tensor, + style_vec: torch.Tensor, + noise_scale: float = 0.667, + length_scale: float = 1.0, + noise_scale_w: float = 0.8, + max_len: Optional[int] = None, + sdp_ratio: float = 0.0, + y: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, ...]]: # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert) # g = self.gst(y) if self.n_speakers > 0: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] else: + assert y is not None g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1) x, m_p, logs_p, x_mask = self.enc_p( x, x_lengths, tone, language, bert, ja_bert, en_bert, style_vec, sid, g=g diff --git a/style_bert_vits2/models/models_jp_extra.py b/style_bert_vits2/models/models_jp_extra.py index 3a43d5191..a43a7157b 100644 --- a/style_bert_vits2/models/models_jp_extra.py +++ b/style_bert_vits2/models/models_jp_extra.py @@ -1,4 +1,5 @@ import math +from typing import Any, Optional import torch from torch import nn @@ -10,13 +11,18 @@ from style_bert_vits2.models import commons from style_bert_vits2.models import modules from style_bert_vits2.models import monotonic_alignment -from style_bert_vits2.nlp.symbols import SYMBOLS, NUM_TONES, NUM_LANGUAGES +from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS class DurationDiscriminator(nn.Module): # vits2 def __init__( - self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 - ): + self, + in_channels: int, + filter_channels: int, + kernel_size: int, + p_dropout: float, + gin_channels: int = 0 + ) -> None: super().__init__() self.in_channels = in_channels @@ -47,7 +53,7 @@ def __init__( nn.Linear(2 * filter_channels, 1), nn.Sigmoid() ) - def forward_probability(self, x, dur): + def forward_probability(self, x: torch.Tensor, dur: torch.Tensor) -> torch.Tensor: dur = self.dur_proj(dur) x = torch.cat([x, dur], dim=1) x = x.transpose(1, 2) @@ -55,7 +61,14 @@ def forward_probability(self, x, dur): output_prob = self.output_layer(x) return output_prob - def forward(self, x, x_mask, dur_r, dur_hat, g=None): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + dur_r: torch.Tensor, + dur_hat: torch.Tensor, + g: Optional[torch.Tensor] = None, + ) -> list[torch.Tensor]: x = torch.detach(x) if g is not None: g = torch.detach(g) @@ -80,17 +93,17 @@ def forward(self, x, x_mask, dur_r, dur_hat, g=None): class TransformerCouplingBlock(nn.Module): def __init__( self, - channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - n_flows=4, - gin_channels=0, - share_parameter=False, - ): + channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: float, + n_flows: int = 4, + gin_channels: int = 0, + share_parameter: bool = False, + ) -> None: super().__init__() self.channels = channels self.hidden_channels = hidden_channels @@ -102,16 +115,17 @@ def __init__( self.flows = nn.ModuleList() self.wn = ( - attentions.FFT( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - isflow=True, - gin_channels=self.gin_channels, - ) + # attentions.FFT( + # hidden_channels, + # filter_channels, + # n_heads, + # n_layers, + # kernel_size, + # p_dropout, + # isflow=True, + # gin_channels=self.gin_channels, + # ) + None if share_parameter else None ) @@ -133,7 +147,13 @@ def __init__( ) self.flows.append(modules.Flip()) - def forward(self, x, x_mask, g=None, reverse=False): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> torch.Tensor: if not reverse: for flow in self.flows: x, _ = flow(x, x_mask, g=g, reverse=reverse) @@ -146,13 +166,13 @@ def forward(self, x, x_mask, g=None, reverse=False): class StochasticDurationPredictor(nn.Module): def __init__( self, - in_channels, - filter_channels, - kernel_size, - p_dropout, - n_flows=4, - gin_channels=0, - ): + in_channels: int, + filter_channels: int, + kernel_size: int, + p_dropout: float, + n_flows: int = 4, + gin_channels: int = 0, + ) -> None: super().__init__() filter_channels = in_channels # it needs to be removed from future version. self.in_channels = in_channels @@ -192,7 +212,15 @@ def __init__( if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, filter_channels, 1) - def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + w: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + noise_scale: float = 1.0, + ) -> torch.Tensor: x = torch.detach(x) x = self.pre(x) if g is not None: @@ -256,8 +284,13 @@ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): class DurationPredictor(nn.Module): def __init__( - self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0 - ): + self, + in_channels: int, + filter_channels: int, + kernel_size: int, + p_dropout: float, + gin_channels: int = 0, + ) -> None: super().__init__() self.in_channels = in_channels @@ -280,7 +313,7 @@ def __init__( if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, in_channels, 1) - def forward(self, x, x_mask, g=None): + def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: x = torch.detach(x) if g is not None: g = torch.detach(g) @@ -298,14 +331,14 @@ def forward(self, x, x_mask, g=None): class Bottleneck(nn.Sequential): - def __init__(self, in_dim, hidden_dim): + def __init__(self, in_dim: int, hidden_dim: int) -> None: c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False) c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False) - super().__init__(*[c_fc1, c_fc2]) + super().__init__(c_fc1, c_fc2) class Block(nn.Module): - def __init__(self, in_dim, hidden_dim) -> None: + def __init__(self, in_dim: int, hidden_dim: int) -> None: super().__init__() self.norm = nn.LayerNorm(in_dim) self.mlp = MLP(in_dim, hidden_dim) @@ -316,13 +349,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MLP(nn.Module): - def __init__(self, in_dim, hidden_dim): + def __init__(self, in_dim: int, hidden_dim: int) -> None: super().__init__() self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False) self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False) self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.silu(self.c_fc1(x)) * self.c_fc2(x) x = self.c_proj(x) return x @@ -331,16 +364,16 @@ def forward(self, x: torch.Tensor): class TextEncoder(nn.Module): def __init__( self, - n_vocab, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - gin_channels=0, - ): + n_vocab: int, + out_channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: float, + gin_channels: int = 0, + ) -> None: super().__init__() self.n_vocab = n_vocab self.out_channels = out_channels @@ -373,7 +406,16 @@ def __init__( ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, tone, language, bert, style_vec, g=None): + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + tone: torch.Tensor, + language: torch.Tensor, + bert: torch.Tensor, + style_vec: torch.Tensor, + g: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: bert_emb = self.bert_proj(bert).transpose(1, 2) style_emb = self.style_proj(style_vec.unsqueeze(1)) x = ( @@ -400,14 +442,14 @@ def forward(self, x, x_lengths, tone, language, bert, style_vec, g=None): class ResidualCouplingBlock(nn.Module): def __init__( self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=0, - ): + channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + n_layers: int, + n_flows: int = 4, + gin_channels: int = 0, + ) -> None: super().__init__() self.channels = channels self.hidden_channels = hidden_channels @@ -432,7 +474,13 @@ def __init__( ) self.flows.append(modules.Flip()) - def forward(self, x, x_mask, g=None, reverse=False): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> torch.Tensor: if not reverse: for flow in self.flows: x, _ = flow(x, x_mask, g=g, reverse=reverse) @@ -445,14 +493,14 @@ def forward(self, x, x_mask, g=None, reverse=False): class PosteriorEncoder(nn.Module): def __init__( self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0, - ): + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + n_layers: int, + gin_channels: int = 0, + ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -472,7 +520,12 @@ def __init__( ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, g=None): + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + g: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x.dtype ) @@ -487,22 +540,22 @@ def forward(self, x, x_lengths, g=None): class Generator(torch.nn.Module): def __init__( self, - initial_channel, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - gin_channels=0, - ): + initial_channel: int, + resblock_str: str, + resblock_kernel_sizes: list[int], + resblock_dilation_sizes: list[list[int]], + upsample_rates: list[int], + upsample_initial_channel: int, + upsample_kernel_sizes: list[int], + gin_channels: int = 0, + ) -> None: super(Generator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) self.conv_pre = Conv1d( initial_channel, upsample_initial_channel, 7, 1, padding=3 ) - resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + resblock = modules.ResBlock1 if resblock_str == "1" else modules.ResBlock2 self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -519,20 +572,22 @@ def __init__( ) self.resblocks = nn.ModuleList() + ch = None for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate( zip(resblock_kernel_sizes, resblock_dilation_sizes) ): - self.resblocks.append(resblock(ch, k, d)) + self.resblocks.append(resblock(ch, k, d)) # type: ignore + assert ch is not None self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) self.ups.apply(commons.init_weights) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - def forward(self, x, g=None): + def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: x = self.conv_pre(x) if g is not None: x = x + self.cond(g) @@ -546,6 +601,7 @@ def forward(self, x, g=None): xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) + assert xs is not None x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) @@ -553,7 +609,7 @@ def forward(self, x, g=None): return x - def remove_weight_norm(self): + def remove_weight_norm(self) -> None: print("Removing weight norm...") for layer in self.ups: remove_weight_norm(layer) @@ -562,7 +618,7 @@ def remove_weight_norm(self): class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + def __init__(self, period: int, kernel_size: int = 5, stride: int = 3, use_spectral_norm: bool = False) -> None: super(DiscriminatorP, self).__init__() self.period = period self.use_spectral_norm = use_spectral_norm @@ -618,7 +674,7 @@ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): ) self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - def forward(self, x): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: fmap = [] # 1d to 2d @@ -641,7 +697,7 @@ def forward(self, x): class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): + def __init__(self, use_spectral_norm: bool = False) -> None: super(DiscriminatorS, self).__init__() norm_f = weight_norm if use_spectral_norm is False else spectral_norm self.convs = nn.ModuleList( @@ -656,7 +712,7 @@ def __init__(self, use_spectral_norm=False): ) self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]: fmap = [] for layer in self.convs: @@ -671,7 +727,7 @@ def forward(self, x): class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self, use_spectral_norm=False): + def __init__(self, use_spectral_norm: bool = False) -> None: super(MultiPeriodDiscriminator, self).__init__() periods = [2, 3, 5, 7, 11] @@ -681,7 +737,11 @@ def __init__(self, use_spectral_norm=False): ] self.discriminators = nn.ModuleList(discs) - def forward(self, y, y_hat): + def forward( + self, + y: torch.Tensor, + y_hat: torch.Tensor, + ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: y_d_rs = [] y_d_gs = [] fmap_rs = [] @@ -701,8 +761,12 @@ class WavLMDiscriminator(nn.Module): """docstring for Discriminator.""" def __init__( - self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False - ): + self, + slm_hidden: int = 768, + slm_layers: int = 13, + initial_channel: int = 64, + use_spectral_norm: bool = False, + ) -> None: super(WavLMDiscriminator, self).__init__() norm_f = weight_norm if use_spectral_norm == False else spectral_norm self.pre = norm_f( @@ -732,7 +796,7 @@ def __init__( self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pre(x) fmap = [] @@ -752,7 +816,7 @@ class ReferenceEncoder(nn.Module): outputs --- [N, ref_enc_gru_size] """ - def __init__(self, spec_channels, gin_channels=0): + def __init__(self, spec_channels: int, gin_channels: int = 0) -> None: super().__init__() self.spec_channels = spec_channels ref_enc_filters = [32, 32, 64, 64, 128, 128] @@ -781,7 +845,7 @@ def __init__(self, spec_channels, gin_channels=0): ) self.proj = nn.Linear(128, gin_channels) - def forward(self, inputs, mask=None): + def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: N = inputs.size(0) out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] for conv in self.convs: @@ -799,7 +863,7 @@ def forward(self, inputs, mask=None): return self.proj(out.squeeze(0)) - def calculate_channels(self, L, kernel_size, stride, pad, n_convs): + def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int: for i in range(n_convs): L = (L - kernel_size + 2 * pad) // stride + 1 return L @@ -812,31 +876,31 @@ class SynthesizerTrn(nn.Module): def __init__( self, - n_vocab, - spec_channels, - segment_size, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - n_speakers=256, - gin_channels=256, - use_sdp=True, - n_flow_layer=4, - n_layers_trans_flow=6, - flow_share_parameter=False, - use_transformer_flow=True, - **kwargs - ): + n_vocab: int, + spec_channels: int, + segment_size: int, + inter_channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: float, + resblock: str, + resblock_kernel_sizes: list[int], + resblock_dilation_sizes: list[list[int]], + upsample_rates: list[int], + upsample_initial_channel: int, + upsample_kernel_sizes: list[int], + n_speakers: int = 256, + gin_channels: int = 256, + use_sdp: bool = True, + n_flow_layer: int = 4, + n_layers_trans_flow: int = 6, + flow_share_parameter: bool = False, + use_transformer_flow: bool = True, + **kwargs: Any, + ) -> None: super().__init__() self.n_vocab = n_vocab self.spec_channels = spec_channels @@ -933,16 +997,26 @@ def __init__( def forward( self, - x, - x_lengths, - y, - y_lengths, - sid, - tone, - language, - bert, - style_vec, - ): + x: torch.Tensor, + x_lengths: torch.Tensor, + y: torch.Tensor, + y_lengths: torch.Tensor, + sid: torch.Tensor, + tone: torch.Tensor, + language: torch.Tensor, + bert: torch.Tensor, + style_vec: torch.Tensor, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + tuple[torch.Tensor, ...], + tuple[torch.Tensor, ...], + ]: if self.n_speakers > 0: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] else: @@ -1014,32 +1088,33 @@ def forward( ids_slice, x_mask, y_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), + (z, z_p, m_p, logs_p, m_q, logs_q), # type: ignore (x, logw, logw_), # , logw_sdp), g, ) def infer( self, - x, - x_lengths, - sid, - tone, - language, - bert, - style_vec, - noise_scale=0.667, - length_scale=1.0, - noise_scale_w=0.8, - max_len=None, - sdp_ratio=0.0, - y=None, - ): + x: torch.Tensor, + x_lengths: torch.Tensor, + sid: torch.Tensor, + tone: torch.Tensor, + language: torch.Tensor, + bert: torch.Tensor, + style_vec: torch.Tensor, + noise_scale: float = 0.667, + length_scale: float = 1.0, + noise_scale_w: float = 0.8, + max_len: Optional[int] = None, + sdp_ratio: float = 0.0, + y: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, ...]]: # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert) # g = self.gst(y) if self.n_speakers > 0: g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] else: + assert y is not None g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1) x, m_p, logs_p, x_mask = self.enc_p( x, x_lengths, tone, language, bert, style_vec, g=g diff --git a/style_bert_vits2/models/modules.py b/style_bert_vits2/models/modules.py index df3807123..8eed9635c 100644 --- a/style_bert_vits2/models/modules.py +++ b/style_bert_vits2/models/modules.py @@ -16,7 +16,7 @@ class LayerNorm(nn.Module): - def __init__(self, channels: int, eps: float = 1e-5): + def __init__(self, channels: int, eps: float = 1e-5) -> None: super().__init__() self.channels = channels self.eps = eps @@ -39,7 +39,7 @@ def __init__( kernel_size: int, n_layers: int, p_dropout: float, - ): + ) -> None: super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels @@ -88,7 +88,7 @@ class DDSConv(nn.Module): Dialted and Depth-Separable Convolution """ - def __init__(self, channels: int, kernel_size: int, n_layers: int, p_dropout: float = 0.0): + def __init__(self, channels: int, kernel_size: int, n_layers: int, p_dropout: float = 0.0) -> None: super().__init__() self.channels = channels self.kernel_size = kernel_size @@ -141,7 +141,7 @@ def __init__( n_layers: int, gin_channels: int = 0, p_dropout: float = 0, - ): + ) -> None: super(WN, self).__init__() assert kernel_size % 2 == 1 self.hidden_channels = hidden_channels @@ -221,7 +221,7 @@ def remove_weight_norm(self) -> None: class ResBlock1(torch.nn.Module): - def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int, int] = (1, 3, 5)): + def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int, int] = (1, 3, 5)) -> None: super(ResBlock1, self).__init__() self.convs1 = nn.ModuleList( [ @@ -318,7 +318,7 @@ def remove_weight_norm(self) -> None: class ResBlock2(torch.nn.Module): - def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int] = (1, 3)): + def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int] = (1, 3)) -> None: super(ResBlock2, self).__init__() self.convs = nn.ModuleList( [ @@ -363,7 +363,13 @@ def remove_weight_norm(self) -> None: class Log(nn.Module): - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, reverse: bool = False, **kwargs: Any): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + reverse: bool = False, + **kwargs: Any, + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if not reverse: y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask logdet = torch.sum(-y, [1, 2]) @@ -390,7 +396,7 @@ def forward( class ElementwiseAffine(nn.Module): - def __init__(self, channels: int): + def __init__(self, channels: int) -> None: super().__init__() self.channels = channels self.m = nn.Parameter(torch.zeros(channels, 1)) @@ -424,7 +430,7 @@ def __init__( p_dropout: float = 0, gin_channels: int = 0, mean_only: bool = False, - ): + ) -> None: assert channels % 2 == 0, "channels should be divisible by 2" super().__init__() self.channels = channels @@ -449,7 +455,13 @@ def __init__( assert self.post.bias is not None self.post.bias.data.zero_() - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: x0, x1 = torch.split(x, [self.half_channels] * 2, 1) h = self.pre(x0) * x_mask h = self.enc(h, x_mask, g=g) @@ -480,7 +492,7 @@ def __init__( n_layers: int, num_bins: int = 10, tail_bound: float = 5.0, - ): + ) -> None: super().__init__() self.in_channels = in_channels self.filter_channels = filter_channels @@ -499,7 +511,13 @@ def __init__( assert self.proj.bias is not None self.proj.bias.data.zero_() - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse: bool = False): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: x0, x1 = torch.split(x, [self.half_channels] * 2, 1) h = self.pre(x0) h = self.convs(h, x_mask, g=g) From c594f7ea7a51e225e5bb980b8d1a947cac07f32e Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sat, 9 Mar 2024 00:26:51 +0000 Subject: [PATCH 065/148] Refactor: change execution location of pyopenjtalk.initialize() Considering library design, this function with many side effects should not be executed in a library. --- server_editor.py | 8 ++++++-- server_fastapi.py | 3 +-- style_bert_vits2/nlp/japanese/g2p.py | 8 -------- style_bert_vits2/nlp/japanese/user_dict/__init__.py | 4 ---- webui/inference.py | 3 +-- 5 files changed, 8 insertions(+), 18 deletions(-) diff --git a/server_editor.py b/server_editor.py index 75fa4727c..9a6a08af8 100644 --- a/server_editor.py +++ b/server_editor.py @@ -42,6 +42,7 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.nlp import bert_models +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.nlp.japanese.normalizer import normalize_text from style_bert_vits2.nlp.japanese.user_dict import ( @@ -148,8 +149,11 @@ def save_last_download(latest_release): # ---フロントエンド部分に関する処理ここまで--- # 以降はAPIの設定 -# 最初に pyopenjtalk の辞書を更新 -## pyopenjtalk_worker の起動も同時に行われる +# pyopenjtalk_worker を起動 +## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する +pyopenjtalk.initialize() + +# pyopenjtalk の辞書を更新 update_dict() # 事前に BERT モデル/トークナイザーをロードしておく diff --git a/server_fastapi.py b/server_fastapi.py index d8da95611..e8309dadb 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -42,8 +42,7 @@ # pyopenjtalk_worker を起動 -## Gradio はマルチスレッドだが、initialize() 内部で利用されている signal はマルチスレッドから設定できない -## さらに起動には若干時間がかかるため、事前に起動しておいた方が体験が良い +## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する pyopenjtalk.initialize() # 事前に BERT モデル/トークナイザーをロードしておく diff --git a/style_bert_vits2/nlp/japanese/g2p.py b/style_bert_vits2/nlp/japanese/g2p.py index 70b4e811c..7fc97f210 100644 --- a/style_bert_vits2/nlp/japanese/g2p.py +++ b/style_bert_vits2/nlp/japanese/g2p.py @@ -112,10 +112,6 @@ def text_to_sep_kata( tuple[list[str], list[str]]: 分割された単語リストと、その読み(カタカナ or 記号1文字)のリスト """ - # pyopenjtalk_worker を初期化 - ## 一度 worker を起動すれば、明示的に終了するかプロセス終了まで同一の worker に接続される - pyopenjtalk.initialize() - # parsed: OpenJTalkの解析結果 parsed = pyopenjtalk.run_frontend(norm_text) sep_text: list[str] = [] @@ -249,10 +245,6 @@ def _numeric_feature_by_regex(regex: str, s: str) -> int: return -50 return int(match.group(1)) - # pyopenjtalk_worker を初期化 - ## 一度 worker を起動すれば、明示的に終了するかプロセス終了まで同一の worker に接続される - pyopenjtalk.initialize() - labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text)) N = len(labels) diff --git a/style_bert_vits2/nlp/japanese/user_dict/__init__.py b/style_bert_vits2/nlp/japanese/user_dict/__init__.py index 1887c1334..a2cc43ea9 100644 --- a/style_bert_vits2/nlp/japanese/user_dict/__init__.py +++ b/style_bert_vits2/nlp/japanese/user_dict/__init__.py @@ -80,10 +80,6 @@ def update_dict( コンパイル済み辞書ファイルのパス """ - # pyopenjtalk_worker を初期化 - ## 一度 worker を起動すれば、明示的に終了するかプロセス終了まで同一の worker に接続される - pyopenjtalk.initialize() - random_string = uuid4() tmp_csv_path = compiled_dict_path.with_suffix( f".dict_csv-{random_string}.tmp" diff --git a/webui/inference.py b/webui/inference.py index ef1f97f1c..536b38b4a 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -27,8 +27,7 @@ # pyopenjtalk_worker を起動 -## Gradio はマルチスレッドだが、initialize() 内部で利用されている signal はマルチスレッドから設定できない -## さらに起動には若干時間がかかるため、事前に起動しておいた方が体験が良い +## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する pyopenjtalk.initialize() # 事前に BERT モデル/トークナイザーをロードしておく From 98ab8e79789cdf6f8f95d92a769d391994b8ead5 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sat, 9 Mar 2024 15:58:57 +0000 Subject: [PATCH 066/148] Refactor: separate module for utilities related to loading/saving checkpoints and safetensors --- style_bert_vits2/models/infer.py | 4 +- style_bert_vits2/models/utils.py | 357 ------------------- style_bert_vits2/models/utils/__init__.py | 156 ++++++++ style_bert_vits2/models/utils/checkpoints.py | 194 ++++++++++ style_bert_vits2/models/utils/safetensors.py | 91 +++++ train_ms.py | 71 ++-- train_ms_jp_extra.py | 94 ++--- 7 files changed, 510 insertions(+), 457 deletions(-) delete mode 100644 style_bert_vits2/models/utils.py create mode 100644 style_bert_vits2/models/utils/__init__.py create mode 100644 style_bert_vits2/models/utils/checkpoints.py create mode 100644 style_bert_vits2/models/utils/safetensors.py diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index 93bf8c285..d3ec963ef 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -80,9 +80,9 @@ def get_net_g(model_path: str, version: str, device: str, hps: HyperParameters): net_g.state_dict() _ = net_g.eval() if model_path.endswith(".pth") or model_path.endswith(".pt"): - _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True) + _ = utils.checkpoints.load_checkpoint(model_path, net_g, None, skip_optimizer=True) elif model_path.endswith(".safetensors"): - _ = utils.load_safetensors(model_path, net_g, True) + _ = utils.safetensors.load_safetensors(model_path, net_g, True) else: raise ValueError(f"Unknown model format: {model_path}") return net_g diff --git a/style_bert_vits2/models/utils.py b/style_bert_vits2/models/utils.py deleted file mode 100644 index 37ea269dd..000000000 --- a/style_bert_vits2/models/utils.py +++ /dev/null @@ -1,357 +0,0 @@ -import argparse -import glob -import json -import logging -import os -import re -import subprocess - -import numpy as np -import torch -from safetensors import safe_open -from safetensors.torch import save_file -from scipy.io.wavfile import read - -from style_bert_vits2.logging import logger - - -MATPLOTLIB_FLAG = False - - -def load_checkpoint( - checkpoint_path, model, optimizer=None, skip_optimizer=False, for_infer=False -): - assert os.path.isfile(checkpoint_path) - checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") - iteration = checkpoint_dict["iteration"] - learning_rate = checkpoint_dict["learning_rate"] - logger.info( - f"Loading model and optimizer at iteration {iteration} from {checkpoint_path}" - ) - if ( - optimizer is not None - and not skip_optimizer - and checkpoint_dict["optimizer"] is not None - ): - optimizer.load_state_dict(checkpoint_dict["optimizer"]) - elif optimizer is None and not skip_optimizer: - # else: Disable this line if Infer and resume checkpoint,then enable the line upper - new_opt_dict = optimizer.state_dict() - new_opt_dict_params = new_opt_dict["param_groups"][0]["params"] - new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"] - new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params - optimizer.load_state_dict(new_opt_dict) - - saved_state_dict = checkpoint_dict["model"] - if hasattr(model, "module"): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - - new_state_dict = {} - for k, v in state_dict.items(): - try: - # assert "emb_g" not in k - new_state_dict[k] = saved_state_dict[k] - assert saved_state_dict[k].shape == v.shape, ( - saved_state_dict[k].shape, - v.shape, - ) - except: - # For upgrading from the old version - if "ja_bert_proj" in k: - v = torch.zeros_like(v) - logger.warning( - f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility" - ) - elif "enc_q" in k and for_infer: - continue - else: - logger.error(f"{k} is not in the checkpoint {checkpoint_path}") - - new_state_dict[k] = v - - if hasattr(model, "module"): - model.module.load_state_dict(new_state_dict, strict=False) - else: - model.load_state_dict(new_state_dict, strict=False) - - logger.info("Loaded '{}' (iteration {})".format(checkpoint_path, iteration)) - - return model, optimizer, learning_rate, iteration - - -def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): - logger.info( - "Saving model and optimizer state at iteration {} to {}".format( - iteration, checkpoint_path - ) - ) - if hasattr(model, "module"): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - torch.save( - { - "model": state_dict, - "iteration": iteration, - "optimizer": optimizer.state_dict(), - "learning_rate": learning_rate, - }, - checkpoint_path, - ) - - -def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): - """Freeing up space by deleting saved ckpts - - Arguments: - path_to_models -- Path to the model directory - n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth - sort_by_time -- True -> chronologically delete ckpts - False -> lexicographically delete ckpts - """ - import re - - ckpts_files = [ - f - for f in os.listdir(path_to_models) - if os.path.isfile(os.path.join(path_to_models, f)) - ] - - def name_key(_f): - return int(re.compile("._(\\d+)\\.pth").match(_f).group(1)) - - def time_key(_f): - return os.path.getmtime(os.path.join(path_to_models, _f)) - - sort_key = time_key if sort_by_time else name_key - - def x_sorted(_x): - return sorted( - [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], - key=sort_key, - ) - - to_del = [ - os.path.join(path_to_models, fn) - for fn in ( - x_sorted("G_")[:-n_ckpts_to_keep] - + x_sorted("D_")[:-n_ckpts_to_keep] - + x_sorted("WD_")[:-n_ckpts_to_keep] - + x_sorted("DUR_")[:-n_ckpts_to_keep] - ) - ] - - def del_info(fn): - return logger.info(f"Free up space by deleting ckpt {fn}") - - def del_routine(x): - return [os.remove(x), del_info(x)] - - [del_routine(fn) for fn in to_del] - - -def load_safetensors(checkpoint_path, model, for_infer=False): - """ - Load safetensors model. - """ - - tensors = {} - iteration = None - with safe_open(checkpoint_path, framework="pt", device="cpu") as f: - for key in f.keys(): - if key == "iteration": - iteration = f.get_tensor(key).item() - tensors[key] = f.get_tensor(key) - if hasattr(model, "module"): - result = model.module.load_state_dict(tensors, strict=False) - else: - result = model.load_state_dict(tensors, strict=False) - for key in result.missing_keys: - if key.startswith("enc_q") and for_infer: - continue - logger.warning(f"Missing key: {key}") - for key in result.unexpected_keys: - if key == "iteration": - continue - logger.warning(f"Unexpected key: {key}") - if iteration is None: - logger.info(f"Loaded '{checkpoint_path}'") - else: - logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})") - return model, iteration - - -def save_safetensors(model, iteration, checkpoint_path, is_half=False, for_infer=False): - """ - Save model with safetensors. - """ - if hasattr(model, "module"): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - keys = [] - for k in state_dict: - if "enc_q" in k and for_infer: - continue # noqa: E701 - keys.append(k) - - new_dict = ( - {k: state_dict[k].half() for k in keys} - if is_half - else {k: state_dict[k] for k in keys} - ) - new_dict["iteration"] = torch.LongTensor([iteration]) - logger.info(f"Saved safetensors to {checkpoint_path}") - save_file(new_dict, checkpoint_path) - - -def summarize( - writer, - global_step, - scalars={}, - histograms={}, - images={}, - audios={}, - audio_sampling_rate=22050, -): - for k, v in scalars.items(): - writer.add_scalar(k, v, global_step) - for k, v in histograms.items(): - writer.add_histogram(k, v, global_step) - for k, v in images.items(): - writer.add_image(k, v, global_step, dataformats="HWC") - for k, v in audios.items(): - writer.add_audio(k, v, global_step, audio_sampling_rate) - - -def is_resuming(dir_path): - # JP-ExtraバージョンではDURがなくWDがあったり変わるため、Gのみで判断する - g_list = glob.glob(os.path.join(dir_path, "G_*.pth")) - # d_list = glob.glob(os.path.join(dir_path, "D_*.pth")) - # dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth")) - return len(g_list) > 0 - - -def latest_checkpoint_path(dir_path, regex="G_*.pth"): - f_list = glob.glob(os.path.join(dir_path, regex)) - f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) - try: - x = f_list[-1] - except IndexError: - raise ValueError(f"No checkpoint found in {dir_path} with regex {regex}") - return x - - -def plot_spectrogram_to_numpy(spectrogram): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger("matplotlib") - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - plt.xlabel("Frames") - plt.ylabel("Channels") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data - - -def plot_alignment_to_numpy(alignment, info=None): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger("matplotlib") - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(6, 4)) - im = ax.imshow( - alignment.transpose(), aspect="auto", origin="lower", interpolation="none" - ) - fig.colorbar(im, ax=ax) - xlabel = "Decoder timestep" - if info is not None: - xlabel += "\n\n" + info - plt.xlabel(xlabel) - plt.ylabel("Encoder timestep") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data - - -def load_wav_to_torch(full_path): - sampling_rate, data = read(full_path) - return torch.FloatTensor(data.astype(np.float32)), sampling_rate - - -def load_filepaths_and_text(filename, split="|"): - with open(filename, encoding="utf-8") as f: - filepaths_and_text = [line.strip().split(split) for line in f] - return filepaths_and_text - - -def get_logger(model_dir, filename="train.log"): - global logger - logger = logging.getLogger(os.path.basename(model_dir)) - logger.setLevel(logging.DEBUG) - - formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") - if not os.path.exists(model_dir): - os.makedirs(model_dir) - h = logging.FileHandler(os.path.join(model_dir, filename)) - h.setLevel(logging.DEBUG) - h.setFormatter(formatter) - logger.addHandler(h) - return logger - - -def get_steps(model_path): - matches = re.findall(r"\d+", model_path) - return matches[-1] if matches else None - - -def check_git_hash(model_dir): - source_dir = os.path.dirname(os.path.realpath(__file__)) - if not os.path.exists(os.path.join(source_dir, ".git")): - logger.warning( - "{} is not a git repository, therefore hash value comparison will be ignored.".format( - source_dir - ) - ) - return - - cur_hash = subprocess.getoutput("git rev-parse HEAD") - - path = os.path.join(model_dir, "githash") - if os.path.exists(path): - saved_hash = open(path).read() - if saved_hash != cur_hash: - logger.warning( - "git hash values are different. {}(saved) != {}(current)".format( - saved_hash[:8], cur_hash[:8] - ) - ) - else: - open(path, "w").write(cur_hash) diff --git a/style_bert_vits2/models/utils/__init__.py b/style_bert_vits2/models/utils/__init__.py new file mode 100644 index 000000000..f488e091c --- /dev/null +++ b/style_bert_vits2/models/utils/__init__.py @@ -0,0 +1,156 @@ +import glob +import logging +import os +import re +import subprocess + +import numpy as np +import torch +from scipy.io.wavfile import read + +from style_bert_vits2.logging import logger +from style_bert_vits2.models.utils import checkpoints # type: ignore +from style_bert_vits2.models.utils import safetensors # type: ignore + + +MATPLOTLIB_FLAG = False + + +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats="HWC") + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def is_resuming(dir_path): + # JP-ExtraバージョンではDURがなくWDがあったり変わるため、Gのみで判断する + g_list = glob.glob(os.path.join(dir_path, "G_*.pth")) + # d_list = glob.glob(os.path.join(dir_path, "D_*.pth")) + # dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth")) + return len(g_list) > 0 + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment.transpose(), aspect="auto", origin="lower", interpolation="none" + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +def get_steps(model_path): + matches = re.findall(r"\d+", model_path) + return matches[-1] if matches else None + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warning( + "{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + ) + ) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warning( + "git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8] + ) + ) + else: + open(path, "w").write(cur_hash) diff --git a/style_bert_vits2/models/utils/checkpoints.py b/style_bert_vits2/models/utils/checkpoints.py new file mode 100644 index 000000000..f26f8fc92 --- /dev/null +++ b/style_bert_vits2/models/utils/checkpoints.py @@ -0,0 +1,194 @@ +import glob +import os +import re +from pathlib import Path +from typing import Any, Optional, Union + +import torch + +from style_bert_vits2.logging import logger + + +def load_checkpoint( + checkpoint_path: Union[str, Path], + model: torch.nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + skip_optimizer: bool = False, + for_infer: bool = False +) -> tuple[torch.nn.Module, Optional[torch.optim.Optimizer], float, int]: + """ + 指定されたパスからチェックポイントを読み込み、モデルとオプティマイザーを更新する。 + + Args: + checkpoint_path (Union[str, Path]): チェックポイントファイルのパス + model (torch.nn.Module): 更新するモデル + optimizer (Optional[torch.optim.Optimizer]): 更新するオプティマイザー。None の場合は更新しない + skip_optimizer (bool): オプティマイザーの更新をスキップするかどうかのフラグ + for_infer (bool): 推論用に読み込むかどうかのフラグ + + Returns: + tuple[torch.nn.Module, Optional[torch.optim.Optimizer], float, int]: 更新されたモデルとオプティマイザー、学習率、イテレーション番号 + """ + + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + logger.info( + f"Loading model and optimizer at iteration {iteration} from {checkpoint_path}" + ) + if ( + optimizer is not None + and not skip_optimizer + and checkpoint_dict["optimizer"] is not None + ): + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + elif optimizer is None and not skip_optimizer: + # else: Disable this line if Infer and resume checkpoint,then enable the line upper + new_opt_dict = optimizer.state_dict() # type: ignore + new_opt_dict_params = new_opt_dict["param_groups"][0]["params"] + new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"] + new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params + optimizer.load_state_dict(new_opt_dict) # type: ignore + + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + + new_state_dict = {} + for k, v in state_dict.items(): + try: + # assert "emb_g" not in k + new_state_dict[k] = saved_state_dict[k] + assert saved_state_dict[k].shape == v.shape, ( + saved_state_dict[k].shape, + v.shape, + ) + except: + # For upgrading from the old version + if "ja_bert_proj" in k: + v = torch.zeros_like(v) + logger.warning( + f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility" + ) + elif "enc_q" in k and for_infer: + continue + else: + logger.error(f"{k} is not in the checkpoint {checkpoint_path}") + + new_state_dict[k] = v + + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict, strict=False) + else: + model.load_state_dict(new_state_dict, strict=False) + + logger.info("Loaded '{}' (iteration {})".format(checkpoint_path, iteration)) + + return model, optimizer, learning_rate, iteration + + +def save_checkpoint( + model: torch.nn.Module, + optimizer: Union[torch.optim.Optimizer, torch.optim.AdamW], + learning_rate: float, + iteration: int, + checkpoint_path: Union[str, Path], +) -> None: + """ + モデルとオプティマイザーの状態を指定されたパスに保存する。 + + Args: + model (torch.nn.Module): 保存するモデル + optimizer (Union[torch.optim.Optimizer, torch.optim.AdamW]): 保存するオプティマイザー + learning_rate (float): 学習率 + iteration (int): イテレーション数 + checkpoint_path (Union[str, Path]): 保存先のパス + """ + logger.info(f"Saving model and optimizer state at iteration {iteration} to {checkpoint_path}") + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def clean_checkpoints(model_dir_path: Union[str, Path] = "logs/44k/", n_ckpts_to_keep: int = 2, sort_by_time: bool = True) -> None: + """ + 指定されたディレクトリから古いチェックポイントを削除して空き容量を確保する + + Args: + model_dir_path (Union[str, Path]): モデルが保存されているディレクトリのパス + n_ckpts_to_keep (int): 保持するチェックポイントの数(G_0.pth と D_0.pth を除く) + sort_by_time (bool): True の場合、時間順に削除。False の場合、名前順に削除 + """ + + ckpts_files = [ + f + for f in os.listdir(model_dir_path) + if os.path.isfile(os.path.join(model_dir_path, f)) + ] + + def name_key(_f: str) -> int: + return int(re.compile("._(\\d+)\\.pth").match(_f).group(1)) # type: ignore + + def time_key(_f: str) -> float: + return os.path.getmtime(os.path.join(model_dir_path, _f)) + + sort_key = time_key if sort_by_time else name_key + + def x_sorted(_x: str) -> list[str]: + return sorted( + [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], + key=sort_key, + ) + + to_del = [ + os.path.join(model_dir_path, fn) + for fn in ( + x_sorted("G_")[:-n_ckpts_to_keep] + + x_sorted("D_")[:-n_ckpts_to_keep] + + x_sorted("WD_")[:-n_ckpts_to_keep] + + x_sorted("DUR_")[:-n_ckpts_to_keep] + ) + ] + + def del_info(fn: str) -> None: + return logger.info(f"Free up space by deleting ckpt {fn}") + + def del_routine(x: str) -> list[Any]: + return [os.remove(x), del_info(x)] + + [del_routine(fn) for fn in to_del] + + +def get_latest_checkpoint_path(model_dir_path: Union[str, Path], regex: str = "G_*.pth") -> str: + """ + 指定されたディレクトリから最新のチェックポイントのパスを取得する + + Args: + model_dir_path (Union[str, Path]): モデルが保存されているディレクトリのパス + regex (str): チェックポイントのファイル名の正規表現 + + Returns: + str: 最新のチェックポイントのパス + """ + + f_list = glob.glob(os.path.join(str(model_dir_path), regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + try: + x = f_list[-1] + except IndexError: + raise ValueError(f"No checkpoint found in {model_dir_path} with regex {regex}") + + return x diff --git a/style_bert_vits2/models/utils/safetensors.py b/style_bert_vits2/models/utils/safetensors.py new file mode 100644 index 000000000..8917c778e --- /dev/null +++ b/style_bert_vits2/models/utils/safetensors.py @@ -0,0 +1,91 @@ +from pathlib import Path +from typing import Any, Optional, Union + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from style_bert_vits2.logging import logger + + +def load_safetensors( + checkpoint_path: Union[str, Path], + model: torch.nn.Module, + for_infer: bool = False, +) -> tuple[torch.nn.Module, Optional[int]]: + """ + 指定されたパスから safetensors モデルを読み込み、モデルとイテレーションを返す。 + + Args: + checkpoint_path (Union[str, Path]): モデルのチェックポイントファイルのパス + model (torch.nn.Module): 読み込む対象のモデル + for_infer (bool): 推論用に読み込むかどうかのフラグ + + Returns: + tuple[torch.nn.Module, Optional[int]]: 読み込まれたモデルとイテレーション番号(存在する場合) + """ + + tensors: dict[str, Any] = {} + iteration: Optional[int] = None + with safe_open(str(checkpoint_path), framework="pt", device="cpu") as f: # type: ignore + for key in f.keys(): + if key == "iteration": + iteration = f.get_tensor(key).item() + tensors[key] = f.get_tensor(key) + if hasattr(model, "module"): + result = model.module.load_state_dict(tensors, strict=False) + else: + result = model.load_state_dict(tensors, strict=False) + for key in result.missing_keys: + if key.startswith("enc_q") and for_infer: + continue + logger.warning(f"Missing key: {key}") + for key in result.unexpected_keys: + if key == "iteration": + continue + logger.warning(f"Unexpected key: {key}") + if iteration is None: + logger.info(f"Loaded '{checkpoint_path}'") + else: + logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})") + + return model, iteration + + +def save_safetensors( + model: torch.nn.Module, + iteration: int, + checkpoint_path: Union[str, Path], + is_half: bool = False, + for_infer: bool = False, +) -> None: + """ + モデルを safetensors 形式で保存する。 + + Args: + model (torch.nn.Module): 保存するモデル + iteration (int): イテレーション番号 + checkpoint_path (Union[str, Path]): 保存先のパス + is_half (bool): モデルを半精度で保存するかどうかのフラグ + for_infer (bool): 推論用に保存するかどうかのフラグ + """ + + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + keys = [] + for k in state_dict: + if "enc_q" in k and for_infer: + continue # noqa: E701 + keys.append(k) + + new_dict = ( + {k: state_dict[k].half() for k in keys} + if is_half + else {k: state_dict[k] for k in keys} + ) + new_dict["iteration"] = torch.LongTensor([iteration]) + logger.info(f"Saved safetensors to {checkpoint_path}") + + save_file(new_dict, checkpoint_path) diff --git a/train_ms.py b/train_ms.py index b2cb02bf4..3d50c0398 100644 --- a/train_ms.py +++ b/train_ms.py @@ -248,10 +248,7 @@ def run(): drop_last=False, collate_fn=collate_fn, ) - if ( - "use_noise_scaled_mas" in hps.model.keys() - and hps.model.use_noise_scaled_mas is True - ): + if hps.model.use_noise_scaled_mas is True: logger.info("Using noise scaled MAS for VITS2") mas_noise_scale_initial = 0.01 noise_scale_delta = 2e-6 @@ -259,10 +256,7 @@ def run(): logger.info("Using normal MAS for VITS1") mas_noise_scale_initial = 0.0 noise_scale_delta = 0.0 - if ( - "use_duration_discriminator" in hps.model.keys() - and hps.model.use_duration_discriminator is True - ): + if hps.model.use_duration_discriminator is True: logger.info("Using duration discriminator for VITS2") net_dur_disc = DurationDiscriminator( hps.model.hidden_channels, @@ -271,10 +265,7 @@ def run(): 0.1, gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, ).cuda(local_rank) - if ( - "use_spk_conditioned_encoder" in hps.model.keys() - and hps.model.use_spk_conditioned_encoder is True - ): + if hps.model.use_spk_conditioned_encoder is True: if hps.data.n_speakers == 0: raise ValueError( "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model" @@ -370,31 +361,25 @@ def run(): if utils.is_resuming(model_dir): if net_dur_disc is not None: - _, _, dur_resume_lr, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(model_dir, "DUR_*.pth"), + _, _, dur_resume_lr, epoch_str = utils.checkpoints.load_checkpoint( + utils.checkpoints.get_latest_checkpoint_path(model_dir, "DUR_*.pth"), net_dur_disc, optim_dur_disc, - skip_optimizer=( - hps.train.skip_optimizer if "skip_optimizer" in hps.train else True - ), + skip_optimizer=hps.train.skip_optimizer, ) if not optim_dur_disc.param_groups[0].get("initial_lr"): optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr - _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(model_dir, "G_*.pth"), + _, optim_g, g_resume_lr, epoch_str = utils.checkpoints.load_checkpoint( + utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth"), net_g, optim_g, - skip_optimizer=( - hps.train.skip_optimizer if "skip_optimizer" in hps.train else True - ), + skip_optimizer=hps.train.skip_optimizer, ) - _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(model_dir, "D_*.pth"), + _, optim_d, d_resume_lr, epoch_str = utils.checkpoints.load_checkpoint( + utils.checkpoints.get_latest_checkpoint_path(model_dir, "D_*.pth"), net_d, optim_d, - skip_optimizer=( - hps.train.skip_optimizer if "skip_optimizer" in hps.train else True - ), + skip_optimizer=hps.train.skip_optimizer, ) if not optim_g.param_groups[0].get("initial_lr"): optim_g.param_groups[0]["initial_lr"] = g_resume_lr @@ -404,21 +389,21 @@ def run(): epoch_str = max(epoch_str, 1) # global_step = (epoch_str - 1) * len(train_loader) global_step = int( - utils.get_steps(utils.latest_checkpoint_path(model_dir, "G_*.pth")) + utils.get_steps(utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth")) ) logger.info( f"******************Found the model. Current epoch is {epoch_str}, gloabl step is {global_step}*********************" ) else: try: - _ = utils.load_safetensors( + _ = utils.safetensors.load_safetensors( os.path.join(model_dir, "G_0.safetensors"), net_g ) - _ = utils.load_safetensors( + _ = utils.safetensors.load_safetensors( os.path.join(model_dir, "D_0.safetensors"), net_d ) if net_dur_disc is not None: - _ = utils.load_safetensors( + _ = utils.safetensors.load_safetensors( os.path.join(model_dir, "DUR_0.safetensors"), net_dur_disc ) logger.info("Loaded the pretrained models.") @@ -511,14 +496,16 @@ def lr_lambda(epoch): if epoch == hps.train.epochs: # Save the final models - utils.save_checkpoint( + assert optim_g is not None + utils.checkpoints.save_checkpoint( net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(model_dir, "G_{}.pth".format(global_step)), ) - utils.save_checkpoint( + assert optim_d is not None + utils.checkpoints.save_checkpoint( net_d, optim_d, hps.train.learning_rate, @@ -526,14 +513,15 @@ def lr_lambda(epoch): os.path.join(model_dir, "D_{}.pth".format(global_step)), ) if net_dur_disc is not None: - utils.save_checkpoint( + assert optim_dur_disc is not None + utils.checkpoints.save_checkpoint( net_dur_disc, optim_dur_disc, hps.train.learning_rate, epoch, os.path.join(model_dir, "DUR_{}.pth".format(global_step)), ) - utils.save_safetensors( + utils.safetensors.save_safetensors( net_g, epoch, os.path.join( @@ -804,14 +792,15 @@ def train_and_evaluate( ): if not hps.speedup: evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint( + assert hps.model_dir is not None + utils.checkpoints.save_checkpoint( net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), ) - utils.save_checkpoint( + utils.checkpoints.save_checkpoint( net_d, optim_d, hps.train.learning_rate, @@ -819,7 +808,7 @@ def train_and_evaluate( os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), ) if net_dur_disc is not None: - utils.save_checkpoint( + utils.checkpoints.save_checkpoint( net_dur_disc, optim_dur_disc, hps.train.learning_rate, @@ -828,13 +817,13 @@ def train_and_evaluate( ) keep_ckpts = config.train_ms_config.keep_ckpts if keep_ckpts > 0: - utils.clean_checkpoints( - path_to_models=hps.model_dir, + utils.checkpoints.clean_checkpoints( + model_dir_path=hps.model_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True, ) # Save safetensors (for inference) to `model_assets/{model_name}` - utils.save_safetensors( + utils.safetensors.save_safetensors( net_g, epoch, os.path.join( diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index a464d29b9..5ed68a103 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -247,10 +247,7 @@ def run(): drop_last=False, collate_fn=collate_fn, ) - if ( - "use_noise_scaled_mas" in hps.model.keys() - and hps.model.use_noise_scaled_mas is True - ): + if hps.model.use_noise_scaled_mas is True: logger.info("Using noise scaled MAS for VITS2") mas_noise_scale_initial = 0.01 noise_scale_delta = 2e-6 @@ -258,10 +255,7 @@ def run(): logger.info("Using normal MAS for VITS1") mas_noise_scale_initial = 0.0 noise_scale_delta = 0.0 - if ( - "use_duration_discriminator" in hps.model.keys() - and hps.model.use_duration_discriminator is True - ): + if hps.model.use_duration_discriminator is True: logger.info("Using duration discriminator for VITS2") net_dur_disc = DurationDiscriminator( hps.model.hidden_channels, @@ -272,19 +266,13 @@ def run(): ).cuda(local_rank) else: net_dur_disc = None - if ( - "use_wavlm_discriminator" in hps.model.keys() - and hps.model.use_wavlm_discriminator is True - ): + if hps.model.use_wavlm_discriminator is True: net_wd = WavLMDiscriminator( hps.model.slm.hidden, hps.model.slm.nlayers, hps.model.slm.initial_channel ).cuda(local_rank) else: net_wd = None - if ( - "use_spk_conditioned_encoder" in hps.model.keys() - and hps.model.use_spk_conditioned_encoder is True - ): + if hps.model.use_spk_conditioned_encoder is True: if hps.data.n_speakers == 0: raise ValueError( "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model" @@ -394,15 +382,11 @@ def run(): if utils.is_resuming(model_dir): if net_dur_disc is not None: try: - _, _, dur_resume_lr, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(model_dir, "DUR_*.pth"), + _, _, dur_resume_lr, epoch_str = utils.checkpoints.load_checkpoint( + utils.checkpoints.get_latest_checkpoint_path(model_dir, "DUR_*.pth"), net_dur_disc, optim_dur_disc, - skip_optimizer=( - hps.train.skip_optimizer - if "skip_optimizer" in hps.train - else True - ), + skip_optimizer=hps.train.skip_optimizer, ) if not optim_dur_disc.param_groups[0].get("initial_lr"): optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr @@ -412,15 +396,11 @@ def run(): print("Initialize dur_disc") if net_wd is not None: try: - _, optim_wd, wd_resume_lr, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(model_dir, "WD_*.pth"), + _, optim_wd, wd_resume_lr, epoch_str = utils.checkpoints.load_checkpoint( + utils.checkpoints.get_latest_checkpoint_path(model_dir, "WD_*.pth"), net_wd, optim_wd, - skip_optimizer=( - hps.train.skip_optimizer - if "skip_optimizer" in hps.train - else True - ), + skip_optimizer=hps.train.skip_optimizer, ) if not optim_wd.param_groups[0].get("initial_lr"): optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr @@ -430,21 +410,17 @@ def run(): logger.info("Initialize wavlm") try: - _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(model_dir, "G_*.pth"), + _, optim_g, g_resume_lr, epoch_str = utils.checkpoints.load_checkpoint( + utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth"), net_g, optim_g, - skip_optimizer=( - hps.train.skip_optimizer if "skip_optimizer" in hps.train else True - ), + skip_optimizer=hps.train.skip_optimizer, ) - _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint( - utils.latest_checkpoint_path(model_dir, "D_*.pth"), + _, optim_d, d_resume_lr, epoch_str = utils.checkpoints.load_checkpoint( + utils.checkpoints.get_latest_checkpoint_path(model_dir, "D_*.pth"), net_d, optim_d, - skip_optimizer=( - hps.train.skip_optimizer if "skip_optimizer" in hps.train else True - ), + skip_optimizer=hps.train.skip_optimizer, ) if not optim_g.param_groups[0].get("initial_lr"): optim_g.param_groups[0]["initial_lr"] = g_resume_lr @@ -454,7 +430,7 @@ def run(): epoch_str = max(epoch_str, 1) # global_step = (epoch_str - 1) * len(train_loader) global_step = int( - utils.get_steps(utils.latest_checkpoint_path(model_dir, "G_*.pth")) + utils.get_steps(utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth")) ) logger.info( f"******************Found the model. Current epoch is {epoch_str}, gloabl step is {global_step}*********************" @@ -468,18 +444,18 @@ def run(): global_step = 0 else: try: - _ = utils.load_safetensors( + _ = utils.safetensors.load_safetensors( os.path.join(model_dir, "G_0.safetensors"), net_g ) - _ = utils.load_safetensors( + _ = utils.safetensors.load_safetensors( os.path.join(model_dir, "D_0.safetensors"), net_d ) if net_dur_disc is not None: - _ = utils.load_safetensors( + _ = utils.safetensors.load_safetensors( os.path.join(model_dir, "DUR_0.safetensors"), net_dur_disc ) if net_wd is not None: - _ = utils.load_safetensors( + _ = utils.safetensors.load_safetensors( os.path.join(model_dir, "WD_0.safetensors"), net_wd ) logger.info("Loaded the pretrained models.") @@ -586,14 +562,16 @@ def lr_lambda(epoch): scheduler_wd.step() if epoch == hps.train.epochs: # Save the final models - utils.save_checkpoint( + assert optim_g is not None + utils.checkpoints.save_checkpoint( net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(model_dir, "G_{}.pth".format(global_step)), ) - utils.save_checkpoint( + assert optim_d is not None + utils.checkpoints.save_checkpoint( net_d, optim_d, hps.train.learning_rate, @@ -601,7 +579,8 @@ def lr_lambda(epoch): os.path.join(model_dir, "D_{}.pth".format(global_step)), ) if net_dur_disc is not None: - utils.save_checkpoint( + assert optim_dur_disc is not None + utils.checkpoints.save_checkpoint( net_dur_disc, optim_dur_disc, hps.train.learning_rate, @@ -609,14 +588,15 @@ def lr_lambda(epoch): os.path.join(model_dir, "DUR_{}.pth".format(global_step)), ) if net_wd is not None: - utils.save_checkpoint( + assert optim_wd is not None + utils.checkpoints.save_checkpoint( net_wd, optim_wd, hps.train.learning_rate, epoch, os.path.join(model_dir, "WD_{}.pth".format(global_step)), ) - utils.save_safetensors( + utils.safetensors.save_safetensors( net_g, epoch, os.path.join( @@ -949,14 +929,14 @@ def train_and_evaluate( ): if not hps.speedup: evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint( + utils.checkpoints.save_checkpoint( net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), ) - utils.save_checkpoint( + utils.checkpoints.save_checkpoint( net_d, optim_d, hps.train.learning_rate, @@ -964,7 +944,7 @@ def train_and_evaluate( os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), ) if net_dur_disc is not None: - utils.save_checkpoint( + utils.checkpoints.save_checkpoint( net_dur_disc, optim_dur_disc, hps.train.learning_rate, @@ -972,7 +952,7 @@ def train_and_evaluate( os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)), ) if net_wd is not None: - utils.save_checkpoint( + utils.checkpoints.save_checkpoint( net_wd, optim_wd, hps.train.learning_rate, @@ -981,13 +961,13 @@ def train_and_evaluate( ) keep_ckpts = config.train_ms_config.keep_ckpts if keep_ckpts > 0: - utils.clean_checkpoints( - path_to_models=hps.model_dir, + utils.checkpoints.clean_checkpoints( + model_dir_path=hps.model_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True, ) # Save safetensors (for inference) to `model_assets/{model_name}` - utils.save_safetensors( + utils.safetensors.save_safetensors( net_g, epoch, os.path.join( From 61e2a1deae543e2403ffa56e97570e11fad4a609 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sat, 9 Mar 2024 17:26:42 +0000 Subject: [PATCH 067/148] Refactor: add type hints to style_bert_vits2.models.utils --- style_bert_vits2/models/utils/__init__.py | 158 +++++++++++++++++----- style_bert_vits2/nlp/__init__.py | 2 +- 2 files changed, 127 insertions(+), 33 deletions(-) diff --git a/style_bert_vits2/models/utils/__init__.py b/style_bert_vits2/models/utils/__init__.py index f488e091c..51e19d9a4 100644 --- a/style_bert_vits2/models/utils/__init__.py +++ b/style_bert_vits2/models/utils/__init__.py @@ -3,28 +3,44 @@ import os import re import subprocess +from pathlib import Path +from typing import Any, Optional, Union import numpy as np import torch +from numpy.typing import NDArray from scipy.io.wavfile import read +from torch.utils.tensorboard import SummaryWriter from style_bert_vits2.logging import logger from style_bert_vits2.models.utils import checkpoints # type: ignore from style_bert_vits2.models.utils import safetensors # type: ignore -MATPLOTLIB_FLAG = False +__is_matplotlib_imported = False def summarize( - writer, - global_step, - scalars={}, - histograms={}, - images={}, - audios={}, - audio_sampling_rate=22050, -): + writer: SummaryWriter, + global_step: int, + scalars: dict[str, float] = {}, + histograms: dict[str, Any] = {}, + images: dict[str, Any] = {}, + audios: dict[str, Any] = {}, + audio_sampling_rate: int = 22050, +) -> None: + """ + 指定されたデータを TensorBoard にまとめて追加する + + Args: + writer (SummaryWriter): TensorBoard への書き込みを行うオブジェクト + global_step (int): グローバルステップ数 + scalars (dict[str, float]): スカラー値の辞書 + histograms (dict[str, Any]): ヒストグラムの辞書 + images (dict[str, Any]): 画像データの辞書 + audios (dict[str, Any]): 音声データの辞書 + audio_sampling_rate (int): 音声データのサンプリングレート + """ for k, v in scalars.items(): writer.add_scalar(k, v, global_step) for k, v in histograms.items(): @@ -35,7 +51,16 @@ def summarize( writer.add_audio(k, v, global_step, audio_sampling_rate) -def is_resuming(dir_path): +def is_resuming(dir_path: Union[str, Path]) -> bool: + """ + 指定されたディレクトリパスに再開可能なモデルが存在するかどうかを返す + + Args: + dir_path: チェックするディレクトリのパス + + Returns: + bool: 再開可能なモデルが存在するかどうか + """ # JP-ExtraバージョンではDURがなくWDがあったり変わるため、Gのみで判断する g_list = glob.glob(os.path.join(dir_path, "G_*.pth")) # d_list = glob.glob(os.path.join(dir_path, "D_*.pth")) @@ -43,13 +68,23 @@ def is_resuming(dir_path): return len(g_list) > 0 -def plot_spectrogram_to_numpy(spectrogram): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: +def plot_spectrogram_to_numpy(spectrogram: NDArray[Any]) -> NDArray[Any]: + """ + 指定されたスペクトログラムを画像データに変換する + + Args: + spectrogram (NDArray[Any]): スペクトログラム + + Returns: + NDArray[Any]: 画像データ + """ + + global __is_matplotlib_imported + if not __is_matplotlib_imported: import matplotlib matplotlib.use("Agg") - MATPLOTLIB_FLAG = True + __is_matplotlib_imported = True mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt @@ -63,23 +98,33 @@ def plot_spectrogram_to_numpy(spectrogram): plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") # type: ignore data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data -def plot_alignment_to_numpy(alignment, info=None): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: +def plot_alignment_to_numpy(alignment: NDArray[Any], info: Optional[str] = None) -> NDArray[Any]: + """ + 指定されたアライメントを画像データに変換する + + Args: + alignment (NDArray[Any]): アライメント + info (Optional[str]): 画像に追加する情報 + + Returns: + NDArray[Any]: 画像データ + """ + + global __is_matplotlib_imported + if not __is_matplotlib_imported: import matplotlib matplotlib.use("Agg") - MATPLOTLIB_FLAG = True + __is_matplotlib_imported = True mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt - import numpy as np fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow( @@ -94,44 +139,93 @@ def plot_alignment_to_numpy(alignment, info=None): plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") # type: ignore data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data -def load_wav_to_torch(full_path): +def load_wav_to_torch(full_path: Union[str, Path]) -> tuple[torch.FloatTensor, int]: + """ + 指定された音声ファイルを読み込み、PyTorch のテンソルに変換して返す + + Args: + full_path (Union[str, Path]): 音声ファイルのパス + + Returns: + tuple[torch.FloatTensor, int]: 音声データのテンソルとサンプリングレート + """ + sampling_rate, data = read(full_path) return torch.FloatTensor(data.astype(np.float32)), sampling_rate -def load_filepaths_and_text(filename, split="|"): +def load_filepaths_and_text(filename: Union[str, Path], split: str = "|") -> list[list[str]]: + """ + 指定されたファイルからファイルパスとテキストを読み込む + + Args: + filename (Union[str, Path]): ファイルのパス + split (str): ファイルの区切り文字 (デフォルト: "|") + + Returns: + list[list[str]]: ファイルパスとテキストのリスト + """ + with open(filename, encoding="utf-8") as f: filepaths_and_text = [line.strip().split(split) for line in f] return filepaths_and_text -def get_logger(model_dir, filename="train.log"): +def get_logger(model_dir_path: Union[str, Path], filename: str = "train.log") -> logging.Logger: + """ + ロガーを取得する + + Args: + model_dir_path (Union[str, Path]): ログを保存するディレクトリのパス + filename (str): ログファイルの名前 (デフォルト: "train.log") + + Returns: + logging.Logger: ロガー + """ + global logger - logger = logging.getLogger(os.path.basename(model_dir)) + logger = logging.getLogger(os.path.basename(model_dir_path)) logger.setLevel(logging.DEBUG) formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") - if not os.path.exists(model_dir): - os.makedirs(model_dir) - h = logging.FileHandler(os.path.join(model_dir, filename)) + if not os.path.exists(model_dir_path): + os.makedirs(model_dir_path) + h = logging.FileHandler(os.path.join(model_dir_path, filename)) h.setLevel(logging.DEBUG) h.setFormatter(formatter) logger.addHandler(h) return logger -def get_steps(model_path): - matches = re.findall(r"\d+", model_path) +def get_steps(model_path: Union[str, Path]) -> Optional[int]: + """ + モデルのパスからイテレーション番号を取得する + + Args: + model_path (Union[str, Path]): モデルのパス + + Returns: + Optional[int]: イテレーション番号 + """ + + matches = re.findall(r"\d+", model_path) # type: ignore return matches[-1] if matches else None -def check_git_hash(model_dir): +def check_git_hash(model_dir_path: Union[str, Path]) -> None: + """ + モデルのディレクトリに .git ディレクトリが存在する場合、ハッシュ値を比較する + + Args: + model_dir_path (Union[str, Path]): モデルのディレクトリのパス + """ + source_dir = os.path.dirname(os.path.realpath(__file__)) if not os.path.exists(os.path.join(source_dir, ".git")): logger.warning( @@ -143,7 +237,7 @@ def check_git_hash(model_dir): cur_hash = subprocess.getoutput("git rev-parse HEAD") - path = os.path.join(model_dir, "githash") + path = os.path.join(model_dir_path, "githash") if os.path.exists(path): saved_hash = open(path).read() if saved_hash != cur_hash: diff --git a/style_bert_vits2/nlp/__init__.py b/style_bert_vits2/nlp/__init__.py index afc6cb09e..683d6d479 100644 --- a/style_bert_vits2/nlp/__init__.py +++ b/style_bert_vits2/nlp/__init__.py @@ -8,7 +8,7 @@ ) # __init__.py は配下のモジュールをインポートした時点で実行される -# Pytorch のインポートは重いので、型チェック時以外はインポートしない +# PyTorch のインポートは重いので、型チェック時以外はインポートしない if TYPE_CHECKING: import torch From 96d22102f3f881964c047bc99507369bde62ac83 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sat, 9 Mar 2024 17:45:56 +0000 Subject: [PATCH 068/148] Refactor: separate adjust_voice() function from tts_model.py --- style_bert_vits2/tts_model.py | 169 +++++++++++++++------------------- style_bert_vits2/voice.py | 46 +++++++++ webui/inference.py | 6 +- webui/merge.py | 6 +- 4 files changed, 126 insertions(+), 101 deletions(-) create mode 100644 style_bert_vits2/voice.py diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index 86c8425e5..a3a83c316 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -25,54 +25,23 @@ from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra from style_bert_vits2.logging import logger +from style_bert_vits2.voice import adjust_voice -def adjust_voice( - fs: int, - wave: NDArray[Any], - pitch_scale: float, - intonation_scale: float, -) -> tuple[int, NDArray[Any]]: - - if pitch_scale == 1.0 and intonation_scale == 1.0: - # 初期値の場合は、音質劣化を避けるためにそのまま返す - return fs, wave - - try: - import pyworld - except ImportError: - raise ImportError( - "pyworld is not installed. Please install it by `pip install pyworld`" - ) - - # pyworld で f0 を加工して合成 - # pyworld よりもよいのがあるかもしれないが…… - ## pyworld は Cython で書かれているが、スタブファイルがないため型補完が全く効かない… - - wave = wave.astype(np.double) - - # 質が高そうだしとりあえずharvestにしておく - f0, t = pyworld.harvest(wave, fs) # type: ignore - - sp = pyworld.cheaptrick(wave, f0, t, fs) # type: ignore - ap = pyworld.d4c(wave, f0, t, fs) # type: ignore - - non_zero_f0 = [f for f in f0 if f != 0] - f0_mean = sum(non_zero_f0) / len(non_zero_f0) - - for i, f in enumerate(f0): - if f == 0: - continue - f0[i] = pitch_scale * f0_mean + intonation_scale * (f - f0_mean) - - wave = pyworld.synthesize(f0, sp, ap, fs) # type: ignore - return fs, wave +class Model: + """ + Style-Bert-Vits2 の音声合成モデルを操作するためのクラス + モデル/ハイパーパラメータ/スタイルベクトルのパスとデバイスを指定して初期化し、model.infer() メソッドを呼び出すと音声合成を行える + """ -class Model: def __init__( - self, model_path: Path, config_path: Path, style_vec_path: Path, device: str - ): + self, + model_path: Path, + config_path: Path, + style_vec_path: Path, + device: str, + ) -> None: self.model_path: Path = model_path self.config_path: Path = config_path self.style_vec_path: Path = style_vec_path @@ -99,7 +68,8 @@ def __init__( self.net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None - def load_net_g(self): + + def load_net_g(self) -> None: self.net_g = get_net_g( model_path=str(self.model_path), version=self.hps.version, @@ -107,15 +77,15 @@ def load_net_g(self): hps=self.hps, ) + def get_style_vector(self, style_id: int, weight: float = 1.0) -> NDArray[Any]: mean = self.style_vectors[0] style_vec = self.style_vectors[style_id] style_vec = mean + (style_vec - mean) * weight return style_vec - def get_style_vector_from_audio( - self, audio_path: str, weight: float = 1.0 - ) -> NDArray[Any]: + + def get_style_vector_from_audio(self, audio_path: str, weight: float = 1.0) -> NDArray[Any]: from style_gen import get_style_vector xvec = get_style_vector(audio_path) @@ -123,6 +93,7 @@ def get_style_vector_from_audio( xvec = mean + (xvec - mean) * weight return xvec + def infer( self, text: str, @@ -167,20 +138,20 @@ def infer( if not line_split: with torch.no_grad(): audio = infer( - text=text, - sdp_ratio=sdp_ratio, - noise_scale=noise, - noise_scale_w=noisew, - length_scale=length, - sid=sid, - language=language, - hps=self.hps, - net_g=self.net_g, - device=self.device, - assist_text=assist_text, - assist_text_weight=assist_text_weight, - style_vec=style_vector, - given_tone=given_tone, + text = text, + sdp_ratio = sdp_ratio, + noise_scale = noise, + noise_scale_w = noisew, + length_scale = length, + sid = sid, + language = language, + hps = self.hps, + net_g = self.net_g, + device = self.device, + assist_text = assist_text, + assist_text_weight = assist_text_weight, + style_vec = style_vector, + given_tone = given_tone, ) else: texts = text.split("\n") @@ -190,19 +161,19 @@ def infer( for i, t in enumerate(texts): audios.append( infer( - text=t, - sdp_ratio=sdp_ratio, - noise_scale=noise, - noise_scale_w=noisew, - length_scale=length, - sid=sid, - language=language, - hps=self.hps, - net_g=self.net_g, - device=self.device, - assist_text=assist_text, - assist_text_weight=assist_text_weight, - style_vec=style_vector, + text = t, + sdp_ratio = sdp_ratio, + noise_scale = noise, + noise_scale_w = noisew, + length_scale = length, + sid = sid, + language = language, + hps = self.hps, + net_g = self.net_g, + device = self.device, + assist_text = assist_text, + assist_text_weight = assist_text_weight, + style_vec = style_vector, ) ) if i != len(texts) - 1: @@ -211,10 +182,10 @@ def infer( logger.info("Audio data generated successfully") if not (pitch_scale == 1.0 and intonation_scale == 1.0): _, audio = adjust_voice( - fs=self.hps.data.sampling_rate, - wave=audio, - pitch_scale=pitch_scale, - intonation_scale=intonation_scale, + fs = self.hps.data.sampling_rate, + wave = audio, + pitch_scale = pitch_scale, + intonation_scale = intonation_scale, ) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -223,8 +194,13 @@ def infer( class ModelHolder: - def __init__(self, root_dir: Path, device: str): - self.root_dir: Path = root_dir + """ + Style-Bert-Vits2 の音声合成モデルを管理するためのクラス + """ + + + def __init__(self, model_root_dir: Path, device: str) -> None: + self.root_dir: Path = model_root_dir self.device: str = device self.model_files_dict: dict[str, list[Path]] = {} self.current_model: Optional[Model] = None @@ -233,7 +209,8 @@ def __init__(self, root_dir: Path, device: str): self.models_info: list[dict[str, Union[str, list[str]]]] = [] self.refresh() - def refresh(self): + + def refresh(self) -> None: self.model_files_dict = {} self.model_names = [] self.current_model = None @@ -269,7 +246,8 @@ def refresh(self): "speakers": speakers, }) - def load_model(self, model_name: str, model_path_str: str): + + def load_model(self, model_name: str, model_path_str: str) -> Model: model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") @@ -277,16 +255,15 @@ def load_model(self, model_name: str, model_path_str: str): raise ValueError(f"Model file `{model_path}` is not found") if self.current_model is None or self.current_model.model_path != model_path: self.current_model = Model( - model_path=model_path, - config_path=self.root_dir / model_name / "config.json", - style_vec_path=self.root_dir / model_name / "style_vectors.npy", - device=self.device, + model_path = model_path, + config_path = self.root_dir / model_name / "config.json", + style_vec_path = self.root_dir / model_name / "style_vectors.npy", + device = self.device, ) return self.current_model - def load_model_gr( - self, model_name: str, model_path_str: str - ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: + + def load_model_for_gradio(self, model_name: str, model_path_str: str) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") @@ -305,10 +282,10 @@ def load_model_gr( gr.Dropdown(choices=speakers, value=speakers[0]), # type: ignore ) self.current_model = Model( - model_path=model_path, - config_path=self.root_dir / model_name / "config.json", - style_vec_path=self.root_dir / model_name / "style_vectors.npy", - device=self.device, + model_path = model_path, + config_path = self.root_dir / model_name / "config.json", + style_vec_path = self.root_dir / model_name / "style_vectors.npy", + device = self.device, ) speakers = list(self.current_model.spk2id.keys()) styles = list(self.current_model.style2id.keys()) @@ -318,11 +295,13 @@ def load_model_gr( gr.Dropdown(choices=speakers, value=speakers[0]), # type: ignore ) - def update_model_files_gr(self, model_name: str) -> gr.Dropdown: + + def update_model_files_for_gradio(self, model_name: str) -> gr.Dropdown: model_files = self.model_files_dict[model_name] return gr.Dropdown(choices=model_files, value=model_files[0]) # type: ignore - def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]: + + def update_model_names_for_gradio(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]: self.refresh() initial_model_name = self.model_names[0] initial_model_files = self.model_files_dict[initial_model_name] diff --git a/style_bert_vits2/voice.py b/style_bert_vits2/voice.py new file mode 100644 index 000000000..f1cfcacb7 --- /dev/null +++ b/style_bert_vits2/voice.py @@ -0,0 +1,46 @@ +from typing import Any + +import numpy as np +from numpy.typing import NDArray + + +def adjust_voice( + fs: int, + wave: NDArray[Any], + pitch_scale: float = 1.0, + intonation_scale: float = 1.0, +) -> tuple[int, NDArray[Any]]: + + if pitch_scale == 1.0 and intonation_scale == 1.0: + # 初期値の場合は、音質劣化を避けるためにそのまま返す + return fs, wave + + try: + import pyworld + except ImportError: + raise ImportError( + "pyworld is not installed. Please install it by `pip install pyworld`" + ) + + # pyworld で f0 を加工して合成 + # pyworld よりもよいのがあるかもしれないが…… + ## pyworld は Cython で書かれているが、スタブファイルがないため型補完が全く効かない… + + wave = wave.astype(np.double) + + # 質が高そうだしとりあえずharvestにしておく + f0, t = pyworld.harvest(wave, fs) # type: ignore + + sp = pyworld.cheaptrick(wave, f0, t, fs) # type: ignore + ap = pyworld.d4c(wave, f0, t, fs) # type: ignore + + non_zero_f0 = [f for f in f0 if f != 0] + f0_mean = sum(non_zero_f0) / len(non_zero_f0) + + for i, f in enumerate(f0): + if f == 0: + continue + f0[i] = pitch_scale * f0_mean + intonation_scale * (f - f0_mean) + + wave = pyworld.synthesize(f0, sp, ap, fs) # type: ignore + return fs, wave diff --git a/webui/inference.py b/webui/inference.py index 536b38b4a..9c2bf63f6 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -446,7 +446,7 @@ def tts_fn( ) model_name.change( - model_holder.update_model_files_gr, + model_holder.update_model_files_for_gradio, inputs=[model_name], outputs=[model_path], ) @@ -454,12 +454,12 @@ def tts_fn( model_path.change(make_non_interactive, outputs=[tts_button]) refresh_button.click( - model_holder.update_model_names_gr, + model_holder.update_model_names_for_gradio, outputs=[model_name, model_path, tts_button], ) load_button.click( - model_holder.load_model_gr, + model_holder.load_model_for_gradio, inputs=[model_name, model_path], outputs=[style, tts_button, speaker], ) diff --git a/webui/merge.py b/webui/merge.py index e9dd1f004..85806f4a6 100644 --- a/webui/merge.py +++ b/webui/merge.py @@ -255,7 +255,7 @@ def simple_tts(model_name, text, style=DEFAULT_STYLE, style_weight=1.0): def update_two_model_names_dropdown(model_holder: ModelHolder): - new_names, new_files, _ = model_holder.update_model_names_gr() + new_names, new_files, _ = model_holder.update_model_names_for_gradio() return new_names, new_files, new_names, new_files @@ -444,12 +444,12 @@ def create_merge_app(model_holder: ModelHolder) -> gr.Blocks: audio_output = gr.Audio(label="結果") model_name_a.change( - model_holder.update_model_files_gr, + model_holder.update_model_files_for_gradio, inputs=[model_name_a], outputs=[model_path_a], ) model_name_b.change( - model_holder.update_model_files_gr, + model_holder.update_model_files_for_gradio, inputs=[model_name_b], outputs=[model_path_b], ) From 1d320915e0704fbaabd49d117b3ba8b8ccf363c2 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sat, 9 Mar 2024 18:10:25 +0000 Subject: [PATCH 069/148] Refactor: make tts_model.py independent of style_gen.py --- style_bert_vits2/tts_model.py | 102 ++++++++++++++++++--------- style_bert_vits2/utils/subprocess.py | 4 +- style_gen.py | 3 +- 3 files changed, 70 insertions(+), 39 deletions(-) diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index a3a83c316..42f6b84e3 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -4,6 +4,7 @@ import gradio as gr import numpy as np +import pyannote.audio import torch from gradio.processing_utils import convert_to_16_bit_wav from numpy.typing import NDArray @@ -30,8 +31,8 @@ class Model: """ - Style-Bert-Vits2 の音声合成モデルを操作するためのクラス - モデル/ハイパーパラメータ/スタイルベクトルのパスとデバイスを指定して初期化し、model.infer() メソッドを呼び出すと音声合成を行える + Style-Bert-Vits2 の音声合成モデルを操作するためのクラス。 + モデル/ハイパーパラメータ/スタイルベクトルのパスとデバイスを指定して初期化し、model.infer() メソッドを呼び出すと音声合成を行える。 """ @@ -46,50 +47,81 @@ def __init__( self.config_path: Path = config_path self.style_vec_path: Path = style_vec_path self.device: str = device - self.hps: HyperParameters = HyperParameters.load_from_json(self.config_path) - self.spk2id: dict[str, int] = self.hps.data.spk2id + self.hyper_parameters: HyperParameters = HyperParameters.load_from_json(self.config_path) + self.spk2id: dict[str, int] = self.hyper_parameters.data.spk2id self.id2spk: dict[int, str] = {v: k for k, v in self.spk2id.items()} - self.num_styles: int = self.hps.data.num_styles - if hasattr(self.hps.data, "style2id"): - self.style2id: dict[str, int] = self.hps.data.style2id + num_styles: int = self.hyper_parameters.data.num_styles + if hasattr(self.hyper_parameters.data, "style2id"): + self.style2id: dict[str, int] = self.hyper_parameters.data.style2id else: - self.style2id: dict[str, int] = {str(i): i for i in range(self.num_styles)} - if len(self.style2id) != self.num_styles: + self.style2id: dict[str, int] = {str(i): i for i in range(num_styles)} + if len(self.style2id) != num_styles: raise ValueError( - f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})" + f"Number of styles ({num_styles}) does not match the number of style2id ({len(self.style2id)})" ) - self.style_vectors: NDArray[Any] = np.load(self.style_vec_path) - if self.style_vectors.shape[0] != self.num_styles: + self.__style_vector_inference: Optional[pyannote.audio.Inference] = None + self.__style_vectors: NDArray[Any] = np.load(self.style_vec_path) + if self.__style_vectors.shape[0] != num_styles: raise ValueError( - f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})" + f"The number of styles ({num_styles}) does not match the number of style vectors ({self.__style_vectors.shape[0]})" ) - self.net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None + self.__net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None def load_net_g(self) -> None: - self.net_g = get_net_g( + """ + net_g をロードする。 + """ + self.__net_g = get_net_g( model_path=str(self.model_path), - version=self.hps.version, + version=self.hyper_parameters.version, device=self.device, - hps=self.hps, + hps=self.hyper_parameters, ) def get_style_vector(self, style_id: int, weight: float = 1.0) -> NDArray[Any]: - mean = self.style_vectors[0] - style_vec = self.style_vectors[style_id] + """ + スタイルベクトルを取得する。 + + Args: + style_id (int): スタイル ID + weight (float, optional): スタイルベクトルの重み. Defaults to 1.0. + + Returns: + NDArray[Any]: スタイルベクトル + """ + mean = self.__style_vectors[0] + style_vec = self.__style_vectors[style_id] style_vec = mean + (style_vec - mean) * weight return style_vec def get_style_vector_from_audio(self, audio_path: str, weight: float = 1.0) -> NDArray[Any]: - from style_gen import get_style_vector + """ + 音声からスタイルベクトルを推論する。 + + Args: + audio_path (str): 音声ファイルのパス + weight (float, optional): スタイルベクトルの重み. Defaults to 1.0. + Returns: + NDArray[Any]: スタイルベクトル + """ + + # スタイルベクトルを取得するための推論モデルを初期化 + if self.__style_vector_inference is None: + self.__style_vector_inference = pyannote.audio.Inference( + model = pyannote.audio.Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM"), + window = "whole", + ) + self.__style_vector_inference.to(torch.device(self.device)) - xvec = get_style_vector(audio_path) - mean = self.style_vectors[0] + # 音声からスタイルベクトルを推論 + xvec = self.__style_vector_inference(audio_path) + mean = self.__style_vectors[0] xvec = mean + (xvec - mean) * weight return xvec @@ -116,7 +148,7 @@ def infer( intonation_scale: float = 1.0, ) -> tuple[int, NDArray[Any]]: logger.info(f"Start generating audio data from text:\n{text}") - if language != "JP" and self.hps.version.endswith("JP-Extra"): + if language != "JP" and self.hyper_parameters.version.endswith("JP-Extra"): raise ValueError( "The model is trained with JP-Extra, but the language is not JP" ) @@ -125,9 +157,9 @@ def infer( if assist_text == "" or not use_assist_text: assist_text = None - if self.net_g is None: + if self.__net_g is None: self.load_net_g() - assert self.net_g is not None + assert self.__net_g is not None if reference_audio_path is None: style_id = self.style2id[style] style_vector = self.get_style_vector(style_id, style_weight) @@ -145,8 +177,8 @@ def infer( length_scale = length, sid = sid, language = language, - hps = self.hps, - net_g = self.net_g, + hps = self.hyper_parameters, + net_g = self.__net_g, device = self.device, assist_text = assist_text, assist_text_weight = assist_text_weight, @@ -168,8 +200,8 @@ def infer( length_scale = length, sid = sid, language = language, - hps = self.hps, - net_g = self.net_g, + hps = self.hyper_parameters, + net_g = self.__net_g, device = self.device, assist_text = assist_text, assist_text_weight = assist_text_weight, @@ -182,7 +214,7 @@ def infer( logger.info("Audio data generated successfully") if not (pitch_scale == 1.0 and intonation_scale == 1.0): _, audio = adjust_voice( - fs = self.hps.data.sampling_rate, + fs = self.hyper_parameters.data.sampling_rate, wave = audio, pitch_scale = pitch_scale, intonation_scale = intonation_scale, @@ -190,12 +222,12 @@ def infer( with warnings.catch_warnings(): warnings.simplefilter("ignore") audio = convert_to_16_bit_wav(audio) - return (self.hps.data.sampling_rate, audio) + return (self.hyper_parameters.data.sampling_rate, audio) class ModelHolder: """ - Style-Bert-Vits2 の音声合成モデルを管理するためのクラス + Style-Bert-Vits2 の音声合成モデルを管理するためのクラス。 """ @@ -234,10 +266,10 @@ def refresh(self) -> None: continue self.model_files_dict[model_dir.name] = model_files self.model_names.append(model_dir.name) - hps = HyperParameters.load_from_json(config_path) - style2id: dict[str, int] = hps.data.style2id + hyper_parameters = HyperParameters.load_from_json(config_path) + style2id: dict[str, int] = hyper_parameters.data.style2id styles = list(style2id.keys()) - spk2id: dict[str, int] = hps.data.spk2id + spk2id: dict[str, int] = hyper_parameters.data.spk2id speakers = list(spk2id.keys()) self.models_info.append({ "name": model_dir.name, diff --git a/style_bert_vits2/utils/subprocess.py b/style_bert_vits2/utils/subprocess.py index 5ff267b4e..542f94b9d 100644 --- a/style_bert_vits2/utils/subprocess.py +++ b/style_bert_vits2/utils/subprocess.py @@ -8,7 +8,7 @@ def run_script_with_log(cmd: list[str], ignore_warning: bool = False) -> tuple[bool, str]: """ - 指定されたコマンドを実行し、そのログを記録する + 指定されたコマンドを実行し、そのログを記録する。 Args: cmd: 実行するコマンドのリスト @@ -39,7 +39,7 @@ def run_script_with_log(cmd: list[str], ignore_warning: bool = False) -> tuple[b def second_elem_of(original_function: Callable[..., tuple[Any, Any]]) -> Callable[..., Any]: """ - 与えられた関数をラップし、その戻り値の 2 番目の要素のみを返す関数を生成する + 与えられた関数をラップし、その戻り値の 2 番目の要素のみを返す関数を生成する。 Args: original_function (Callable[..., tuple[Any, Any]])): ラップする元の関数 diff --git a/style_gen.py b/style_gen.py index ec0b50778..5190575b0 100644 --- a/style_gen.py +++ b/style_gen.py @@ -6,11 +6,10 @@ import torch from tqdm import tqdm +from config import config from style_bert_vits2.logging import logger -from style_bert_vits2.models import utils from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -from config import config warnings.filterwarnings("ignore", category=UserWarning) from pyannote.audio import Inference, Model From a79b1910fb1945151903458e4ddca938d008d096 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sat, 9 Mar 2024 18:38:16 +0000 Subject: [PATCH 070/148] Add: use hatch to build style-bert-vits2 as a library --- .gitignore | 1 + app.py | 4 +- pyproject.toml | 92 +++++++++++++++++++++++++++++++++++ server_editor.py | 4 +- style_bert_vits2/constants.py | 2 +- tests/__init__.py | 0 6 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 pyproject.toml create mode 100644 tests/__init__.py diff --git a/.gitignore b/.gitignore index 88048d770..b8a19a4e7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ venv/ .venv/ +dist/ .ipynb_checkpoints/ /*.yml diff --git a/app.py b/app.py index f0f94d036..eaf07b7a4 100644 --- a/app.py +++ b/app.py @@ -5,7 +5,7 @@ import torch import yaml -from style_bert_vits2.constants import GRADIO_THEME, LATEST_VERSION +from style_bert_vits2.constants import GRADIO_THEME, VERSION from style_bert_vits2.tts_model import ModelHolder from webui import ( create_dataset_app, @@ -37,7 +37,7 @@ model_holder = ModelHolder(Path(assets_root), device) with gr.Blocks(theme=GRADIO_THEME) as app: - gr.Markdown(f"# Style-Bert-VITS2 WebUI (version {LATEST_VERSION})") + gr.Markdown(f"# Style-Bert-VITS2 WebUI (version {VERSION})") with gr.Tabs(): with gr.Tab("音声合成"): create_inference_app(model_holder=model_holder) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..4db5fe9eb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,92 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "style-bert-vits2" +dynamic = ["version"] +description = 'Style-Bert-VITS2: Bert-VITS2 with more controllable voice styles.' +readme = "README.md" +requires-python = ">=3.9" +license = "AGPL-3.0" +keywords = [] +authors = [ + { name = "litagin02", email = "139731664+litagin02@users.noreply.github.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", +] +dependencies = [ + 'cmudict', + 'cn2an', + 'g2p_en', + 'gradio', + 'jieba', + 'librosa==0.9.2', + 'loguru', + 'num2words', + 'numba', + 'numpy', + 'pyannote.audio>=3.1.0', + 'pydantic', + 'pyopenjtalk-dict', + 'pypinyin', + 'pyworld', + 'safetensors', + 'scipy', + 'torch>=2.1,<2.2', + 'transformers', +] + +[project.urls] +Documentation = "https://github.com/litagin02/Style-Bert-VITS2#readme" +Issues = "https://github.com/litagin02/Style-Bert-VITS2/issues" +Source = "https://github.com/litagin02/Style-Bert-VITS2" + +[tool.hatch.version] +path = "style_bert_vits2/constants.py" + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.9", "3.10", "3.11", "3.12"] + +[tool.coverage.run] +source_pkgs = ["style_bert_vits2", "tests"] +branch = true +parallel = true +omit = [ + "style_bert_vits2/constants.py", +] + +[tool.coverage.paths] +style_bert_vits2 = ["style_bert_vits2", "*/style-bert-vits2/style_bert_vits2"] +tests = ["tests", "*/style-bert-vits2/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] diff --git a/server_editor.py b/server_editor.py index 9a6a08af8..b12f30442 100644 --- a/server_editor.py +++ b/server_editor.py @@ -37,7 +37,7 @@ DEFAULT_SDP_RATIO, DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, - LATEST_VERSION, + VERSION, Languages, ) from style_bert_vits2.logging import logger @@ -219,7 +219,7 @@ class AudioResponse(Response): @router.get("/version") def version() -> str: - return LATEST_VERSION + return VERSION class MoraTone(BaseModel): diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index 5f32a15a4..735aebcc2 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -4,7 +4,7 @@ # Style-Bert-VITS2 のバージョン -LATEST_VERSION = "2.4" +VERSION = "2.4" # Style-Bert-VITS2 のベースディレクトリ BASE_DIR = Path(__file__).parent.parent diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb From b7d7c7820364c7d17f8bedad04b8aa5441bec6ba Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 03:04:27 +0000 Subject: [PATCH 071/148] Refactor: when pyopenjtalk_worker is called without initialization, continue processing without a worker When using style-bert-vits2 as a library, the requirement to be able to launch it in multiple processes may not be necessary. Also, if the library is embedded and exe-ed using PyInstaller or similar, it is difficult to make pyopenjtalk_worker run in a separate process. Therefore, we changed it so that the worker is used only when it is explicitly initialized. --- server_editor.py | 2 +- server_fastapi.py | 2 +- .../japanese/pyopenjtalk_worker/__init__.py | 76 ++++++++++++------- webui/inference.py | 2 +- 4 files changed, 51 insertions(+), 31 deletions(-) diff --git a/server_editor.py b/server_editor.py index b12f30442..3c0d2c87b 100644 --- a/server_editor.py +++ b/server_editor.py @@ -151,7 +151,7 @@ def save_last_download(latest_release): # pyopenjtalk_worker を起動 ## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する -pyopenjtalk.initialize() +pyopenjtalk.initialize_worker() # pyopenjtalk の辞書を更新 update_dict() diff --git a/server_fastapi.py b/server_fastapi.py index e8309dadb..30f253ffc 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -43,7 +43,7 @@ # pyopenjtalk_worker を起動 ## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する -pyopenjtalk.initialize() +pyopenjtalk.initialize_worker() # 事前に BERT モデル/トークナイザーをロードしておく ## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py index 3593e60a7..212d21ab2 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py @@ -18,38 +18,58 @@ def run_frontend(text: str) -> list[dict[str, Any]]: - assert WORKER_CLIENT - ret = WORKER_CLIENT.dispatch_pyopenjtalk("run_frontend", text) - assert isinstance(ret, list) - return ret + if WORKER_CLIENT is not None: + ret = WORKER_CLIENT.dispatch_pyopenjtalk("run_frontend", text) + assert isinstance(ret, list) + return ret + else: + # without worker + import pyopenjtalk + return pyopenjtalk.run_frontend(text) def make_label(njd_features: Any) -> list[str]: - assert WORKER_CLIENT - ret = WORKER_CLIENT.dispatch_pyopenjtalk("make_label", njd_features) - assert isinstance(ret, list) - return ret - - -def mecab_dict_index(path: str, out_path: str, dn_mecab: Optional[str] = None): - assert WORKER_CLIENT - WORKER_CLIENT.dispatch_pyopenjtalk("mecab_dict_index", path, out_path, dn_mecab) - - -def update_global_jtalk_with_user_dict(path: str): - assert WORKER_CLIENT - WORKER_CLIENT.dispatch_pyopenjtalk("update_global_jtalk_with_user_dict", path) - - -def unset_user_dict(): - assert WORKER_CLIENT - WORKER_CLIENT.dispatch_pyopenjtalk("unset_user_dict") + if WORKER_CLIENT is not None: + ret = WORKER_CLIENT.dispatch_pyopenjtalk("make_label", njd_features) + assert isinstance(ret, list) + return ret + else: + # without worker + import pyopenjtalk + return pyopenjtalk.make_label(njd_features) + + +def mecab_dict_index(path: str, out_path: str, dn_mecab: Optional[str] = None) -> None: + if WORKER_CLIENT is not None: + WORKER_CLIENT.dispatch_pyopenjtalk("mecab_dict_index", path, out_path, dn_mecab) + else: + # without worker + import pyopenjtalk + pyopenjtalk.mecab_dict_index(path, out_path, dn_mecab) + + +def update_global_jtalk_with_user_dict(path: str) -> None: + if WORKER_CLIENT is not None: + WORKER_CLIENT.dispatch_pyopenjtalk("update_global_jtalk_with_user_dict", path) + else: + # without worker + import pyopenjtalk + pyopenjtalk.update_global_jtalk_with_user_dict(path) + + +def unset_user_dict() -> None: + if WORKER_CLIENT is not None: + WORKER_CLIENT.dispatch_pyopenjtalk("unset_user_dict") + else: + # without worker + import pyopenjtalk + pyopenjtalk.unset_user_dict() # initialize module when imported -def initialize(port: int = WORKER_PORT) -> None: +def initialize_worker(port: int = WORKER_PORT) -> None: import atexit import signal import socket @@ -99,11 +119,11 @@ def initialize(port: int = WORKER_PORT) -> None: logger.debug("pyopenjtalk worker server started") WORKER_CLIENT = client - atexit.register(terminate) + atexit.register(terminate_worker) # when the process is killed def signal_handler(signum: int, frame: Any): - terminate() + terminate_worker() try: signal.signal(signal.SIGTERM, signal_handler) @@ -113,13 +133,13 @@ def signal_handler(signum: int, frame: Any): # top-level declaration -def terminate() -> None: +def terminate_worker() -> None: logger.debug("pyopenjtalk worker server terminated") global WORKER_CLIENT if not WORKER_CLIENT: return - # repare for unexpected errors + # prepare for unexpected errors try: if WORKER_CLIENT.status() == 1: WORKER_CLIENT.quit_server() diff --git a/webui/inference.py b/webui/inference.py index 9c2bf63f6..6d6eae651 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -28,7 +28,7 @@ # pyopenjtalk_worker を起動 ## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する -pyopenjtalk.initialize() +pyopenjtalk.initialize_worker() # 事前に BERT モデル/トークナイザーをロードしておく ## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い From d2fd378b565239fd272520a011f2df84d1916a3b Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 03:47:17 +0000 Subject: [PATCH 072/148] Refactor: rename Model / ModelHolder to TTSModel / TTSModelHolder for clarification and add comments to each method --- app.py | 4 +- server_editor.py | 14 ++-- server_fastapi.py | 14 ++-- speech_mos.py | 4 +- style_bert_vits2/tts_model.py | 133 +++++++++++++++++++++++++++------- webui/inference.py | 12 +-- webui/merge.py | 8 +- 7 files changed, 133 insertions(+), 56 deletions(-) diff --git a/app.py b/app.py index eaf07b7a4..b4fbc3ec8 100644 --- a/app.py +++ b/app.py @@ -6,7 +6,7 @@ import yaml from style_bert_vits2.constants import GRADIO_THEME, VERSION -from style_bert_vits2.tts_model import ModelHolder +from style_bert_vits2.tts_model import TTSModelHolder from webui import ( create_dataset_app, create_inference_app, @@ -34,7 +34,7 @@ if device == "cuda" and not torch.cuda.is_available(): device = "cpu" -model_holder = ModelHolder(Path(assets_root), device) +model_holder = TTSModelHolder(Path(assets_root), device) with gr.Blocks(theme=GRADIO_THEME) as app: gr.Markdown(f"# Style-Bert-VITS2 WebUI (version {VERSION})") diff --git a/server_editor.py b/server_editor.py index 3c0d2c87b..aa85ef0ad 100644 --- a/server_editor.py +++ b/server_editor.py @@ -52,7 +52,7 @@ rewrite_word, update_dict, ) -from style_bert_vits2.tts_model import ModelHolder +from style_bert_vits2.tts_model import TTSModelHolder # ---フロントエンド部分に関する処理--- @@ -198,7 +198,7 @@ class AudioResponse(Response): model_dir = Path(args.model_dir) port = int(args.port) -model_holder = ModelHolder(model_dir, device) +model_holder = TTSModelHolder(model_dir, device) if len(model_holder.model_names) == 0: logger.error(f"Models not found in {model_dir}.") sys.exit(1) @@ -283,7 +283,7 @@ def synthesis(request: SynthesisRequest): detail=f"1行の文字数は{args.line_length}文字以下にしてください。", ) try: - model = model_holder.load_model( + model = model_holder.get_model( model_name=request.model, model_path_str=request.modelFile ) except Exception as e: @@ -310,7 +310,7 @@ def synthesis(request: SynthesisRequest): language=request.language, sdp_ratio=request.sdpRatio, noise=request.noise, - noisew=request.noisew, + noise_w=request.noisew, length=1 / request.speed, given_tone=tone, style=request.style, @@ -321,7 +321,7 @@ def synthesis(request: SynthesisRequest): line_split=False, pitch_scale=request.pitchScale, intonation_scale=request.intonationScale, - sid=sid, + speaker_id=sid, ) with BytesIO() as wavContent: @@ -350,7 +350,7 @@ def multi_synthesis(request: MultiSynthesisRequest): detail=f"1行の文字数は{args.line_length}文字以下にしてください。", ) try: - model = model_holder.load_model( + model = model_holder.get_model( model_name=req.model, model_path_str=req.modelFile ) except Exception as e: @@ -370,7 +370,7 @@ def multi_synthesis(request: MultiSynthesisRequest): language=req.language, sdp_ratio=req.sdpRatio, noise=req.noise, - noisew=req.noisew, + noise_w=req.noisew, length=1 / req.speed, given_tone=tone, style=req.style, diff --git a/server_fastapi.py b/server_fastapi.py index 30f253ffc..1f60b3c57 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -36,7 +36,7 @@ from style_bert_vits2.logging import logger from style_bert_vits2.nlp import bert_models from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk -from style_bert_vits2.tts_model import Model, ModelHolder +from style_bert_vits2.tts_model import TTSModel, TTSModelHolder ln = config.server_config.language @@ -67,16 +67,16 @@ class AudioResponse(Response): media_type = "audio/wav" -def load_models(model_holder: ModelHolder): +def load_models(model_holder: TTSModelHolder): model_holder.models = [] for model_name, model_paths in model_holder.model_files_dict.items(): - model = Model( + model = TTSModel( model_path=model_paths[0], config_path=model_holder.root_dir / model_name / "config.json", style_vec_path=model_holder.root_dir / model_name / "style_vectors.npy", device=model_holder.device, ) - model.load_net_g() + model.load() model_holder.models.append(model) @@ -94,7 +94,7 @@ def load_models(model_holder: ModelHolder): device = "cuda" if torch.cuda.is_available() else "cpu" model_dir = Path(args.dir) - model_holder = ModelHolder(model_dir, device) + model_holder = TTSModelHolder(model_dir, device) if len(model_holder.model_names) == 0: logger.error(f"Models not found in {model_dir}.") sys.exit(1) @@ -194,11 +194,11 @@ async def voice( sr, audio = model.infer( text=text, language=language, - sid=speaker_id, + speaker_id=speaker_id, reference_audio_path=reference_audio_path, sdp_ratio=sdp_ratio, noise=noise, - noisew=noisew, + noise_w=noisew, length=length, line_split=auto_split, split_interval=split_interval, diff --git a/speech_mos.py b/speech_mos.py index 79221adcc..6dd2caa58 100644 --- a/speech_mos.py +++ b/speech_mos.py @@ -12,7 +12,7 @@ from config import config from style_bert_vits2.logging import logger -from style_bert_vits2.tts_model import Model +from style_bert_vits2.tts_model import TTSModel warnings.filterwarnings("ignore") @@ -54,7 +54,7 @@ def get_model(model_file: Path): - return Model( + return TTSModel( model_path=str(model_file), config_path=str(model_file.parent / "config.json"), style_vec_path=str(model_file.parent / "style_vectors.npy"), diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index 42f6b84e3..109ef0e3d 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -29,9 +29,9 @@ from style_bert_vits2.voice import adjust_voice -class Model: +class TTSModel: """ - Style-Bert-Vits2 の音声合成モデルを操作するためのクラス。 + Style-Bert-Vits2 の音声合成モデルを操作するクラス。 モデル/ハイパーパラメータ/スタイルベクトルのパスとデバイスを指定して初期化し、model.infer() メソッドを呼び出すと音声合成を行える。 """ @@ -43,6 +43,17 @@ def __init__( style_vec_path: Path, device: str, ) -> None: + """ + Style-Bert-Vits2 の音声合成モデルを初期化する。 + この時点ではモデルはロードされていない (明示的にロードしたい場合は model.load() を呼び出す)。 + + Args: + model_path (Path): モデル (.safetensors) のパス + config_path (Path): ハイパーパラメータ (config.json) のパス + style_vec_path (Path): スタイルベクトル (style_vectors.npy) のパス + device (str): 音声合成時に利用するデバイス (cpu, cuda, mps など) + """ + self.model_path: Path = model_path self.config_path: Path = config_path self.style_vec_path: Path = style_vec_path @@ -71,24 +82,24 @@ def __init__( self.__net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None - def load_net_g(self) -> None: + def load(self) -> None: """ - net_g をロードする。 + 音声合成モデルをデバイスにロードする。 """ self.__net_g = get_net_g( - model_path=str(self.model_path), - version=self.hyper_parameters.version, - device=self.device, - hps=self.hyper_parameters, + model_path = str(self.model_path), + version = self.hyper_parameters.version, + device = self.device, + hps = self.hyper_parameters, ) - def get_style_vector(self, style_id: int, weight: float = 1.0) -> NDArray[Any]: + def __get_style_vector(self, style_id: int, weight: float = 1.0) -> NDArray[Any]: """ スタイルベクトルを取得する。 Args: - style_id (int): スタイル ID + style_id (int): スタイル ID (0 から始まるインデックス) weight (float, optional): スタイルベクトルの重み. Defaults to 1.0. Returns: @@ -100,7 +111,7 @@ def get_style_vector(self, style_id: int, weight: float = 1.0) -> NDArray[Any]: return style_vec - def get_style_vector_from_audio(self, audio_path: str, weight: float = 1.0) -> NDArray[Any]: + def __get_style_vector_from_audio(self, audio_path: str, weight: float = 1.0) -> NDArray[Any]: """ 音声からスタイルベクトルを推論する。 @@ -130,11 +141,11 @@ def infer( self, text: str, language: Languages = Languages.JP, - sid: int = 0, + speaker_id: int = 0, reference_audio_path: Optional[str] = None, sdp_ratio: float = DEFAULT_SDP_RATIO, noise: float = DEFAULT_NOISE, - noisew: float = DEFAULT_NOISEW, + noise_w: float = DEFAULT_NOISEW, length: float = DEFAULT_LENGTH, line_split: bool = DEFAULT_LINE_SPLIT, split_interval: float = DEFAULT_SPLIT_INTERVAL, @@ -147,6 +158,33 @@ def infer( pitch_scale: float = 1.0, intonation_scale: float = 1.0, ) -> tuple[int, NDArray[Any]]: + """ + テキストから音声を合成する。 + + Args: + text (str): 読み上げるテキスト + language (Languages, optional): 言語. Defaults to Languages.JP. + speaker_id (int, optional): 話者 ID. Defaults to 0. + reference_audio_path (Optional[str], optional): 音声スタイルの参照元の音声ファイルのパス. Defaults to None. + sdp_ratio (float, optional): SDP レシオ (値を大きくするとより感情豊かになる傾向がある). Defaults to DEFAULT_SDP_RATIO. + noise (float, optional): ノイズの大きさ. Defaults to DEFAULT_NOISE. + noise_w (float, optional): ノイズの大きさの重み. Defaults to DEFAULT_NOISEW. + length (float, optional): 長さ. Defaults to DEFAULT_LENGTH. + line_split (bool, optional): テキストを改行ごとに分割して生成するかどうか. Defaults to DEFAULT_LINE_SPLIT. + split_interval (float, optional): 改行ごとに分割する場合の無音 (秒). Defaults to DEFAULT_SPLIT_INTERVAL. + assist_text (Optional[str], optional): 感情表現の参照元の補助テキスト. Defaults to None. + assist_text_weight (float, optional): 感情表現の補助テキストを適用する強さ. Defaults to DEFAULT_ASSIST_TEXT_WEIGHT. + use_assist_text (bool, optional): 音声合成時に感情表現の補助テキストを使用するかどうか. Defaults to False. + style (str, optional): 音声スタイル (Neutral, Happy など). Defaults to DEFAULT_STYLE. + style_weight (float, optional): 音声スタイルを適用する強さ. Defaults to DEFAULT_STYLE_WEIGHT. + given_tone (Optional[list[int]], optional): アクセントのトーンのリスト. Defaults to None. + pitch_scale (float, optional): ピッチの高さ (1.0 から変更すると若干音質が低下する). Defaults to 1.0. + intonation_scale (float, optional): イントネーションの高さ (1.0 から変更すると若干音質が低下する). Defaults to 1.0. + + Returns: + tuple[int, NDArray[Any]]: サンプリングレートと音声データ (16bit PCM) + """ + logger.info(f"Start generating audio data from text:\n{text}") if language != "JP" and self.hyper_parameters.version.endswith("JP-Extra"): raise ValueError( @@ -158,13 +196,13 @@ def infer( assist_text = None if self.__net_g is None: - self.load_net_g() + self.load() assert self.__net_g is not None if reference_audio_path is None: style_id = self.style2id[style] - style_vector = self.get_style_vector(style_id, style_weight) + style_vector = self.__get_style_vector(style_id, style_weight) else: - style_vector = self.get_style_vector_from_audio( + style_vector = self.__get_style_vector_from_audio( reference_audio_path, style_weight ) if not line_split: @@ -173,9 +211,9 @@ def infer( text = text, sdp_ratio = sdp_ratio, noise_scale = noise, - noise_scale_w = noisew, + noise_scale_w = noise_w, length_scale = length, - sid = sid, + sid = speaker_id, language = language, hps = self.hyper_parameters, net_g = self.__net_g, @@ -196,9 +234,9 @@ def infer( text = t, sdp_ratio = sdp_ratio, noise_scale = noise, - noise_scale_w = noisew, + noise_scale_w = noise_w, length_scale = length, - sid = sid, + sid = speaker_id, language = language, hps = self.hyper_parameters, net_g = self.__net_g, @@ -225,24 +263,50 @@ def infer( return (self.hyper_parameters.data.sampling_rate, audio) -class ModelHolder: +class TTSModelHolder: """ - Style-Bert-Vits2 の音声合成モデルを管理するためのクラス。 + Style-Bert-Vits2 の音声合成モデルを管理するクラス。 + model_holder.models_info から指定されたディレクトリ内にある音声合成モデルの一覧を取得できる。 """ def __init__(self, model_root_dir: Path, device: str) -> None: + """ + Style-Bert-Vits2 の音声合成モデルを管理するクラスを初期化する。 + 音声合成モデルは下記のように配置されていることを前提とする (.safetensors のファイル名は自由) 。 + ``` + model_root_dir + ├── model-name-1 + │ ├── config.json + │ ├── model-name-1_e160_s14000.safetensors + │ └── style_vectors.npy + ├── model-name-2 + │ ├── config.json + │ ├── model-name-2_e160_s14000.safetensors + │ └── style_vectors.npy + └── ... + ``` + + Args: + model_root_dir (Path): 音声合成モデルが配置されているディレクトリのパス + device (str): 音声合成時に利用するデバイス (cpu, cuda, mps など) + """ + self.root_dir: Path = model_root_dir self.device: str = device self.model_files_dict: dict[str, list[Path]] = {} - self.current_model: Optional[Model] = None + self.current_model: Optional[TTSModel] = None self.model_names: list[str] = [] - self.models: list[Model] = [] + self.models: list[TTSModel] = [] self.models_info: list[dict[str, Union[str, list[str]]]] = [] self.refresh() def refresh(self) -> None: + """ + 音声合成モデルの一覧を更新する。 + """ + self.model_files_dict = {} self.model_names = [] self.current_model = None @@ -279,23 +343,36 @@ def refresh(self) -> None: }) - def load_model(self, model_name: str, model_path_str: str) -> Model: + def get_model(self, model_name: str, model_path_str: str) -> TTSModel: + """ + 指定された音声合成モデルのインスタンスを取得する。 + この時点ではモデルはロードされていない (明示的にロードしたい場合は model.load() を呼び出す)。 + + Args: + model_name (str): 音声合成モデルの名前 + model_path_str (str): 音声合成モデルのファイルパス (.safetensors) + + Returns: + TTSModel: 音声合成モデルのインスタンス + """ + model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") if model_path not in self.model_files_dict[model_name]: raise ValueError(f"Model file `{model_path}` is not found") if self.current_model is None or self.current_model.model_path != model_path: - self.current_model = Model( + self.current_model = TTSModel( model_path = model_path, config_path = self.root_dir / model_name / "config.json", style_vec_path = self.root_dir / model_name / "style_vectors.npy", device = self.device, ) + return self.current_model - def load_model_for_gradio(self, model_name: str, model_path_str: str) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: + def get_model_for_gradio(self, model_name: str, model_path_str: str) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") @@ -313,7 +390,7 @@ def load_model_for_gradio(self, model_name: str, model_path_str: str) -> tuple[g gr.Button(interactive=True, value="音声合成"), gr.Dropdown(choices=speakers, value=speakers[0]), # type: ignore ) - self.current_model = Model( + self.current_model = TTSModel( model_path = model_path, config_path = self.root_dir / model_name / "config.json", style_vec_path = self.root_dir / model_name / "style_vectors.npy", diff --git a/webui/inference.py b/webui/inference.py index 6d6eae651..e32907c62 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -23,7 +23,7 @@ from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.nlp.japanese.normalizer import normalize_text -from style_bert_vits2.tts_model import ModelHolder +from style_bert_vits2.tts_model import TTSModelHolder # pyopenjtalk_worker を起動 @@ -151,7 +151,7 @@ def gr_util(item): return (gr.update(visible=False), gr.update(visible=True)) -def create_inference_app(model_holder: ModelHolder) -> gr.Blocks: +def create_inference_app(model_holder: TTSModelHolder) -> gr.Blocks: def tts_fn( model_name, model_path, @@ -175,7 +175,7 @@ def tts_fn( pitch_scale, intonation_scale, ): - model_holder.load_model(model_name, model_path) + model_holder.get_model(model_name, model_path) assert model_holder.current_model is not None wrong_tone_message = "" @@ -218,7 +218,7 @@ def tts_fn( reference_audio_path=reference_audio_path, sdp_ratio=sdp_ratio, noise=noise_scale, - noisew=noise_scale_w, + noise_w=noise_scale_w, length=length_scale, line_split=line_split, split_interval=split_interval, @@ -228,7 +228,7 @@ def tts_fn( style=style, style_weight=style_weight, given_tone=tone, - sid=speaker_id, + speaker_id=speaker_id, pitch_scale=pitch_scale, intonation_scale=intonation_scale, ) @@ -459,7 +459,7 @@ def tts_fn( ) load_button.click( - model_holder.load_model_for_gradio, + model_holder.get_model_for_gradio, inputs=[model_name, model_path], outputs=[style, tts_button, speaker], ) diff --git a/webui/merge.py b/webui/merge.py index 85806f4a6..f692b6748 100644 --- a/webui/merge.py +++ b/webui/merge.py @@ -11,7 +11,7 @@ from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME from style_bert_vits2.logging import logger -from style_bert_vits2.tts_model import Model, ModelHolder +from style_bert_vits2.tts_model import TTSModel, TTSModelHolder voice_keys = ["dec"] @@ -250,11 +250,11 @@ def simple_tts(model_name, text, style=DEFAULT_STYLE, style_weight=1.0): config_path = os.path.join(assets_root, model_name, "config.json") style_vec_path = os.path.join(assets_root, model_name, "style_vectors.npy") - model = Model(Path(model_path), Path(config_path), Path(style_vec_path), device) + model = TTSModel(Path(model_path), Path(config_path), Path(style_vec_path), device) return model.infer(text, style=style, style_weight=style_weight) -def update_two_model_names_dropdown(model_holder: ModelHolder): +def update_two_model_names_dropdown(model_holder: TTSModelHolder): new_names, new_files, _ = model_holder.update_model_names_for_gradio() return new_names, new_files, new_names, new_files @@ -328,7 +328,7 @@ def load_styles_gr(model_name_a, model_name_b): """ -def create_merge_app(model_holder: ModelHolder) -> gr.Blocks: +def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: model_names = model_holder.model_names if len(model_names) == 0: logger.error( From afff154da4dd9eb9870e895cb143b2a2f67f19a2 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 04:18:39 +0000 Subject: [PATCH 073/148] Add: test code for style-bert-vits2 as a library By executing "hatch run test:test", you can check whether the test passes in all Python 3.9 to 3.12 environments. --- pyproject.toml | 10 ++++++--- style_bert_vits2/tts_model.py | 12 ++++++++--- tests/.gitignore | 1 + tests/test_main.py | 39 +++++++++++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 6 deletions(-) create mode 100644 tests/.gitignore create mode 100644 tests/test_main.py diff --git a/pyproject.toml b/pyproject.toml index 4db5fe9eb..5995fc98c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,24 +52,28 @@ Source = "https://github.com/litagin02/Style-Bert-VITS2" [tool.hatch.version] path = "style_bert_vits2/constants.py" -[tool.hatch.envs.default] +[tool.hatch.envs.test] dependencies = [ "coverage[toml]>=6.5", "pytest", ] -[tool.hatch.envs.default.scripts] +[tool.hatch.envs.test.scripts] +# Usage: `hatch run test:test` test = "pytest {args:tests}" +# Usage: `hatch run test:coverage` test-cov = "coverage run -m pytest {args:tests}" +# Usage: `hatch run test:cov-report` cov-report = [ "- coverage combine", "coverage report", ] +# Usage: `hatch run test:cov` cov = [ "test-cov", "cov-report", ] -[[tool.hatch.envs.all.matrix]] +[[tool.hatch.envs.test.matrix]] python = ["3.9", "3.10", "3.11", "3.12"] [tool.coverage.run] diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index 109ef0e3d..fe4dcb5eb 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -1,6 +1,6 @@ import warnings from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional, Union, TypedDict import gradio as gr import numpy as np @@ -263,6 +263,13 @@ def infer( return (self.hyper_parameters.data.sampling_rate, audio) +class TTSModelInfo(TypedDict): + name: str + files: list[str] + styles: list[str] + speakers: list[str] + + class TTSModelHolder: """ Style-Bert-Vits2 の音声合成モデルを管理するクラス。 @@ -297,8 +304,7 @@ def __init__(self, model_root_dir: Path, device: str) -> None: self.model_files_dict: dict[str, list[Path]] = {} self.current_model: Optional[TTSModel] = None self.model_names: list[str] = [] - self.models: list[TTSModel] = [] - self.models_info: list[dict[str, Union[str, list[str]]]] = [] + self.models_info: list[TTSModelInfo] = [] self.refresh() diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 000000000..697e56f2d --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +*.wav \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 000000000..597c23280 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,39 @@ +import pytest +from scipy.io import wavfile + +from style_bert_vits2.constants import BASE_DIR +from style_bert_vits2.tts_model import TTSModelHolder + + +def synthesize(device: str = 'cpu'): + + # モデル一覧を取得 + model_holder = TTSModelHolder(BASE_DIR / 'model_assets', device) + + # モデルが存在する場合、音声合成を実行 + if len(model_holder.models_info) > 0: + + # jvnv-F1-jp モデルを探す + for model_info in model_holder.models_info: + if model_info['name'] == 'jvnv-F1-jp': + + # 音声合成を実行 + model = model_holder.get_model(model_info['name'], model_info['files'][0]) + model.load() + sample_rate, audio_data = model.infer("あらゆる現実を、すべて自分のほうへねじ曲げたのだ。") + + # 音声データを保存 + with open(BASE_DIR / 'tests/test.wav', mode='wb') as f: + wavfile.write(f, sample_rate, audio_data) + else: + pytest.skip("音声合成モデルが見つかりませんでした。") + + +def test_synthesize_cpu(): + synthesize(device='cpu') + assert (BASE_DIR / 'tests/test.wav').exists() + + +def test_synthesize_cuda(): + synthesize(device='cuda') + assert (BASE_DIR / 'tests/test.wav').exists() From 84b1dbe1b5cd91e70286cb006df0b79dfd760838 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 10:48:21 +0000 Subject: [PATCH 074/148] Fix: problem with test failures Style-Bert-VITS2 has been reported to not work with some PyTorch 2.2 series, but Python 3.12 is only supported in Torch >= 2.2, so Python 3.12 support is not provided for the time being --- pyproject.toml | 5 ++--- style_bert_vits2/models/utils/__init__.py | 9 ++++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5995fc98c..9e68d6d09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", ] dependencies = [ @@ -37,7 +36,7 @@ dependencies = [ 'pydantic', 'pyopenjtalk-dict', 'pypinyin', - 'pyworld', + # 'pyworld', 'safetensors', 'scipy', 'torch>=2.1,<2.2', @@ -74,7 +73,7 @@ cov = [ ] [[tool.hatch.envs.test.matrix]] -python = ["3.9", "3.10", "3.11", "3.12"] +python = ["3.9", "3.10", "3.11"] [tool.coverage.run] source_pkgs = ["style_bert_vits2", "tests"] diff --git a/style_bert_vits2/models/utils/__init__.py b/style_bert_vits2/models/utils/__init__.py index 51e19d9a4..e0d922f4a 100644 --- a/style_bert_vits2/models/utils/__init__.py +++ b/style_bert_vits2/models/utils/__init__.py @@ -4,24 +4,27 @@ import re import subprocess from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional, Union, TYPE_CHECKING import numpy as np import torch from numpy.typing import NDArray from scipy.io.wavfile import read -from torch.utils.tensorboard import SummaryWriter from style_bert_vits2.logging import logger from style_bert_vits2.models.utils import checkpoints # type: ignore from style_bert_vits2.models.utils import safetensors # type: ignore +if TYPE_CHECKING: + # tensorboard はライブラリとしてインストールされている場合は依存関係に含まれないため、型チェック時のみインポートする + from torch.utils.tensorboard import SummaryWriter + __is_matplotlib_imported = False def summarize( - writer: SummaryWriter, + writer: "SummaryWriter", global_step: int, scalars: dict[str, float] = {}, histograms: dict[str, Any] = {}, From cdc47a98cebef125b7a33d8b24a5970887f6d9a8 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 11:10:06 +0000 Subject: [PATCH 075/148] Improve: test code --- .gitignore | 1 + tests/test_main.py | 50 ++++++++++++++++++++++++++++++---------------- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index b8a19a4e7..31602318c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ venv/ .venv/ dist/ +.coverage .ipynb_checkpoints/ /*.yml diff --git a/tests/test_main.py b/tests/test_main.py index 597c23280..c5d786f89 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,39 +1,55 @@ import pytest from scipy.io import wavfile -from style_bert_vits2.constants import BASE_DIR +from style_bert_vits2.constants import BASE_DIR, Languages from style_bert_vits2.tts_model import TTSModelHolder def synthesize(device: str = 'cpu'): - # モデル一覧を取得 + # 音声合成モデルが配置されていれば、音声合成を実行 model_holder = TTSModelHolder(BASE_DIR / 'model_assets', device) - - # モデルが存在する場合、音声合成を実行 if len(model_holder.models_info) > 0: - # jvnv-F1-jp モデルを探す + # jvnv-F2-jp モデルを探す for model_info in model_holder.models_info: - if model_info['name'] == 'jvnv-F1-jp': - - # 音声合成を実行 - model = model_holder.get_model(model_info['name'], model_info['files'][0]) - model.load() - sample_rate, audio_data = model.infer("あらゆる現実を、すべて自分のほうへねじ曲げたのだ。") - - # 音声データを保存 - with open(BASE_DIR / 'tests/test.wav', mode='wb') as f: - wavfile.write(f, sample_rate, audio_data) + if model_info['name'] == 'jvnv-F2-jp': + # すべてのスタイルに対して音声合成を実行 + for style in model_info['styles']: + + # 音声合成を実行 + model = model_holder.get_model(model_info['name'], model_info['files'][0]) + model.load() + sample_rate, audio_data = model.infer( + "あらゆる現実を、すべて自分のほうへねじ曲げたのだ。", + # 言語 (JP, EN, ZH / JP-Extra モデルの場合は JP のみ) + language = Languages.JP, + # 話者 ID (音声合成モデルに複数の話者が含まれる場合のみ必須、単一話者のみの場合は 0) + speaker_id = 0, + # 感情表現の強さ (0.0 〜 1.0) + sdp_ratio = 0.4, + # スタイル (Neutral, Happy など) + style = style, + # スタイルの強さ (0.0 〜 100.0) + style_weight = 6.0, + ) + + # 音声データを保存 + (BASE_DIR / 'tests/wavs').mkdir(exist_ok=True, parents=True) + wav_file_path = BASE_DIR / f'tests/wavs/{style}.wav' + with open(wav_file_path, 'wb') as f: + wavfile.write(f, sample_rate, audio_data) + + # 音声データが保存されたことを確認 + assert wav_file_path.exists() + # wav_file_path.unlink() else: pytest.skip("音声合成モデルが見つかりませんでした。") def test_synthesize_cpu(): synthesize(device='cpu') - assert (BASE_DIR / 'tests/test.wav').exists() def test_synthesize_cuda(): synthesize(device='cuda') - assert (BASE_DIR / 'tests/test.wav').exists() From 00bf496325ebe3ddaf2ed673d7a82db34e7a2d43 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 13:51:56 +0000 Subject: [PATCH 076/148] Add: VSCode settings Enabling type checking with Pylance. --- .gitignore | 2 -- .vscode/extensions.json | 6 ++++++ .vscode/settings.json | 22 ++++++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 .vscode/extensions.json create mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 31602318c..ae22e6437 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,3 @@ -.vscode/ - __pycache__/ venv/ .venv/ diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 000000000..7478fbda5 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,6 @@ +{ + "recommendations": [ + "ms-python.python", + "ms-python.vscode-pylance" + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..342c024c0 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,22 @@ +{ + // Pylance の Type Checking を有効化 + "python.languageServer": "Pylance", + "python.analysis.typeCheckingMode": "strict", + // Pylance の Type Checking のうち、いくつかのエラー報告を抑制する + "python.analysis.diagnosticSeverityOverrides": { + "reportConstantRedefinition": "none", + "reportGeneralTypeIssues": "warning", + "reportMissingParameterType": "warning", + "reportMissingTypeStubs": "none", + "reportPrivateImportUsage": "none", + "reportPrivateUsage": "warning", + "reportShadowedImports": "none", + "reportUnnecessaryComparison": "none", + "reportUnknownArgumentType": "none", + "reportUnknownMemberType": "none", + "reportUnknownParameterType": "warning", + "reportUnknownVariableType": "none", + "reportUnusedFunction": "none", + "reportUnusedVariable": "information", + }, +} \ No newline at end of file From 9c233630efbe48840dbff1e4d4774cf37413de42 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 14:18:57 +0000 Subject: [PATCH 077/148] Refactor: don't keep models in bert_feature module for each language BERT models and tokenizers are already stored and managed in the bert_models module and should not be stored here. In addition, since there may be situations where the user would like to use cpu instead of mps for inference when using it as a library, the automatic switching process to mps was removed. --- style_bert_vits2/nlp/chinese/bert_feature.py | 20 +++-------------- style_bert_vits2/nlp/english/bert_feature.py | 20 +++-------------- style_bert_vits2/nlp/japanese/bert_feature.py | 22 ++++--------------- 3 files changed, 10 insertions(+), 52 deletions(-) diff --git a/style_bert_vits2/nlp/chinese/bert_feature.py b/style_bert_vits2/nlp/chinese/bert_feature.py index f448b301d..adc07887d 100644 --- a/style_bert_vits2/nlp/chinese/bert_feature.py +++ b/style_bert_vits2/nlp/chinese/bert_feature.py @@ -1,16 +1,11 @@ -import sys from typing import Optional import torch -from transformers import PreTrainedModel from style_bert_vits2.constants import Languages from style_bert_vits2.nlp import bert_models -__models: dict[str, PreTrainedModel] = {} - - def extract_bert_feature( text: str, word2ph: list[int], @@ -32,18 +27,9 @@ def extract_bert_feature( torch.Tensor: BERT の特徴量 """ - if ( - sys.platform == "darwin" - and torch.backends.mps.is_available() - and device == "cpu" - ): - device = "mps" - if not device: - device = "cuda" if device == "cuda" and not torch.cuda.is_available(): device = "cpu" - if device not in __models.keys(): - __models[device] = bert_models.load_model(Languages.ZH).to(device) # type: ignore + model = bert_models.load_model(Languages.ZH).to(device) # type: ignore style_res_mean = None with torch.no_grad(): @@ -51,13 +37,13 @@ def extract_bert_feature( inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) # type: ignore - res = __models[device](**inputs, output_hidden_states=True) + res = model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() if assist_text: style_inputs = tokenizer(assist_text, return_tensors="pt") for i in style_inputs: style_inputs[i] = style_inputs[i].to(device) # type: ignore - style_res = __models[device](**style_inputs, output_hidden_states=True) + style_res = model(**style_inputs, output_hidden_states=True) style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) diff --git a/style_bert_vits2/nlp/english/bert_feature.py b/style_bert_vits2/nlp/english/bert_feature.py index 27fd5018a..7692c9e4c 100644 --- a/style_bert_vits2/nlp/english/bert_feature.py +++ b/style_bert_vits2/nlp/english/bert_feature.py @@ -1,16 +1,11 @@ -import sys from typing import Optional import torch -from transformers import PreTrainedModel from style_bert_vits2.constants import Languages from style_bert_vits2.nlp import bert_models -__models: dict[str, PreTrainedModel] = {} - - def extract_bert_feature( text: str, word2ph: list[int], @@ -32,18 +27,9 @@ def extract_bert_feature( torch.Tensor: BERT の特徴量 """ - if ( - sys.platform == "darwin" - and torch.backends.mps.is_available() - and device == "cpu" - ): - device = "mps" - if not device: - device = "cuda" if device == "cuda" and not torch.cuda.is_available(): device = "cpu" - if device not in __models.keys(): - __models[device] = bert_models.load_model(Languages.EN).to(device) # type: ignore + model = bert_models.load_model(Languages.EN).to(device) # type: ignore style_res_mean = None with torch.no_grad(): @@ -51,13 +37,13 @@ def extract_bert_feature( inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) # type: ignore - res = __models[device](**inputs, output_hidden_states=True) + res = model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() if assist_text: style_inputs = tokenizer(assist_text, return_tensors="pt") for i in style_inputs: style_inputs[i] = style_inputs[i].to(device) # type: ignore - style_res = __models[device](**style_inputs, output_hidden_states=True) + style_res = model(**style_inputs, output_hidden_states=True) style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) diff --git a/style_bert_vits2/nlp/japanese/bert_feature.py b/style_bert_vits2/nlp/japanese/bert_feature.py index 0d70014fd..d6e214b9a 100644 --- a/style_bert_vits2/nlp/japanese/bert_feature.py +++ b/style_bert_vits2/nlp/japanese/bert_feature.py @@ -1,17 +1,12 @@ -import sys from typing import Optional import torch -from transformers import PreTrainedModel from style_bert_vits2.constants import Languages from style_bert_vits2.nlp import bert_models from style_bert_vits2.nlp.japanese.g2p import text_to_sep_kata -__models: dict[str, PreTrainedModel] = {} - - def extract_bert_feature( text: str, word2ph: list[int], @@ -36,21 +31,12 @@ def extract_bert_feature( # 各単語が何文字かを作る `word2ph` を使う必要があるので、読めない文字は必ず無視する # でないと `word2ph` の結果とテキストの文字数結果が整合性が取れない text = "".join(text_to_sep_kata(text, raise_yomi_error=False)[0]) - if assist_text: assist_text = "".join(text_to_sep_kata(assist_text, raise_yomi_error=False)[0]) - if ( - sys.platform == "darwin" - and torch.backends.mps.is_available() - and device == "cpu" - ): - device = "mps" - if not device: - device = "cuda" + if device == "cuda" and not torch.cuda.is_available(): device = "cpu" - if device not in __models.keys(): - __models[device] = bert_models.load_model(Languages.JP).to(device) # type: ignore + model = bert_models.load_model(Languages.JP).to(device) # type: ignore style_res_mean = None with torch.no_grad(): @@ -58,13 +44,13 @@ def extract_bert_feature( inputs = tokenizer(text, return_tensors="pt") for i in inputs: inputs[i] = inputs[i].to(device) # type: ignore - res = __models[device](**inputs, output_hidden_states=True) + res = model(**inputs, output_hidden_states=True) res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() if assist_text: style_inputs = tokenizer(assist_text, return_tensors="pt") for i in style_inputs: style_inputs[i] = style_inputs[i].to(device) # type: ignore - style_res = __models[device](**style_inputs, output_hidden_states=True) + style_res = model(**style_inputs, output_hidden_states=True) style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu() style_res_mean = style_res.mean(0) From 733a9d838d10f3accdc4151c1e14988538ff1ff9 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 15:38:08 +0000 Subject: [PATCH 078/148] Fix: clearly include Pydantic v2 in the dependencies The Pydantic models in the library are written for Pydantic v2 and will not work with Pydantic v1. --- pyproject.toml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9e68d6d09..0a1bd9939 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ 'numba', 'numpy', 'pyannote.audio>=3.1.0', - 'pydantic', + 'pydantic>=2.0', 'pyopenjtalk-dict', 'pypinyin', # 'pyworld', diff --git a/requirements.txt b/requirements.txt index 669af2f94..75dbdcdac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ numba numpy psutil pyannote.audio>=3.1.0 -pydantic +pydantic>=2.0 pyloudnorm # pyopenjtalk-prebuilt # Should be manually uninstalled pyopenjtalk-dict From be265d42ed1a946bbb4bcc12d76fe9ab01da59b2 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 16:51:58 +0000 Subject: [PATCH 079/148] Improve: switch pyworld to pyworld-prebuilt and enable it by default Prebuilt wheels for almost every OS/architecture (except musl) are now available on PyPI, eliminating the need for a build environment. ref: https://pypi.org/project/pyworld-prebuilt/ --- pyproject.toml | 2 +- requirements.txt | 2 +- style_bert_vits2/voice.py | 30 ++++++++++++++++++------------ webui/inference.py | 2 -- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0a1bd9939..870c3b2d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ 'pydantic>=2.0', 'pyopenjtalk-dict', 'pypinyin', - # 'pyworld', + 'pyworld-prebuilt', 'safetensors', 'scipy', 'torch>=2.1,<2.2', diff --git a/requirements.txt b/requirements.txt index 75dbdcdac..bf47ec0fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ pyloudnorm # pyopenjtalk-prebuilt # Should be manually uninstalled pyopenjtalk-dict pypinyin -# pyworld # Not supported on Windows without Cython... +pyworld-prebuilt PyYAML requests safetensors diff --git a/style_bert_vits2/voice.py b/style_bert_vits2/voice.py index f1cfcacb7..ed7843f73 100644 --- a/style_bert_vits2/voice.py +++ b/style_bert_vits2/voice.py @@ -1,6 +1,7 @@ from typing import Any import numpy as np +import pyworld from numpy.typing import NDArray @@ -10,29 +11,34 @@ def adjust_voice( pitch_scale: float = 1.0, intonation_scale: float = 1.0, ) -> tuple[int, NDArray[Any]]: + """ + 音声のピッチとイントネーションを調整する。 + 変更すると若干音質が劣化するので、どちらも初期値のままならそのまま返す。 + + Args: + fs (int): 音声のサンプリング周波数 + wave (NDArray[Any]): 音声データ + pitch_scale (float, optional): ピッチの高さ. Defaults to 1.0. + intonation_scale (float, optional): イントネーションの高さ. Defaults to 1.0. + + Returns: + tuple[int, NDArray[Any]]: 調整後の音声データのサンプリング周波数と音声データ + """ if pitch_scale == 1.0 and intonation_scale == 1.0: # 初期値の場合は、音質劣化を避けるためにそのまま返す return fs, wave - try: - import pyworld - except ImportError: - raise ImportError( - "pyworld is not installed. Please install it by `pip install pyworld`" - ) - # pyworld で f0 を加工して合成 # pyworld よりもよいのがあるかもしれないが…… - ## pyworld は Cython で書かれているが、スタブファイルがないため型補完が全く効かない… wave = wave.astype(np.double) # 質が高そうだしとりあえずharvestにしておく - f0, t = pyworld.harvest(wave, fs) # type: ignore + f0, t = pyworld.harvest(wave, fs) - sp = pyworld.cheaptrick(wave, f0, t, fs) # type: ignore - ap = pyworld.d4c(wave, f0, t, fs) # type: ignore + sp = pyworld.cheaptrick(wave, f0, t, fs) + ap = pyworld.d4c(wave, f0, t, fs) non_zero_f0 = [f for f in f0 if f != 0] f0_mean = sum(non_zero_f0) / len(non_zero_f0) @@ -42,5 +48,5 @@ def adjust_voice( continue f0[i] = pitch_scale * f0_mean + intonation_scale * (f - f0_mean) - wave = pyworld.synthesize(f0, sp, ap, fs) # type: ignore + wave = pyworld.synthesize(f0, sp, ap, fs) return fs, wave diff --git a/webui/inference.py b/webui/inference.py index e32907c62..91baed50b 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -294,7 +294,6 @@ def tts_fn( value=1, step=0.05, label="音程(1以外では音質劣化)", - visible=False, # pyworldが必要 ) intonation_scale = gr.Slider( minimum=0, @@ -302,7 +301,6 @@ def tts_fn( value=1, step=0.1, label="抑揚(1以外では音質劣化)", - visible=False, # pyworldが必要 ) line_split = gr.Checkbox( From 859d940916e253da5680f59c1091ced23d41804d Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 19:15:11 +0000 Subject: [PATCH 080/148] Fix: failed to start training --- style_bert_vits2/models/hyper_parameters.py | 31 +++++++++++---------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/style_bert_vits2/models/hyper_parameters.py b/style_bert_vits2/models/hyper_parameters.py index 9dc5afb8d..53dc7f37b 100644 --- a/style_bert_vits2/models/hyper_parameters.py +++ b/style_bert_vits2/models/hyper_parameters.py @@ -39,8 +39,8 @@ class HyperParametersTrain(BaseModel): class HyperParametersData(BaseModel): use_jp_extra: bool = True - training_files: str = "Data/dummy/train.list" - validation_files: str = "Data/dummy/val.list" + training_files: str = "Data/Dummy/train.list" + validation_files: str = "Data/Dummy/val.list" max_wav_value: float = 32768.0 sampling_rate: int = 44100 filter_length: int = 2048 @@ -53,7 +53,7 @@ class HyperParametersData(BaseModel): n_speakers: int = 512 cleaned_text: bool = True spk2id: dict[str, int] = { - "dummy": 0 + "Dummy": 0, } num_styles: int = 1 style2id: dict[str, int] = { @@ -61,6 +61,13 @@ class HyperParametersData(BaseModel): } +class HyperParametersModelSLM(BaseModel): + model: str = "./slm/wavlm-base-plus" + sr: int = 16000 + hidden: int = 768 + nlayers: int = 13 + initial_channel: int = 64 + class HyperParametersModel(BaseModel): use_spk_conditioned_encoder: bool = True use_noise_scaled_mas: bool = True @@ -79,7 +86,7 @@ class HyperParametersModel(BaseModel): resblock_dilation_sizes: list[list[int]] = [ [1, 3, 5], [1, 3, 5], - [1, 3, 5] + [1, 3, 5], ] upsample_rates: list[int] = [8, 8, 2, 2, 2] upsample_initial_channel: int = 512 @@ -87,21 +94,15 @@ class HyperParametersModel(BaseModel): n_layers_q: int = 3 use_spectral_norm: bool = False gin_channels: int = 512 - slm: dict[str, Union[int, str]] = { - "model": "./slm/wavlm-base-plus", - "sr": 16000, - "hidden": 768, - "nlayers": 13, - "initial_channel": 64 - } + slm: HyperParametersModelSLM = HyperParametersModelSLM() class HyperParameters(BaseModel): - model_name: str = 'dummy' + model_name: str = 'Dummy' version: str = "2.0-JP-Extra" - train: HyperParametersTrain - data: HyperParametersData - model: HyperParametersModel + train: HyperParametersTrain = HyperParametersTrain() + data: HyperParametersData = HyperParametersData() + model: HyperParametersModel = HyperParametersModel() # 以下は学習時にのみ動的に設定されるパラメータ (通常 config.json には存在しない) model_dir: Optional[str] = None From 7f02b0f1d5e6452dba368178d95037aa69acfaae Mon Sep 17 00:00:00 2001 From: tsukumi Date: Sun, 10 Mar 2024 19:21:22 +0000 Subject: [PATCH 081/148] Refactor: TTSModelInfo changed from TypedDict to Pydantic model Pydantic models are more robust and properties can be accessed by dots. --- server_editor.py | 4 ++-- style_bert_vits2/tts_model.py | 17 +++++++++-------- tests/test_main.py | 6 +++--- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/server_editor.py b/server_editor.py index aa85ef0ad..c6f856e2f 100644 --- a/server_editor.py +++ b/server_editor.py @@ -52,7 +52,7 @@ rewrite_word, update_dict, ) -from style_bert_vits2.tts_model import TTSModelHolder +from style_bert_vits2.tts_model import TTSModelHolder, TTSModelInfo # ---フロントエンド部分に関する処理--- @@ -250,7 +250,7 @@ async def normalize(item: TextRequest): return normalize_text(item.text) -@router.get("/models_info") +@router.get("/models_info", response_model=list[TTSModelInfo]) def models_info(): return model_holder.models_info diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index fe4dcb5eb..527b4abb3 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -1,6 +1,6 @@ import warnings from pathlib import Path -from typing import Any, Optional, Union, TypedDict +from typing import Any, Optional, Union import gradio as gr import numpy as np @@ -8,6 +8,7 @@ import torch from gradio.processing_utils import convert_to_16_bit_wav from numpy.typing import NDArray +from pydantic import BaseModel from style_bert_vits2.constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, @@ -263,7 +264,7 @@ def infer( return (self.hyper_parameters.data.sampling_rate, audio) -class TTSModelInfo(TypedDict): +class TTSModelInfo(BaseModel): name: str files: list[str] styles: list[str] @@ -341,12 +342,12 @@ def refresh(self) -> None: styles = list(style2id.keys()) spk2id: dict[str, int] = hyper_parameters.data.spk2id speakers = list(spk2id.keys()) - self.models_info.append({ - "name": model_dir.name, - "files": [str(f) for f in model_files], - "styles": styles, - "speakers": speakers, - }) + self.models_info.append(TTSModelInfo( + name = model_dir.name, + files = [str(f) for f in model_files], + styles = styles, + speakers = speakers, + )) def get_model(self, model_name: str, model_path_str: str) -> TTSModel: diff --git a/tests/test_main.py b/tests/test_main.py index c5d786f89..0c0fe7737 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -13,12 +13,12 @@ def synthesize(device: str = 'cpu'): # jvnv-F2-jp モデルを探す for model_info in model_holder.models_info: - if model_info['name'] == 'jvnv-F2-jp': + if model_info.name == 'jvnv-F2-jp': # すべてのスタイルに対して音声合成を実行 - for style in model_info['styles']: + for style in model_info.styles: # 音声合成を実行 - model = model_holder.get_model(model_info['name'], model_info['files'][0]) + model = model_holder.get_model(model_info.name, model_info.files[0]) model.load() sample_rate, audio_data = model.infer( "あらゆる現実を、すべて自分のほうへねじ曲げたのだ。", From 42ee7d7608388fa63cdf90be4fd08d7996d6d70f Mon Sep 17 00:00:00 2001 From: litagin02 Date: Mon, 11 Mar 2024 09:27:36 +0900 Subject: [PATCH 082/148] Fix: ensure encoding utf-8 --- initialize.py | 2 +- style_bert_vits2/models/hyper_parameters.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/initialize.py b/initialize.py index 927a91a80..3736151f0 100644 --- a/initialize.py +++ b/initialize.py @@ -9,7 +9,7 @@ def download_bert_models(): - with open("bert/bert_models.json", "r") as fp: + with open("bert/bert_models.json", "r", encoding="utf-8") as fp: models = json.load(fp) for k, v in models.items(): local_path = Path("bert").joinpath(k) diff --git a/style_bert_vits2/models/hyper_parameters.py b/style_bert_vits2/models/hyper_parameters.py index 53dc7f37b..30c579e4a 100644 --- a/style_bert_vits2/models/hyper_parameters.py +++ b/style_bert_vits2/models/hyper_parameters.py @@ -53,11 +53,11 @@ class HyperParametersData(BaseModel): n_speakers: int = 512 cleaned_text: bool = True spk2id: dict[str, int] = { - "Dummy": 0, + "Dummy": 0, } num_styles: int = 1 style2id: dict[str, int] = { - "Neutral": 0, + "Neutral": 0, } @@ -68,6 +68,7 @@ class HyperParametersModelSLM(BaseModel): nlayers: int = 13 initial_channel: int = 64 + class HyperParametersModel(BaseModel): use_spk_conditioned_encoder: bool = True use_noise_scaled_mas: bool = True @@ -98,7 +99,7 @@ class HyperParametersModel(BaseModel): class HyperParameters(BaseModel): - model_name: str = 'Dummy' + model_name: str = "Dummy" version: str = "2.0-JP-Extra" train: HyperParametersTrain = HyperParametersTrain() data: HyperParametersData = HyperParametersData() @@ -112,7 +113,6 @@ class HyperParameters(BaseModel): # model_ 以下を Pydantic の保護対象から除外する model_config = ConfigDict(protected_namespaces=()) - @staticmethod def load_from_json(json_path: Union[str, Path]) -> "HyperParameters": """ @@ -125,5 +125,5 @@ def load_from_json(json_path: Union[str, Path]) -> "HyperParameters": HyperParameters: ハイパーパラメータ """ - with open(json_path, "r") as f: + with open(json_path, "r", encoding="utf-8") as f: return HyperParameters.model_validate_json(f.read()) From c776c082355c6e374058cd02918ea1b47a9ba1e1 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Mon, 11 Mar 2024 09:47:47 +0900 Subject: [PATCH 083/148] Apply black formatter --- .vscode/settings.json | 4 + app.py | 7 +- style_bert_vits2/constants.py | 2 + style_bert_vits2/logging.py | 6 +- style_bert_vits2/models/attentions.py | 36 +++-- style_bert_vits2/models/commons.py | 22 ++- style_bert_vits2/models/infer.py | 120 ++++++++------- style_bert_vits2/models/models.py | 30 +++- style_bert_vits2/models/models_jp_extra.py | 30 +++- style_bert_vits2/models/modules.py | 35 ++++- .../models/monotonic_alignment.py | 4 +- style_bert_vits2/models/transforms.py | 6 +- style_bert_vits2/models/utils/__init__.py | 12 +- style_bert_vits2/models/utils/checkpoints.py | 16 +- style_bert_vits2/nlp/__init__.py | 7 +- style_bert_vits2/nlp/bert_models.py | 27 +++- style_bert_vits2/nlp/chinese/g2p.py | 8 +- style_bert_vits2/nlp/english/normalizer.py | 4 +- style_bert_vits2/nlp/japanese/g2p.py | 14 +- style_bert_vits2/nlp/japanese/g2p_utils.py | 12 +- .../japanese/pyopenjtalk_worker/__init__.py | 12 +- .../pyopenjtalk_worker/worker_client.py | 15 +- .../pyopenjtalk_worker/worker_server.py | 5 +- .../nlp/japanese/user_dict/__init__.py | 14 +- style_bert_vits2/tts_model.py | 137 +++++++++--------- style_bert_vits2/utils/stdout_wrapper.py | 7 +- style_bert_vits2/utils/strenum.py | 13 +- style_bert_vits2/utils/subprocess.py | 16 +- tests/test_main.py | 26 ++-- train_ms.py | 48 +++--- train_ms_jp_extra.py | 66 +++++---- 31 files changed, 463 insertions(+), 298 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 342c024c0..2a583eb78 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -19,4 +19,8 @@ "reportUnusedFunction": "none", "reportUnusedVariable": "information", }, + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnType": true, + }, } \ No newline at end of file diff --git a/app.py b/app.py index b4fbc3ec8..60b91a686 100644 --- a/app.py +++ b/app.py @@ -51,4 +51,9 @@ create_merge_app(model_holder=model_holder) -app.launch(server_name=args.host, server_port=args.port, inbrowser=not args.no_autolaunch, share=args.share) +app.launch( + server_name=args.host, + server_port=args.port, + inbrowser=not args.no_autolaunch, + share=args.share, +) diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index 735aebcc2..9d8469073 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -9,6 +9,7 @@ # Style-Bert-VITS2 のベースディレクトリ BASE_DIR = Path(__file__).parent.parent + # 利用可能な言語 ## JP-Extra モデル利用時は JP 以外の言語の音声合成はできない class Languages(StrEnum): @@ -16,6 +17,7 @@ class Languages(StrEnum): EN = "EN" ZH = "ZH" + # 言語ごとのデフォルトの BERT トークナイザーのパス DEFAULT_BERT_TOKENIZER_PATHS = { Languages.JP: BASE_DIR / "bert" / "deberta-v2-large-japanese-char-wwm", diff --git a/style_bert_vits2/logging.py b/style_bert_vits2/logging.py index eec887ce3..e5c216a00 100644 --- a/style_bert_vits2/logging.py +++ b/style_bert_vits2/logging.py @@ -9,7 +9,7 @@ # Add a new handler logger.add( SAFE_STDOUT, - format = "{time:MM-DD HH:mm:ss} |{level:^8}| {file}:{line} | {message}", - backtrace = True, - diagnose = True, + format="{time:MM-DD HH:mm:ss} |{level:^8}| {file}:{line} | {message}", + backtrace=True, + diagnose=True, ) diff --git a/style_bert_vits2/models/attentions.py b/style_bert_vits2/models/attentions.py index 03b238d8a..9a101120a 100644 --- a/style_bert_vits2/models/attentions.py +++ b/style_bert_vits2/models/attentions.py @@ -24,7 +24,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.jit.script # type: ignore -def fused_add_tanh_sigmoid_multiply(input_a: torch.Tensor, input_b: torch.Tensor, n_channels: list[int]) -> torch.Tensor: +def fused_add_tanh_sigmoid_multiply( + input_a: torch.Tensor, input_b: torch.Tensor, n_channels: list[int] +) -> torch.Tensor: n_channels_int = n_channels[0] in_act = input_a + input_b t_act = torch.tanh(in_act[:, :n_channels_int, :]) @@ -44,7 +46,7 @@ def __init__( p_dropout: float = 0.0, window_size: int = 4, isflow: bool = True, - **kwargs: Any + **kwargs: Any, ) -> None: super().__init__() self.hidden_channels = hidden_channels @@ -99,7 +101,9 @@ def __init__( ) self.norm_layers_2.append(LayerNorm(hidden_channels)) - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask for i in range(self.n_layers): @@ -131,7 +135,7 @@ def __init__( p_dropout: float = 0.0, proximal_bias: bool = False, proximal_init: bool = True, - **kwargs: Any + **kwargs: Any, ) -> None: super().__init__() self.hidden_channels = hidden_channels @@ -180,7 +184,13 @@ def __init__( ) self.norm_layers_2.append(LayerNorm(hidden_channels)) - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, h: torch.Tensor, h_mask: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + h: torch.Tensor, + h_mask: torch.Tensor, + ) -> torch.Tensor: """ x: decoder input h: encoder output @@ -262,7 +272,9 @@ def __init__( assert self.conv_q.bias is not None self.conv_k.bias.copy_(self.conv_q.bias) - def forward(self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: q = self.conv_q(x) k = self.conv_k(c) v = self.conv_v(c) @@ -329,7 +341,9 @@ def attention( ) # [b, n_h, t_t, d_k] -> [b, d, t_t] return output, p_attn - def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def _matmul_with_relative_values( + self, x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: """ x: [b, h, l, m] y: [h or 1, m, d] @@ -338,7 +352,9 @@ def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor) -> torc ret = torch.matmul(x, y.unsqueeze(0)) return ret - def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def _matmul_with_relative_keys( + self, x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: """ x: [b, h, l, d] y: [h or 1, m, d] @@ -347,7 +363,9 @@ def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor) -> torch. ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) return ret - def _get_relative_embeddings(self, relative_embeddings: torch.Tensor, length: int) -> torch.Tensor: + def _get_relative_embeddings( + self, relative_embeddings: torch.Tensor, length: int + ) -> torch.Tensor: assert self.window_size is not None 2 * self.window_size + 1 # type: ignore # Pad first before slice to avoid using cond ops. diff --git a/style_bert_vits2/models/commons.py b/style_bert_vits2/models/commons.py index 1106b219a..da8993018 100644 --- a/style_bert_vits2/models/commons.py +++ b/style_bert_vits2/models/commons.py @@ -67,7 +67,9 @@ def intersperse(lst: list[Any], item: Any) -> list[Any]: return result -def slice_segments(x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4) -> torch.Tensor: +def slice_segments( + x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4 +) -> torch.Tensor: """ テンソルからセグメントをスライスする @@ -85,7 +87,9 @@ def slice_segments(x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4 return torch.gather(x, 2, gather_indices) -def rand_slice_segments(x: torch.Tensor, x_lengths: Optional[torch.Tensor] = None, segment_size: int = 4) -> tuple[torch.Tensor, torch.Tensor]: +def rand_slice_segments( + x: torch.Tensor, x_lengths: Optional[torch.Tensor] = None, segment_size: int = 4 +) -> tuple[torch.Tensor, torch.Tensor]: """ ランダムなセグメントをスライスする @@ -121,7 +125,9 @@ def subsequent_mask(length: int) -> torch.Tensor: @torch.jit.script # type: ignore -def fused_add_tanh_sigmoid_multiply(input_a: torch.Tensor, input_b: torch.Tensor, n_channels: torch.Tensor) -> torch.Tensor: +def fused_add_tanh_sigmoid_multiply( + input_a: torch.Tensor, input_b: torch.Tensor, n_channels: torch.Tensor +) -> torch.Tensor: """ 加算、tanh、sigmoid の活性化関数を組み合わせた演算を行う @@ -141,7 +147,9 @@ def fused_add_tanh_sigmoid_multiply(input_a: torch.Tensor, input_b: torch.Tensor return acts -def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None) -> torch.Tensor: +def sequence_mask( + length: torch.Tensor, max_length: Optional[int] = None +) -> torch.Tensor: """ シーケンスマスクを生成する @@ -180,7 +188,11 @@ def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: return path -def clip_grad_value_(parameters: Union[torch.Tensor, list[torch.Tensor]], clip_value: Optional[float], norm_type: float = 2.0) -> float: +def clip_grad_value_( + parameters: Union[torch.Tensor, list[torch.Tensor]], + clip_value: Optional[float], + norm_type: float = 2.0, +) -> float: """ 勾配の値をクリップする diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index d3ec963ef..b0ab9d293 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -9,8 +9,14 @@ from style_bert_vits2.models import utils from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.models import SynthesizerTrn -from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra -from style_bert_vits2.nlp import clean_text, cleaned_text_to_sequence, extract_bert_feature +from style_bert_vits2.models.models_jp_extra import ( + SynthesizerTrn as SynthesizerTrnJPExtra, +) +from style_bert_vits2.nlp import ( + clean_text, + cleaned_text_to_sequence, + extract_bert_feature, +) from style_bert_vits2.nlp.symbols import SYMBOLS @@ -18,69 +24,71 @@ def get_net_g(model_path: str, version: str, device: str, hps: HyperParameters): if version.endswith("JP-Extra"): logger.info("Using JP-Extra model") net_g = SynthesizerTrnJPExtra( - n_vocab = len(SYMBOLS), - spec_channels = hps.data.filter_length // 2 + 1, - segment_size = hps.train.segment_size // hps.data.hop_length, - n_speakers = hps.data.n_speakers, + n_vocab=len(SYMBOLS), + spec_channels=hps.data.filter_length // 2 + 1, + segment_size=hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, # hps.model 以下のすべての値を引数に渡す - use_spk_conditioned_encoder = hps.model.use_spk_conditioned_encoder, - use_noise_scaled_mas = hps.model.use_noise_scaled_mas, - use_mel_posterior_encoder = hps.model.use_mel_posterior_encoder, - use_duration_discriminator = hps.model.use_duration_discriminator, - use_wavlm_discriminator = hps.model.use_wavlm_discriminator, - inter_channels = hps.model.inter_channels, - hidden_channels = hps.model.hidden_channels, - filter_channels = hps.model.filter_channels, - n_heads = hps.model.n_heads, - n_layers = hps.model.n_layers, - kernel_size = hps.model.kernel_size, - p_dropout = hps.model.p_dropout, - resblock = hps.model.resblock, - resblock_kernel_sizes = hps.model.resblock_kernel_sizes, - resblock_dilation_sizes = hps.model.resblock_dilation_sizes, - upsample_rates = hps.model.upsample_rates, - upsample_initial_channel = hps.model.upsample_initial_channel, - upsample_kernel_sizes = hps.model.upsample_kernel_sizes, - n_layers_q = hps.model.n_layers_q, - use_spectral_norm = hps.model.use_spectral_norm, - gin_channels = hps.model.gin_channels, - slm = hps.model.slm, + use_spk_conditioned_encoder=hps.model.use_spk_conditioned_encoder, + use_noise_scaled_mas=hps.model.use_noise_scaled_mas, + use_mel_posterior_encoder=hps.model.use_mel_posterior_encoder, + use_duration_discriminator=hps.model.use_duration_discriminator, + use_wavlm_discriminator=hps.model.use_wavlm_discriminator, + inter_channels=hps.model.inter_channels, + hidden_channels=hps.model.hidden_channels, + filter_channels=hps.model.filter_channels, + n_heads=hps.model.n_heads, + n_layers=hps.model.n_layers, + kernel_size=hps.model.kernel_size, + p_dropout=hps.model.p_dropout, + resblock=hps.model.resblock, + resblock_kernel_sizes=hps.model.resblock_kernel_sizes, + resblock_dilation_sizes=hps.model.resblock_dilation_sizes, + upsample_rates=hps.model.upsample_rates, + upsample_initial_channel=hps.model.upsample_initial_channel, + upsample_kernel_sizes=hps.model.upsample_kernel_sizes, + n_layers_q=hps.model.n_layers_q, + use_spectral_norm=hps.model.use_spectral_norm, + gin_channels=hps.model.gin_channels, + slm=hps.model.slm, ).to(device) else: logger.info("Using normal model") net_g = SynthesizerTrn( - n_vocab = len(SYMBOLS), - spec_channels = hps.data.filter_length // 2 + 1, - segment_size = hps.train.segment_size // hps.data.hop_length, + n_vocab=len(SYMBOLS), + spec_channels=hps.data.filter_length // 2 + 1, + segment_size=hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, # hps.model 以下のすべての値を引数に渡す - use_spk_conditioned_encoder = hps.model.use_spk_conditioned_encoder, - use_noise_scaled_mas = hps.model.use_noise_scaled_mas, - use_mel_posterior_encoder = hps.model.use_mel_posterior_encoder, - use_duration_discriminator = hps.model.use_duration_discriminator, - use_wavlm_discriminator = hps.model.use_wavlm_discriminator, - inter_channels = hps.model.inter_channels, - hidden_channels = hps.model.hidden_channels, - filter_channels = hps.model.filter_channels, - n_heads = hps.model.n_heads, - n_layers = hps.model.n_layers, - kernel_size = hps.model.kernel_size, - p_dropout = hps.model.p_dropout, - resblock = hps.model.resblock, - resblock_kernel_sizes = hps.model.resblock_kernel_sizes, - resblock_dilation_sizes = hps.model.resblock_dilation_sizes, - upsample_rates = hps.model.upsample_rates, - upsample_initial_channel = hps.model.upsample_initial_channel, - upsample_kernel_sizes = hps.model.upsample_kernel_sizes, - n_layers_q = hps.model.n_layers_q, - use_spectral_norm = hps.model.use_spectral_norm, - gin_channels = hps.model.gin_channels, - slm = hps.model.slm, + use_spk_conditioned_encoder=hps.model.use_spk_conditioned_encoder, + use_noise_scaled_mas=hps.model.use_noise_scaled_mas, + use_mel_posterior_encoder=hps.model.use_mel_posterior_encoder, + use_duration_discriminator=hps.model.use_duration_discriminator, + use_wavlm_discriminator=hps.model.use_wavlm_discriminator, + inter_channels=hps.model.inter_channels, + hidden_channels=hps.model.hidden_channels, + filter_channels=hps.model.filter_channels, + n_heads=hps.model.n_heads, + n_layers=hps.model.n_layers, + kernel_size=hps.model.kernel_size, + p_dropout=hps.model.p_dropout, + resblock=hps.model.resblock, + resblock_kernel_sizes=hps.model.resblock_kernel_sizes, + resblock_dilation_sizes=hps.model.resblock_dilation_sizes, + upsample_rates=hps.model.upsample_rates, + upsample_initial_channel=hps.model.upsample_initial_channel, + upsample_kernel_sizes=hps.model.upsample_kernel_sizes, + n_layers_q=hps.model.n_layers_q, + use_spectral_norm=hps.model.use_spectral_norm, + gin_channels=hps.model.gin_channels, + slm=hps.model.slm, ).to(device) net_g.state_dict() _ = net_g.eval() if model_path.endswith(".pth") or model_path.endswith(".pt"): - _ = utils.checkpoints.load_checkpoint(model_path, net_g, None, skip_optimizer=True) + _ = utils.checkpoints.load_checkpoint( + model_path, net_g, None, skip_optimizer=True + ) elif model_path.endswith(".safetensors"): _ = utils.safetensors.load_safetensors(model_path, net_g, True) else: @@ -102,8 +110,8 @@ def get_text( norm_text, phone, tone, word2ph = clean_text( text, language_str, - use_jp_extra = use_jp_extra, - raise_yomi_error = False, + use_jp_extra=use_jp_extra, + raise_yomi_error=False, ) if given_tone is not None: if len(given_tone) != len(phone): diff --git a/style_bert_vits2/models/models.py b/style_bert_vits2/models/models.py index 829ca4a06..21a0be487 100644 --- a/style_bert_vits2/models/models.py +++ b/style_bert_vits2/models/models.py @@ -21,7 +21,7 @@ def __init__( filter_channels: int, kernel_size: int, p_dropout: float, - gin_channels: int = 0 + gin_channels: int = 0, ) -> None: super().__init__() @@ -330,7 +330,9 @@ def __init__( if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, in_channels, 1) - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: x = torch.detach(x) if g is not None: g = torch.detach(g) @@ -582,7 +584,9 @@ def __init__( if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: x = self.conv_pre(x) if g is not None: x = x + self.cond(g) @@ -613,7 +617,13 @@ def remove_weight_norm(self) -> None: class DiscriminatorP(torch.nn.Module): - def __init__(self, period: int, kernel_size: int = 5, stride: int = 3, use_spectral_norm: bool = False) -> None: + def __init__( + self, + period: int, + kernel_size: int = 5, + stride: int = 3, + use_spectral_norm: bool = False, + ) -> None: super(DiscriminatorP, self).__init__() self.period = period self.use_spectral_norm = use_spectral_norm @@ -736,7 +746,9 @@ def forward( self, y: torch.Tensor, y_hat: torch.Tensor, - ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: + ) -> tuple[ + list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor] + ]: y_d_rs = [] y_d_gs = [] fmap_rs = [] @@ -787,7 +799,9 @@ def __init__(self, spec_channels: int, gin_channels: int = 0) -> None: ) self.proj = nn.Linear(128, gin_channels) - def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: N = inputs.size(0) out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] for conv in self.convs: @@ -805,7 +819,9 @@ def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> return self.proj(out.squeeze(0)) - def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int: + def calculate_channels( + self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int + ) -> int: for i in range(n_convs): L = (L - kernel_size + 2 * pad) // stride + 1 return L diff --git a/style_bert_vits2/models/models_jp_extra.py b/style_bert_vits2/models/models_jp_extra.py index a43a7157b..00cc02ffc 100644 --- a/style_bert_vits2/models/models_jp_extra.py +++ b/style_bert_vits2/models/models_jp_extra.py @@ -21,7 +21,7 @@ def __init__( filter_channels: int, kernel_size: int, p_dropout: float, - gin_channels: int = 0 + gin_channels: int = 0, ) -> None: super().__init__() @@ -313,7 +313,9 @@ def __init__( if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, in_channels, 1) - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: x = torch.detach(x) if g is not None: g = torch.detach(g) @@ -587,7 +589,9 @@ def __init__( if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: x = self.conv_pre(x) if g is not None: x = x + self.cond(g) @@ -618,7 +622,13 @@ def remove_weight_norm(self) -> None: class DiscriminatorP(torch.nn.Module): - def __init__(self, period: int, kernel_size: int = 5, stride: int = 3, use_spectral_norm: bool = False) -> None: + def __init__( + self, + period: int, + kernel_size: int = 5, + stride: int = 3, + use_spectral_norm: bool = False, + ) -> None: super(DiscriminatorP, self).__init__() self.period = period self.use_spectral_norm = use_spectral_norm @@ -741,7 +751,9 @@ def forward( self, y: torch.Tensor, y_hat: torch.Tensor, - ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: + ) -> tuple[ + list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor] + ]: y_d_rs = [] y_d_gs = [] fmap_rs = [] @@ -845,7 +857,9 @@ def __init__(self, spec_channels: int, gin_channels: int = 0) -> None: ) self.proj = nn.Linear(128, gin_channels) - def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: N = inputs.size(0) out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] for conv in self.convs: @@ -863,7 +877,9 @@ def forward(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None) -> return self.proj(out.squeeze(0)) - def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int: + def calculate_channels( + self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int + ) -> int: for i in range(n_convs): L = (L - kernel_size + 2 * pad) // stride + 1 return L diff --git a/style_bert_vits2/models/modules.py b/style_bert_vits2/models/modules.py index 8eed9635c..ebc52730c 100644 --- a/style_bert_vits2/models/modules.py +++ b/style_bert_vits2/models/modules.py @@ -88,7 +88,9 @@ class DDSConv(nn.Module): Dialted and Depth-Separable Convolution """ - def __init__(self, channels: int, kernel_size: int, n_layers: int, p_dropout: float = 0.0) -> None: + def __init__( + self, channels: int, kernel_size: int, n_layers: int, p_dropout: float = 0.0 + ) -> None: super().__init__() self.channels = channels self.kernel_size = kernel_size @@ -117,7 +119,9 @@ def __init__(self, channels: int, kernel_size: int, n_layers: int, p_dropout: fl self.norms_1.append(LayerNorm(channels)) self.norms_2.append(LayerNorm(channels)) - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: if g is not None: x = x + g for i in range(self.n_layers): @@ -184,7 +188,13 @@ def __init__( res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) - def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, **kwargs: Any) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: output = torch.zeros_like(x) n_channels_tensor = torch.IntTensor([self.hidden_channels]) @@ -221,7 +231,12 @@ def remove_weight_norm(self) -> None: class ResBlock1(torch.nn.Module): - def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int, int] = (1, 3, 5)) -> None: + def __init__( + self, + channels: int, + kernel_size: int = 3, + dilation: tuple[int, int, int] = (1, 3, 5), + ) -> None: super(ResBlock1, self).__init__() self.convs1 = nn.ModuleList( [ @@ -295,7 +310,9 @@ def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int ) self.convs2.apply(commons.init_weights) - def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, LRELU_SLOPE) if x_mask is not None: @@ -318,7 +335,9 @@ def remove_weight_norm(self) -> None: class ResBlock2(torch.nn.Module): - def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int] = (1, 3)) -> None: + def __init__( + self, channels: int, kernel_size: int = 3, dilation: tuple[int, int] = (1, 3) + ) -> None: super(ResBlock2, self).__init__() self.convs = nn.ModuleList( [ @@ -346,7 +365,9 @@ def __init__(self, channels: int, kernel_size: int = 3, dilation: tuple[int, int ) self.convs.apply(commons.init_weights) - def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: for c in self.convs: xt = F.leaky_relu(x, LRELU_SLOPE) if x_mask is not None: diff --git a/style_bert_vits2/models/monotonic_alignment.py b/style_bert_vits2/models/monotonic_alignment.py index b499ad05f..d33631e41 100644 --- a/style_bert_vits2/models/monotonic_alignment.py +++ b/style_bert_vits2/models/monotonic_alignment.py @@ -40,8 +40,8 @@ def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: numba.int32[::1], numba.int32[::1], ), - nopython = True, - nogil = True, + nopython=True, + nogil=True, ) # type: ignore def __maximum_path_jit(paths: Any, values: Any, t_ys: Any, t_xs: Any) -> None: """ diff --git a/style_bert_vits2/models/transforms.py b/style_bert_vits2/models/transforms.py index 61306adc7..b6f4420f9 100644 --- a/style_bert_vits2/models/transforms.py +++ b/style_bert_vits2/models/transforms.py @@ -39,12 +39,14 @@ def piecewise_rational_quadratic_transform( min_bin_width=min_bin_width, min_bin_height=min_bin_height, min_derivative=min_derivative, - **spline_kwargs # type: ignore + **spline_kwargs, # type: ignore ) return outputs, logabsdet -def searchsorted(bin_locations: torch.Tensor, inputs: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: +def searchsorted( + bin_locations: torch.Tensor, inputs: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: bin_locations[..., -1] += eps return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 diff --git a/style_bert_vits2/models/utils/__init__.py b/style_bert_vits2/models/utils/__init__.py index e0d922f4a..0fd3a47ab 100644 --- a/style_bert_vits2/models/utils/__init__.py +++ b/style_bert_vits2/models/utils/__init__.py @@ -107,7 +107,9 @@ def plot_spectrogram_to_numpy(spectrogram: NDArray[Any]) -> NDArray[Any]: return data -def plot_alignment_to_numpy(alignment: NDArray[Any], info: Optional[str] = None) -> NDArray[Any]: +def plot_alignment_to_numpy( + alignment: NDArray[Any], info: Optional[str] = None +) -> NDArray[Any]: """ 指定されたアライメントを画像データに変換する @@ -163,7 +165,9 @@ def load_wav_to_torch(full_path: Union[str, Path]) -> tuple[torch.FloatTensor, i return torch.FloatTensor(data.astype(np.float32)), sampling_rate -def load_filepaths_and_text(filename: Union[str, Path], split: str = "|") -> list[list[str]]: +def load_filepaths_and_text( + filename: Union[str, Path], split: str = "|" +) -> list[list[str]]: """ 指定されたファイルからファイルパスとテキストを読み込む @@ -180,7 +184,9 @@ def load_filepaths_and_text(filename: Union[str, Path], split: str = "|") -> lis return filepaths_and_text -def get_logger(model_dir_path: Union[str, Path], filename: str = "train.log") -> logging.Logger: +def get_logger( + model_dir_path: Union[str, Path], filename: str = "train.log" +) -> logging.Logger: """ ロガーを取得する diff --git a/style_bert_vits2/models/utils/checkpoints.py b/style_bert_vits2/models/utils/checkpoints.py index f26f8fc92..768973bcc 100644 --- a/style_bert_vits2/models/utils/checkpoints.py +++ b/style_bert_vits2/models/utils/checkpoints.py @@ -14,7 +14,7 @@ def load_checkpoint( model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, skip_optimizer: bool = False, - for_infer: bool = False + for_infer: bool = False, ) -> tuple[torch.nn.Module, Optional[torch.optim.Optimizer], float, int]: """ 指定されたパスからチェックポイントを読み込み、モデルとオプティマイザーを更新する。 @@ -107,7 +107,9 @@ def save_checkpoint( iteration (int): イテレーション数 checkpoint_path (Union[str, Path]): 保存先のパス """ - logger.info(f"Saving model and optimizer state at iteration {iteration} to {checkpoint_path}") + logger.info( + f"Saving model and optimizer state at iteration {iteration} to {checkpoint_path}" + ) if hasattr(model, "module"): state_dict = model.module.state_dict() else: @@ -123,7 +125,11 @@ def save_checkpoint( ) -def clean_checkpoints(model_dir_path: Union[str, Path] = "logs/44k/", n_ckpts_to_keep: int = 2, sort_by_time: bool = True) -> None: +def clean_checkpoints( + model_dir_path: Union[str, Path] = "logs/44k/", + n_ckpts_to_keep: int = 2, + sort_by_time: bool = True, +) -> None: """ 指定されたディレクトリから古いチェックポイントを削除して空き容量を確保する @@ -172,7 +178,9 @@ def del_routine(x: str) -> list[Any]: [del_routine(fn) for fn in to_del] -def get_latest_checkpoint_path(model_dir_path: Union[str, Path], regex: str = "G_*.pth") -> str: +def get_latest_checkpoint_path( + model_dir_path: Union[str, Path], regex: str = "G_*.pth" +) -> str: """ 指定されたディレクトリから最新のチェックポイントのパスを取得する diff --git a/style_bert_vits2/nlp/__init__.py b/style_bert_vits2/nlp/__init__.py index 683d6d479..5f3d63f6f 100644 --- a/style_bert_vits2/nlp/__init__.py +++ b/style_bert_vits2/nlp/__init__.py @@ -74,16 +74,19 @@ def clean_text( if language == Languages.JP: from style_bert_vits2.nlp.japanese.g2p import g2p from style_bert_vits2.nlp.japanese.normalizer import normalize_text + norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text, use_jp_extra, raise_yomi_error) elif language == Languages.EN: from style_bert_vits2.nlp.english.g2p import g2p from style_bert_vits2.nlp.english.normalizer import normalize_text + norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text) elif language == Languages.ZH: from style_bert_vits2.nlp.chinese.g2p import g2p from style_bert_vits2.nlp.chinese.normalizer import normalize_text + norm_text = normalize_text(text) phones, tones, word2ph = g2p(norm_text) else: @@ -92,7 +95,9 @@ def clean_text( return norm_text, phones, tones, word2ph -def cleaned_text_to_sequence(cleaned_phones: list[str], tones: list[int], language: Languages) -> tuple[list[int], list[int], list[int]]: +def cleaned_text_to_sequence( + cleaned_phones: list[str], tones: list[int], language: Languages +) -> tuple[list[int], list[int], list[int]]: """ テキスト文字列を、テキスト内の記号に対応する一連の ID に変換する diff --git a/style_bert_vits2/nlp/bert_models.py b/style_bert_vits2/nlp/bert_models.py index 4385d648c..220d58448 100644 --- a/style_bert_vits2/nlp/bert_models.py +++ b/style_bert_vits2/nlp/bert_models.py @@ -30,7 +30,9 @@ __loaded_models: dict[Languages, Union[PreTrainedModel, DebertaV2Model]] = {} # 各言語ごとのロード済みの BERT トークナイザーを格納する辞書 -__loaded_tokenizers: dict[Languages, Union[PreTrainedTokenizer, PreTrainedTokenizerFast, DebertaV2Tokenizer]] = {} +__loaded_tokenizers: dict[ + Languages, Union[PreTrainedTokenizer, PreTrainedTokenizerFast, DebertaV2Tokenizer] +] = {} def load_model( @@ -63,18 +65,24 @@ def load_model( # pretrained_model_name_or_path が指定されていない場合はデフォルトのパスを利用 if pretrained_model_name_or_path is None: - assert DEFAULT_BERT_TOKENIZER_PATHS[language].exists(), \ - f"The default {language} BERT model does not exist on the file system. Please specify the path to the pre-trained model." + assert DEFAULT_BERT_TOKENIZER_PATHS[ + language + ].exists(), f"The default {language} BERT model does not exist on the file system. Please specify the path to the pre-trained model." pretrained_model_name_or_path = str(DEFAULT_BERT_TOKENIZER_PATHS[language]) # BERT モデルをロードし、辞書に格納して返す ## 英語のみ DebertaV2Model でロードする必要がある if language == Languages.EN: - model = cast(DebertaV2Model, DebertaV2Model.from_pretrained(pretrained_model_name_or_path)) + model = cast( + DebertaV2Model, + DebertaV2Model.from_pretrained(pretrained_model_name_or_path), + ) else: model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path) __loaded_models[language] = model - logger.info(f"Loaded the {language} BERT model from {pretrained_model_name_or_path}") + logger.info( + f"Loaded the {language} BERT model from {pretrained_model_name_or_path}" + ) return model @@ -109,8 +117,9 @@ def load_tokenizer( # pretrained_model_name_or_path が指定されていない場合はデフォルトのパスを利用 if pretrained_model_name_or_path is None: - assert DEFAULT_BERT_TOKENIZER_PATHS[language].exists(), \ - f"The default {language} BERT tokenizer does not exist on the file system. Please specify the path to the pre-trained model." + assert DEFAULT_BERT_TOKENIZER_PATHS[ + language + ].exists(), f"The default {language} BERT tokenizer does not exist on the file system. Please specify the path to the pre-trained model." pretrained_model_name_or_path = str(DEFAULT_BERT_TOKENIZER_PATHS[language]) # BERT トークナイザーをロードし、辞書に格納して返す @@ -120,7 +129,9 @@ def load_tokenizer( else: tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) __loaded_tokenizers[language] = tokenizer - logger.info(f"Loaded the {language} BERT tokenizer from {pretrained_model_name_or_path}") + logger.info( + f"Loaded the {language} BERT tokenizer from {pretrained_model_name_or_path}" + ) return tokenizer diff --git a/style_bert_vits2/nlp/chinese/g2p.py b/style_bert_vits2/nlp/chinese/g2p.py index 1cb3839f6..b5744cd83 100644 --- a/style_bert_vits2/nlp/chinese/g2p.py +++ b/style_bert_vits2/nlp/chinese/g2p.py @@ -95,7 +95,11 @@ def __g2p(segments: list[str]) -> tuple[list[str], list[int], list[int]]: if pinyin[0] in single_rep_map.keys(): pinyin = single_rep_map[pinyin[0]] + pinyin[1:] - assert pinyin in __PINYIN_TO_SYMBOL_MAP.keys(), (pinyin, seg, raw_pinyin) + assert pinyin in __PINYIN_TO_SYMBOL_MAP.keys(), ( + pinyin, + seg, + raw_pinyin, + ) phone = __PINYIN_TO_SYMBOL_MAP[pinyin].split(" ") word2ph.append(len(phone)) @@ -125,7 +129,7 @@ def __get_initials_finals(word: str) -> tuple[list[str], list[str]]: text = normalize_text(text) print(text) phones, tones, word2ph = g2p(text) - bert = extract_bert_feature(text, word2ph, 'cuda') + bert = extract_bert_feature(text, word2ph, "cuda") print(phones, tones, word2ph, bert.shape) diff --git a/style_bert_vits2/nlp/english/normalizer.py b/style_bert_vits2/nlp/english/normalizer.py index 81b71d7cc..f6ddc90c2 100644 --- a/style_bert_vits2/nlp/english/normalizer.py +++ b/style_bert_vits2/nlp/english/normalizer.py @@ -121,7 +121,9 @@ def __expand_number(m: re.Match[str]) -> str: else: return __INFLECT.number_to_words( num, andword="", zero="oh", group=2 # type: ignore - ).replace(", ", " ") # type: ignore + ).replace( + ", ", " " + ) # type: ignore else: return __INFLECT.number_to_words(num, andword="") # type: ignore diff --git a/style_bert_vits2/nlp/japanese/g2p.py b/style_bert_vits2/nlp/japanese/g2p.py index 7fc97f210..1f6b450d9 100644 --- a/style_bert_vits2/nlp/japanese/g2p.py +++ b/style_bert_vits2/nlp/japanese/g2p.py @@ -10,9 +10,7 @@ def g2p( - norm_text: str, - use_jp_extra: bool = True, - raise_yomi_error: bool = False + norm_text: str, use_jp_extra: bool = True, raise_yomi_error: bool = False ) -> tuple[list[str], list[int], list[int]]: """ 他で使われるメインの関数。`normalize_text()` で正規化された `norm_text` を受け取り、 @@ -93,8 +91,7 @@ def g2p( def text_to_sep_kata( - norm_text: str, - raise_yomi_error: bool = False + norm_text: str, raise_yomi_error: bool = False ) -> tuple[list[str], list[str]]: """ `normalize_text` で正規化済みの `norm_text` を受け取り、それを単語分割し、 @@ -212,7 +209,9 @@ def __g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: return result -def __pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> list[str]: +def __pyopenjtalk_g2p_prosody( + text: str, drop_unvoiced_vowels: bool = True +) -> list[str]: """ ESPnet の実装から引用、変更点無し。「ん」は「N」なことに注意。 ref: https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py @@ -414,8 +413,7 @@ def mora2phonemes(mora: str) -> str: def __align_tones( - phones_with_punct: list[str], - phone_tone_list: list[tuple[str, int]] + phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]] ) -> list[tuple[str, int]]: """ 例: …私は、、そう思う。 diff --git a/style_bert_vits2/nlp/japanese/g2p_utils.py b/style_bert_vits2/nlp/japanese/g2p_utils.py index 893d3b531..511793f34 100644 --- a/style_bert_vits2/nlp/japanese/g2p_utils.py +++ b/style_bert_vits2/nlp/japanese/g2p_utils.py @@ -34,11 +34,13 @@ def phone_tone2kata_tone(phone_tone: list[tuple[str, int]]) -> list[tuple[str, i """ # 子音の集合 - CONSONANTS = set([ - consonant - for consonant, _ in MORA_KATA_TO_MORA_PHONEMES.values() - if consonant is not None - ]) + CONSONANTS = set( + [ + consonant + for consonant, _ in MORA_KATA_TO_MORA_PHONEMES.values() + if consonant is not None + ] + ) phone_tone = phone_tone[1:] # 最初の("_", 0)を無視 phones = [phone for phone, _ in phone_tone] diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py index 212d21ab2..3a146b671 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/__init__.py @@ -25,6 +25,7 @@ def run_frontend(text: str) -> list[dict[str, Any]]: else: # without worker import pyopenjtalk + return pyopenjtalk.run_frontend(text) @@ -36,6 +37,7 @@ def make_label(njd_features: Any) -> list[str]: else: # without worker import pyopenjtalk + return pyopenjtalk.make_label(njd_features) @@ -45,6 +47,7 @@ def mecab_dict_index(path: str, out_path: str, dn_mecab: Optional[str] = None) - else: # without worker import pyopenjtalk + pyopenjtalk.mecab_dict_index(path, out_path, dn_mecab) @@ -54,6 +57,7 @@ def update_global_jtalk_with_user_dict(path: str) -> None: else: # without worker import pyopenjtalk + pyopenjtalk.update_global_jtalk_with_user_dict(path) @@ -63,6 +67,7 @@ def unset_user_dict() -> None: else: # without worker import pyopenjtalk + pyopenjtalk.unset_user_dict() @@ -102,7 +107,12 @@ def initialize_worker(port: int = WORKER_PORT) -> None: else: # align with Windows behavior # start_new_session is same as specifying setsid in preexec_fn - subprocess.Popen(args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True) + subprocess.Popen( + args, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, + ) # wait until server listening count = 0 diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py index 425cf25fd..c4c5606ea 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_client.py @@ -2,12 +2,15 @@ from typing import Any, cast from style_bert_vits2.logging import logger -from style_bert_vits2.nlp.japanese.pyopenjtalk_worker.worker_common import RequestType, receive_data, send_data +from style_bert_vits2.nlp.japanese.pyopenjtalk_worker.worker_common import ( + RequestType, + receive_data, + send_data, +) class WorkerClient: - """ pyopenjtalk worker client """ - + """pyopenjtalk worker client""" def __init__(self, port: int) -> None: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -16,19 +19,15 @@ def __init__(self, port: int) -> None: sock.connect((socket.gethostname(), port)) self.sock = sock - def __enter__(self) -> "WorkerClient": return self - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.close() - def close(self) -> None: self.sock.close() - def dispatch_pyopenjtalk(self, func: str, *args: Any, **kwargs: Any) -> Any: data = { "request-type": RequestType.PYOPENJTALK, @@ -43,7 +42,6 @@ def dispatch_pyopenjtalk(self, func: str, *args: Any, **kwargs: Any) -> Any: logger.trace(f"client received response: {response}") return response.get("return") - def status(self) -> int: data = {"request-type": RequestType.STATUS} logger.trace(f"client sends request: {data}") @@ -53,7 +51,6 @@ def status(self) -> int: logger.trace(f"client received response: {response}") return cast(int, response.get("client-count")) - def quit_server(self) -> None: data = {"request-type": RequestType.QUIT_SERVER} logger.trace(f"client sends request: {data}") diff --git a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py index 149323ad8..ed0d4e7a7 100644 --- a/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py +++ b/style_bert_vits2/nlp/japanese/pyopenjtalk_worker/worker_server.py @@ -26,14 +26,12 @@ class WorkerServer: - """ pyopenjtalk worker server """ - + """pyopenjtalk worker server""" def __init__(self) -> None: self.client_count: int = 0 self.quit: bool = False - def handle_request(self, request: dict[str, Any]) -> dict[str, Any]: request_type = None try: @@ -70,7 +68,6 @@ def handle_request(self, request: dict[str, Any]) -> dict[str, Any]: return response - def start_server(self, port: int, no_client_timeout: int = 30) -> None: logger.info("start pyopenjtalk worker server") with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: diff --git a/style_bert_vits2/nlp/japanese/user_dict/__init__.py b/style_bert_vits2/nlp/japanese/user_dict/__init__.py index a2cc43ea9..2a4aa2fc6 100644 --- a/style_bert_vits2/nlp/japanese/user_dict/__init__.py +++ b/style_bert_vits2/nlp/japanese/user_dict/__init__.py @@ -18,7 +18,11 @@ from style_bert_vits2.constants import DEFAULT_USER_DICT_DIR from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.nlp.japanese.user_dict.word_model import UserDictWord, WordTypes -from style_bert_vits2.nlp.japanese.user_dict.part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data +from style_bert_vits2.nlp.japanese.user_dict.part_of_speech_data import ( + MAX_PRIORITY, + MIN_PRIORITY, + part_of_speech_data, +) # root_dir = engine_root() # save_dir = get_save_dir() @@ -26,9 +30,13 @@ # if not save_dir.is_dir(): # save_dir.mkdir(parents=True) -default_dict_path = DEFAULT_USER_DICT_DIR / "default.csv" # VOICEVOXデフォルト辞書ファイルのパス +default_dict_path = ( + DEFAULT_USER_DICT_DIR / "default.csv" +) # VOICEVOXデフォルト辞書ファイルのパス user_dict_path = DEFAULT_USER_DICT_DIR / "user_dict.json" # ユーザー辞書ファイルのパス -compiled_dict_path = DEFAULT_USER_DICT_DIR / "user.dic" # コンパイル済み辞書ファイルのパス +compiled_dict_path = ( + DEFAULT_USER_DICT_DIR / "user.dic" +) # コンパイル済み辞書ファイルのパス # # 同時書き込みの制御 diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index 527b4abb3..a769ae620 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -25,7 +25,9 @@ from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.infer import get_net_g, infer from style_bert_vits2.models.models import SynthesizerTrn -from style_bert_vits2.models.models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra +from style_bert_vits2.models.models_jp_extra import ( + SynthesizerTrn as SynthesizerTrnJPExtra, +) from style_bert_vits2.logging import logger from style_bert_vits2.voice import adjust_voice @@ -36,7 +38,6 @@ class TTSModel: モデル/ハイパーパラメータ/スタイルベクトルのパスとデバイスを指定して初期化し、model.infer() メソッドを呼び出すと音声合成を行える。 """ - def __init__( self, model_path: Path, @@ -59,7 +60,9 @@ def __init__( self.config_path: Path = config_path self.style_vec_path: Path = style_vec_path self.device: str = device - self.hyper_parameters: HyperParameters = HyperParameters.load_from_json(self.config_path) + self.hyper_parameters: HyperParameters = HyperParameters.load_from_json( + self.config_path + ) self.spk2id: dict[str, int] = self.hyper_parameters.data.spk2id self.id2spk: dict[int, str] = {v: k for k, v in self.spk2id.items()} @@ -82,19 +85,17 @@ def __init__( self.__net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None - def load(self) -> None: """ 音声合成モデルをデバイスにロードする。 """ self.__net_g = get_net_g( - model_path = str(self.model_path), - version = self.hyper_parameters.version, - device = self.device, - hps = self.hyper_parameters, + model_path=str(self.model_path), + version=self.hyper_parameters.version, + device=self.device, + hps=self.hyper_parameters, ) - def __get_style_vector(self, style_id: int, weight: float = 1.0) -> NDArray[Any]: """ スタイルベクトルを取得する。 @@ -111,8 +112,9 @@ def __get_style_vector(self, style_id: int, weight: float = 1.0) -> NDArray[Any] style_vec = mean + (style_vec - mean) * weight return style_vec - - def __get_style_vector_from_audio(self, audio_path: str, weight: float = 1.0) -> NDArray[Any]: + def __get_style_vector_from_audio( + self, audio_path: str, weight: float = 1.0 + ) -> NDArray[Any]: """ 音声からスタイルベクトルを推論する。 @@ -126,8 +128,10 @@ def __get_style_vector_from_audio(self, audio_path: str, weight: float = 1.0) -> # スタイルベクトルを取得するための推論モデルを初期化 if self.__style_vector_inference is None: self.__style_vector_inference = pyannote.audio.Inference( - model = pyannote.audio.Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM"), - window = "whole", + model=pyannote.audio.Model.from_pretrained( + "pyannote/wespeaker-voxceleb-resnet34-LM" + ), + window="whole", ) self.__style_vector_inference.to(torch.device(self.device)) @@ -137,7 +141,6 @@ def __get_style_vector_from_audio(self, audio_path: str, weight: float = 1.0) -> xvec = mean + (xvec - mean) * weight return xvec - def infer( self, text: str, @@ -209,20 +212,20 @@ def infer( if not line_split: with torch.no_grad(): audio = infer( - text = text, - sdp_ratio = sdp_ratio, - noise_scale = noise, - noise_scale_w = noise_w, - length_scale = length, - sid = speaker_id, - language = language, - hps = self.hyper_parameters, - net_g = self.__net_g, - device = self.device, - assist_text = assist_text, - assist_text_weight = assist_text_weight, - style_vec = style_vector, - given_tone = given_tone, + text=text, + sdp_ratio=sdp_ratio, + noise_scale=noise, + noise_scale_w=noise_w, + length_scale=length, + sid=speaker_id, + language=language, + hps=self.hyper_parameters, + net_g=self.__net_g, + device=self.device, + assist_text=assist_text, + assist_text_weight=assist_text_weight, + style_vec=style_vector, + given_tone=given_tone, ) else: texts = text.split("\n") @@ -232,19 +235,19 @@ def infer( for i, t in enumerate(texts): audios.append( infer( - text = t, - sdp_ratio = sdp_ratio, - noise_scale = noise, - noise_scale_w = noise_w, - length_scale = length, - sid = speaker_id, - language = language, - hps = self.hyper_parameters, - net_g = self.__net_g, - device = self.device, - assist_text = assist_text, - assist_text_weight = assist_text_weight, - style_vec = style_vector, + text=t, + sdp_ratio=sdp_ratio, + noise_scale=noise, + noise_scale_w=noise_w, + length_scale=length, + sid=speaker_id, + language=language, + hps=self.hyper_parameters, + net_g=self.__net_g, + device=self.device, + assist_text=assist_text, + assist_text_weight=assist_text_weight, + style_vec=style_vector, ) ) if i != len(texts) - 1: @@ -253,10 +256,10 @@ def infer( logger.info("Audio data generated successfully") if not (pitch_scale == 1.0 and intonation_scale == 1.0): _, audio = adjust_voice( - fs = self.hyper_parameters.data.sampling_rate, - wave = audio, - pitch_scale = pitch_scale, - intonation_scale = intonation_scale, + fs=self.hyper_parameters.data.sampling_rate, + wave=audio, + pitch_scale=pitch_scale, + intonation_scale=intonation_scale, ) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -277,7 +280,6 @@ class TTSModelHolder: model_holder.models_info から指定されたディレクトリ内にある音声合成モデルの一覧を取得できる。 """ - def __init__(self, model_root_dir: Path, device: str) -> None: """ Style-Bert-Vits2 の音声合成モデルを管理するクラスを初期化する。 @@ -308,7 +310,6 @@ def __init__(self, model_root_dir: Path, device: str) -> None: self.models_info: list[TTSModelInfo] = [] self.refresh() - def refresh(self) -> None: """ 音声合成モデルの一覧を更新する。 @@ -342,13 +343,14 @@ def refresh(self) -> None: styles = list(style2id.keys()) spk2id: dict[str, int] = hyper_parameters.data.spk2id speakers = list(spk2id.keys()) - self.models_info.append(TTSModelInfo( - name = model_dir.name, - files = [str(f) for f in model_files], - styles = styles, - speakers = speakers, - )) - + self.models_info.append( + TTSModelInfo( + name=model_dir.name, + files=[str(f) for f in model_files], + styles=styles, + speakers=speakers, + ) + ) def get_model(self, model_name: str, model_path_str: str) -> TTSModel: """ @@ -370,16 +372,17 @@ def get_model(self, model_name: str, model_path_str: str) -> TTSModel: raise ValueError(f"Model file `{model_path}` is not found") if self.current_model is None or self.current_model.model_path != model_path: self.current_model = TTSModel( - model_path = model_path, - config_path = self.root_dir / model_name / "config.json", - style_vec_path = self.root_dir / model_name / "style_vectors.npy", - device = self.device, + model_path=model_path, + config_path=self.root_dir / model_name / "config.json", + style_vec_path=self.root_dir / model_name / "style_vectors.npy", + device=self.device, ) return self.current_model - - def get_model_for_gradio(self, model_name: str, model_path_str: str) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: + def get_model_for_gradio( + self, model_name: str, model_path_str: str + ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") @@ -398,10 +401,10 @@ def get_model_for_gradio(self, model_name: str, model_path_str: str) -> tuple[gr gr.Dropdown(choices=speakers, value=speakers[0]), # type: ignore ) self.current_model = TTSModel( - model_path = model_path, - config_path = self.root_dir / model_name / "config.json", - style_vec_path = self.root_dir / model_name / "style_vectors.npy", - device = self.device, + model_path=model_path, + config_path=self.root_dir / model_name / "config.json", + style_vec_path=self.root_dir / model_name / "style_vectors.npy", + device=self.device, ) speakers = list(self.current_model.spk2id.keys()) styles = list(self.current_model.style2id.keys()) @@ -411,13 +414,13 @@ def get_model_for_gradio(self, model_name: str, model_path_str: str) -> tuple[gr gr.Dropdown(choices=speakers, value=speakers[0]), # type: ignore ) - def update_model_files_for_gradio(self, model_name: str) -> gr.Dropdown: model_files = self.model_files_dict[model_name] return gr.Dropdown(choices=model_files, value=model_files[0]) # type: ignore - - def update_model_names_for_gradio(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]: + def update_model_names_for_gradio( + self, + ) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]: self.refresh() initial_model_name = self.model_names[0] initial_model_files = self.model_files_dict[initial_model_name] diff --git a/style_bert_vits2/utils/stdout_wrapper.py b/style_bert_vits2/utils/stdout_wrapper.py index 09254ada4..174d1fd26 100644 --- a/style_bert_vits2/utils/stdout_wrapper.py +++ b/style_bert_vits2/utils/stdout_wrapper.py @@ -8,40 +8,35 @@ class StdoutWrapper(TextIO): `sys.stdout` wrapper for both Google Colab and local environment. """ - def __init__(self) -> None: self.temp_file = tempfile.NamedTemporaryFile( mode="w+", delete=False, encoding="utf-8" ) self.original_stdout = sys.stdout - def write(self, message: str) -> int: result = self.temp_file.write(message) self.temp_file.flush() print(message, end="", file=self.original_stdout) return result - def flush(self) -> None: self.temp_file.flush() - def read(self, n: int = -1) -> str: self.temp_file.seek(0) return self.temp_file.read(n) - def close(self) -> None: self.temp_file.close() - def fileno(self) -> int: return self.temp_file.fileno() try: import google.colab # type: ignore + SAFE_STDOUT = StdoutWrapper() except ImportError: SAFE_STDOUT = sys.stdout diff --git a/style_bert_vits2/utils/strenum.py b/style_bert_vits2/utils/strenum.py index 40d3b0cb5..b2e2e3c28 100644 --- a/style_bert_vits2/utils/strenum.py +++ b/style_bert_vits2/utils/strenum.py @@ -9,27 +9,28 @@ class StrEnum(str, enum.Enum): def __new__(cls, *values: str) -> "StrEnum": "values must already be of type `str`" if len(values) > 3: - raise TypeError('too many arguments for str(): %r' % (values, )) + raise TypeError("too many arguments for str(): %r" % (values,)) if len(values) == 1: # it must be a string if not isinstance(values[0], str): # type: ignore - raise TypeError('%r is not a string' % (values[0], )) + raise TypeError("%r is not a string" % (values[0],)) if len(values) >= 2: # check that encoding argument is a string if not isinstance(values[1], str): # type: ignore - raise TypeError('encoding must be a string, not %r' % (values[1], )) + raise TypeError("encoding must be a string, not %r" % (values[1],)) if len(values) == 3: # check that errors argument is a string if not isinstance(values[2], str): # type: ignore - raise TypeError('errors must be a string, not %r' % (values[2])) + raise TypeError("errors must be a string, not %r" % (values[2])) value = str(*values) member = str.__new__(cls, value) member._value_ = value return member - @staticmethod - def _generate_next_value_(name: str, start: int, count: int, last_values: list[str]) -> str: + def _generate_next_value_( + name: str, start: int, count: int, last_values: list[str] + ) -> str: """ Return the lower-cased version of the member name. """ diff --git a/style_bert_vits2/utils/subprocess.py b/style_bert_vits2/utils/subprocess.py index 542f94b9d..f8e8e9454 100644 --- a/style_bert_vits2/utils/subprocess.py +++ b/style_bert_vits2/utils/subprocess.py @@ -6,7 +6,9 @@ from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -def run_script_with_log(cmd: list[str], ignore_warning: bool = False) -> tuple[bool, str]: +def run_script_with_log( + cmd: list[str], ignore_warning: bool = False +) -> tuple[bool, str]: """ 指定されたコマンドを実行し、そのログを記録する。 @@ -21,10 +23,10 @@ def run_script_with_log(cmd: list[str], ignore_warning: bool = False) -> tuple[b logger.info(f"Running: {' '.join(cmd)}") result = subprocess.run( [sys.executable] + cmd, - stdout = SAFE_STDOUT, - stderr = subprocess.PIPE, - text = True, - encoding = "utf-8", + stdout=SAFE_STDOUT, + stderr=subprocess.PIPE, + text=True, + encoding="utf-8", ) if result.returncode != 0: logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}") @@ -37,7 +39,9 @@ def run_script_with_log(cmd: list[str], ignore_warning: bool = False) -> tuple[b return True, "" -def second_elem_of(original_function: Callable[..., tuple[Any, Any]]) -> Callable[..., Any]: +def second_elem_of( + original_function: Callable[..., tuple[Any, Any]] +) -> Callable[..., Any]: """ 与えられた関数をラップし、その戻り値の 2 番目の要素のみを返す関数を生成する。 diff --git a/tests/test_main.py b/tests/test_main.py index 0c0fe7737..f00d90839 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,15 +5,15 @@ from style_bert_vits2.tts_model import TTSModelHolder -def synthesize(device: str = 'cpu'): +def synthesize(device: str = "cpu"): # 音声合成モデルが配置されていれば、音声合成を実行 - model_holder = TTSModelHolder(BASE_DIR / 'model_assets', device) + model_holder = TTSModelHolder(BASE_DIR / "model_assets", device) if len(model_holder.models_info) > 0: # jvnv-F2-jp モデルを探す for model_info in model_holder.models_info: - if model_info.name == 'jvnv-F2-jp': + if model_info.name == "jvnv-F2-jp": # すべてのスタイルに対して音声合成を実行 for style in model_info.styles: @@ -23,21 +23,21 @@ def synthesize(device: str = 'cpu'): sample_rate, audio_data = model.infer( "あらゆる現実を、すべて自分のほうへねじ曲げたのだ。", # 言語 (JP, EN, ZH / JP-Extra モデルの場合は JP のみ) - language = Languages.JP, + language=Languages.JP, # 話者 ID (音声合成モデルに複数の話者が含まれる場合のみ必須、単一話者のみの場合は 0) - speaker_id = 0, + speaker_id=0, # 感情表現の強さ (0.0 〜 1.0) - sdp_ratio = 0.4, + sdp_ratio=0.4, # スタイル (Neutral, Happy など) - style = style, + style=style, # スタイルの強さ (0.0 〜 100.0) - style_weight = 6.0, + style_weight=6.0, ) # 音声データを保存 - (BASE_DIR / 'tests/wavs').mkdir(exist_ok=True, parents=True) - wav_file_path = BASE_DIR / f'tests/wavs/{style}.wav' - with open(wav_file_path, 'wb') as f: + (BASE_DIR / "tests/wavs").mkdir(exist_ok=True, parents=True) + wav_file_path = BASE_DIR / f"tests/wavs/{style}.wav" + with open(wav_file_path, "wb") as f: wavfile.write(f, sample_rate, audio_data) # 音声データが保存されたことを確認 @@ -48,8 +48,8 @@ def synthesize(device: str = 'cpu'): def test_synthesize_cpu(): - synthesize(device='cpu') + synthesize(device="cpu") def test_synthesize_cuda(): - synthesize(device='cuda') + synthesize(device="cuda") diff --git a/train_ms.py b/train_ms.py index 3d50c0398..50a2453cb 100644 --- a/train_ms.py +++ b/train_ms.py @@ -281,28 +281,28 @@ def run(): mas_noise_scale_initial=mas_noise_scale_initial, noise_scale_delta=noise_scale_delta, # hps.model 以下のすべての値を引数に渡す - use_spk_conditioned_encoder = hps.model.use_spk_conditioned_encoder, - use_noise_scaled_mas = hps.model.use_noise_scaled_mas, - use_mel_posterior_encoder = hps.model.use_mel_posterior_encoder, - use_duration_discriminator = hps.model.use_duration_discriminator, - use_wavlm_discriminator = hps.model.use_wavlm_discriminator, - inter_channels = hps.model.inter_channels, - hidden_channels = hps.model.hidden_channels, - filter_channels = hps.model.filter_channels, - n_heads = hps.model.n_heads, - n_layers = hps.model.n_layers, - kernel_size = hps.model.kernel_size, - p_dropout = hps.model.p_dropout, - resblock = hps.model.resblock, - resblock_kernel_sizes = hps.model.resblock_kernel_sizes, - resblock_dilation_sizes = hps.model.resblock_dilation_sizes, - upsample_rates = hps.model.upsample_rates, - upsample_initial_channel = hps.model.upsample_initial_channel, - upsample_kernel_sizes = hps.model.upsample_kernel_sizes, - n_layers_q = hps.model.n_layers_q, - use_spectral_norm = hps.model.use_spectral_norm, - gin_channels = hps.model.gin_channels, - slm = hps.model.slm, + use_spk_conditioned_encoder=hps.model.use_spk_conditioned_encoder, + use_noise_scaled_mas=hps.model.use_noise_scaled_mas, + use_mel_posterior_encoder=hps.model.use_mel_posterior_encoder, + use_duration_discriminator=hps.model.use_duration_discriminator, + use_wavlm_discriminator=hps.model.use_wavlm_discriminator, + inter_channels=hps.model.inter_channels, + hidden_channels=hps.model.hidden_channels, + filter_channels=hps.model.filter_channels, + n_heads=hps.model.n_heads, + n_layers=hps.model.n_layers, + kernel_size=hps.model.kernel_size, + p_dropout=hps.model.p_dropout, + resblock=hps.model.resblock, + resblock_kernel_sizes=hps.model.resblock_kernel_sizes, + resblock_dilation_sizes=hps.model.resblock_dilation_sizes, + upsample_rates=hps.model.upsample_rates, + upsample_initial_channel=hps.model.upsample_initial_channel, + upsample_kernel_sizes=hps.model.upsample_kernel_sizes, + n_layers_q=hps.model.n_layers_q, + use_spectral_norm=hps.model.use_spectral_norm, + gin_channels=hps.model.gin_channels, + slm=hps.model.slm, ).cuda(local_rank) if getattr(hps.train, "freeze_ZH_bert", False): @@ -389,7 +389,9 @@ def run(): epoch_str = max(epoch_str, 1) # global_step = (epoch_str - 1) * len(train_loader) global_step = int( - utils.get_steps(utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth")) + utils.get_steps( + utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth") + ) ) logger.info( f"******************Found the model. Current epoch is {epoch_str}, gloabl step is {global_step}*********************" diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index 5ed68a103..321a040c1 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -288,28 +288,28 @@ def run(): mas_noise_scale_initial=mas_noise_scale_initial, noise_scale_delta=noise_scale_delta, # hps.model 以下のすべての値を引数に渡す - use_spk_conditioned_encoder = hps.model.use_spk_conditioned_encoder, - use_noise_scaled_mas = hps.model.use_noise_scaled_mas, - use_mel_posterior_encoder = hps.model.use_mel_posterior_encoder, - use_duration_discriminator = hps.model.use_duration_discriminator, - use_wavlm_discriminator = hps.model.use_wavlm_discriminator, - inter_channels = hps.model.inter_channels, - hidden_channels = hps.model.hidden_channels, - filter_channels = hps.model.filter_channels, - n_heads = hps.model.n_heads, - n_layers = hps.model.n_layers, - kernel_size = hps.model.kernel_size, - p_dropout = hps.model.p_dropout, - resblock = hps.model.resblock, - resblock_kernel_sizes = hps.model.resblock_kernel_sizes, - resblock_dilation_sizes = hps.model.resblock_dilation_sizes, - upsample_rates = hps.model.upsample_rates, - upsample_initial_channel = hps.model.upsample_initial_channel, - upsample_kernel_sizes = hps.model.upsample_kernel_sizes, - n_layers_q = hps.model.n_layers_q, - use_spectral_norm = hps.model.use_spectral_norm, - gin_channels = hps.model.gin_channels, - slm = hps.model.slm, + use_spk_conditioned_encoder=hps.model.use_spk_conditioned_encoder, + use_noise_scaled_mas=hps.model.use_noise_scaled_mas, + use_mel_posterior_encoder=hps.model.use_mel_posterior_encoder, + use_duration_discriminator=hps.model.use_duration_discriminator, + use_wavlm_discriminator=hps.model.use_wavlm_discriminator, + inter_channels=hps.model.inter_channels, + hidden_channels=hps.model.hidden_channels, + filter_channels=hps.model.filter_channels, + n_heads=hps.model.n_heads, + n_layers=hps.model.n_layers, + kernel_size=hps.model.kernel_size, + p_dropout=hps.model.p_dropout, + resblock=hps.model.resblock, + resblock_kernel_sizes=hps.model.resblock_kernel_sizes, + resblock_dilation_sizes=hps.model.resblock_dilation_sizes, + upsample_rates=hps.model.upsample_rates, + upsample_initial_channel=hps.model.upsample_initial_channel, + upsample_kernel_sizes=hps.model.upsample_kernel_sizes, + n_layers_q=hps.model.n_layers_q, + use_spectral_norm=hps.model.use_spectral_norm, + gin_channels=hps.model.gin_channels, + slm=hps.model.slm, ).cuda(local_rank) if getattr(hps.train, "freeze_JP_bert", False): logger.info("Freezing (JP) bert encoder !!!") @@ -383,7 +383,9 @@ def run(): if net_dur_disc is not None: try: _, _, dur_resume_lr, epoch_str = utils.checkpoints.load_checkpoint( - utils.checkpoints.get_latest_checkpoint_path(model_dir, "DUR_*.pth"), + utils.checkpoints.get_latest_checkpoint_path( + model_dir, "DUR_*.pth" + ), net_dur_disc, optim_dur_disc, skip_optimizer=hps.train.skip_optimizer, @@ -396,11 +398,15 @@ def run(): print("Initialize dur_disc") if net_wd is not None: try: - _, optim_wd, wd_resume_lr, epoch_str = utils.checkpoints.load_checkpoint( - utils.checkpoints.get_latest_checkpoint_path(model_dir, "WD_*.pth"), - net_wd, - optim_wd, - skip_optimizer=hps.train.skip_optimizer, + _, optim_wd, wd_resume_lr, epoch_str = ( + utils.checkpoints.load_checkpoint( + utils.checkpoints.get_latest_checkpoint_path( + model_dir, "WD_*.pth" + ), + net_wd, + optim_wd, + skip_optimizer=hps.train.skip_optimizer, + ) ) if not optim_wd.param_groups[0].get("initial_lr"): optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr @@ -430,7 +436,9 @@ def run(): epoch_str = max(epoch_str, 1) # global_step = (epoch_str - 1) * len(train_loader) global_step = int( - utils.get_steps(utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth")) + utils.get_steps( + utils.checkpoints.get_latest_checkpoint_path(model_dir, "G_*.pth") + ) ) logger.info( f"******************Found the model. Current epoch is {epoch_str}, gloabl step is {global_step}*********************" From 164b5c4a85744b259598bcac4b89b20dbf966e8f Mon Sep 17 00:00:00 2001 From: litagin02 Date: Mon, 11 Mar 2024 10:33:00 +0900 Subject: [PATCH 084/148] Clean unused tools module (previously used in webui.py) --- tools/__init__.py | 3 - tools/classify_language.py | 197 ------------------------------------- tools/sentence.py | 173 -------------------------------- tools/translate.py | 62 ------------ 4 files changed, 435 deletions(-) delete mode 100644 tools/__init__.py delete mode 100644 tools/classify_language.py delete mode 100644 tools/sentence.py delete mode 100644 tools/translate.py diff --git a/tools/__init__.py b/tools/__init__.py deleted file mode 100644 index b68d33295..000000000 --- a/tools/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -工具包 -""" diff --git a/tools/classify_language.py b/tools/classify_language.py deleted file mode 100644 index 2b8a7ab42..000000000 --- a/tools/classify_language.py +++ /dev/null @@ -1,197 +0,0 @@ -import regex as re - -try: - from config import config - - LANGUAGE_IDENTIFICATION_LIBRARY = ( - config.webui_config.language_identification_library - ) -except: - LANGUAGE_IDENTIFICATION_LIBRARY = "langid" - -module = LANGUAGE_IDENTIFICATION_LIBRARY.lower() - -langid_languages = [ - "af", - "am", - "an", - "ar", - "as", - "az", - "be", - "bg", - "bn", - "br", - "bs", - "ca", - "cs", - "cy", - "da", - "de", - "dz", - "el", - "en", - "eo", - "es", - "et", - "eu", - "fa", - "fi", - "fo", - "fr", - "ga", - "gl", - "gu", - "he", - "hi", - "hr", - "ht", - "hu", - "hy", - "id", - "is", - "it", - "ja", - "jv", - "ka", - "kk", - "km", - "kn", - "ko", - "ku", - "ky", - "la", - "lb", - "lo", - "lt", - "lv", - "mg", - "mk", - "ml", - "mn", - "mr", - "ms", - "mt", - "nb", - "ne", - "nl", - "nn", - "no", - "oc", - "or", - "pa", - "pl", - "ps", - "pt", - "qu", - "ro", - "ru", - "rw", - "se", - "si", - "sk", - "sl", - "sq", - "sr", - "sv", - "sw", - "ta", - "te", - "th", - "tl", - "tr", - "ug", - "uk", - "ur", - "vi", - "vo", - "wa", - "xh", - "zh", - "zu", -] - - -def classify_language(text: str, target_languages: list = None) -> str: - if module == "fastlid" or module == "fasttext": - from fastlid import fastlid, supported_langs - - classifier = fastlid - if target_languages != None: - target_languages = [ - lang for lang in target_languages if lang in supported_langs - ] - fastlid.set_languages = target_languages - elif module == "langid": - import langid - - classifier = langid.classify - if target_languages != None: - target_languages = [ - lang for lang in target_languages if lang in langid_languages - ] - langid.set_languages(target_languages) - else: - raise ValueError(f"Wrong module {module}") - - lang = classifier(text)[0] - - return lang - - -def classify_zh_ja(text: str) -> str: - for idx, char in enumerate(text): - unicode_val = ord(char) - - # 检测日语字符 - if 0x3040 <= unicode_val <= 0x309F or 0x30A0 <= unicode_val <= 0x30FF: - return "ja" - - # 检测汉字字符 - if 0x4E00 <= unicode_val <= 0x9FFF: - # 检查周围的字符 - next_char = text[idx + 1] if idx + 1 < len(text) else None - - if next_char and ( - 0x3040 <= ord(next_char) <= 0x309F or 0x30A0 <= ord(next_char) <= 0x30FF - ): - return "ja" - - return "zh" - - -def split_alpha_nonalpha(text, mode=1): - if mode == 1: - pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\d\s])(?=[\p{Latin}])|(?<=[\p{Latin}\s])(?=[\u4e00-\u9fff\u3040-\u30FF\d])" - elif mode == 2: - pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\s])(?=[\p{Latin}\d])|(?<=[\p{Latin}\d\s])(?=[\u4e00-\u9fff\u3040-\u30FF])" - else: - raise ValueError("Invalid mode. Supported modes are 1 and 2.") - - return re.split(pattern, text) - - -if __name__ == "__main__": - text = "这是一个测试文本" - print(classify_language(text)) - print(classify_zh_ja(text)) # "zh" - - text = "これはテストテキストです" - print(classify_language(text)) - print(classify_zh_ja(text)) # "ja" - - text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days" - - print(split_alpha_nonalpha(text, mode=1)) - # output: ['vits', '和', 'Bert-VITS', '2是', 'tts', '模型。花费3', 'days.花费3天。Take 3 days'] - - print(split_alpha_nonalpha(text, mode=2)) - # output: ['vits', '和', 'Bert-VITS2', '是', 'tts', '模型。花费', '3days.花费', '3', '天。Take 3 days'] - - text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days" - print(split_alpha_nonalpha(text, mode=1)) - # output: ['vits ', '和 ', 'Bert-VITS', '2 ', '是 ', 'tts ', '模型。花费3', 'days.花费3天。Take ', '3 ', 'days'] - - text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days" - print(split_alpha_nonalpha(text, mode=2)) - # output: ['vits ', '和 ', 'Bert-VITS2 ', '是 ', 'tts ', '模型。花费', '3days.花费', '3', '天。Take ', '3 ', 'days'] diff --git a/tools/sentence.py b/tools/sentence.py deleted file mode 100644 index b66864ca0..000000000 --- a/tools/sentence.py +++ /dev/null @@ -1,173 +0,0 @@ -import logging - -import regex as re - -from tools.classify_language import classify_language, split_alpha_nonalpha - - -def check_is_none(item) -> bool: - """none -> True, not none -> False""" - return ( - item is None - or (isinstance(item, str) and str(item).isspace()) - or str(item) == "" - ) - - -def markup_language(text: str, target_languages: list = None) -> str: - pattern = ( - r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`" - r"\!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」" - r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+" - ) - sentences = re.split(pattern, text) - - pre_lang = "" - p = 0 - - if target_languages is not None: - sorted_target_languages = sorted(target_languages) - if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]: - new_sentences = [] - for sentence in sentences: - new_sentences.extend(split_alpha_nonalpha(sentence)) - sentences = new_sentences - - for sentence in sentences: - if check_is_none(sentence): - continue - - lang = classify_language(sentence, target_languages) - - if pre_lang == "": - text = text[:p] + text[p:].replace( - sentence, f"[{lang.upper()}]{sentence}", 1 - ) - p += len(f"[{lang.upper()}]") - elif pre_lang != lang: - text = text[:p] + text[p:].replace( - sentence, f"[{pre_lang.upper()}][{lang.upper()}]{sentence}", 1 - ) - p += len(f"[{pre_lang.upper()}][{lang.upper()}]") - pre_lang = lang - p += text[p:].index(sentence) + len(sentence) - text += f"[{pre_lang.upper()}]" - - return text - - -def split_by_language(text: str, target_languages: list = None) -> list: - pattern = ( - r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`" - r"\!?\。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」" - r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+" - ) - sentences = re.split(pattern, text) - - pre_lang = "" - start = 0 - end = 0 - sentences_list = [] - - if target_languages is not None: - sorted_target_languages = sorted(target_languages) - if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]: - new_sentences = [] - for sentence in sentences: - new_sentences.extend(split_alpha_nonalpha(sentence)) - sentences = new_sentences - - for sentence in sentences: - if check_is_none(sentence): - continue - - lang = classify_language(sentence, target_languages) - - end += text[end:].index(sentence) - if pre_lang != "" and pre_lang != lang: - sentences_list.append((text[start:end], pre_lang)) - start = end - end += len(sentence) - pre_lang = lang - sentences_list.append((text[start:], pre_lang)) - - return sentences_list - - -def sentence_split(text: str, max: int) -> list: - pattern = r"[!(),—+\-.:;??。,、;:]+" - sentences = re.split(pattern, text) - discarded_chars = re.findall(pattern, text) - - sentences_list, count, p = [], 0, 0 - - # 按被分割的符号遍历 - for i, discarded_chars in enumerate(discarded_chars): - count += len(sentences[i]) + len(discarded_chars) - if count >= max: - sentences_list.append(text[p : p + count].strip()) - p += count - count = 0 - - # 加入最后剩余的文本 - if p < len(text): - sentences_list.append(text[p:]) - - return sentences_list - - -def sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None): - # 如果该speaker只支持一种语言 - if speaker_lang is not None and len(speaker_lang) == 1: - if lang.upper() not in ["AUTO", "MIX"] and lang.lower() != speaker_lang[0]: - logging.debug( - f'lang "{lang}" is not in speaker_lang {speaker_lang},automatically set lang={speaker_lang[0]}' - ) - lang = speaker_lang[0] - - sentences_list = [] - if lang.upper() != "MIX": - if max <= 0: - sentences_list.append( - markup_language(text, speaker_lang) - if lang.upper() == "AUTO" - else f"[{lang.upper()}]{text}[{lang.upper()}]" - ) - else: - for i in sentence_split(text, max): - if check_is_none(i): - continue - sentences_list.append( - markup_language(i, speaker_lang) - if lang.upper() == "AUTO" - else f"[{lang.upper()}]{i}[{lang.upper()}]" - ) - else: - sentences_list.append(text) - - for i in sentences_list: - logging.debug(i) - - return sentences_list - - -if __name__ == "__main__": - text = "这几天心里颇不宁静。今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。我悄悄地披了大衫,带上门出去。" - print(markup_language(text, target_languages=None)) - print(sentence_split(text, max=50)) - print(sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None)) - - text = "你好,这是一段用来测试自动标注的文本。こんにちは,これは自動ラベリングのテスト用テキストです.Hello, this is a piece of text to test autotagging.你好!今天我们要介绍VITS项目,其重点是使用了GAN Duration predictor和transformer flow,并且接入了Bert模型来提升韵律。Bert embedding会在稍后介绍。" - print(split_by_language(text, ["zh", "ja", "en"])) - - text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days" - - print(split_by_language(text, ["zh", "ja", "en"])) - # output: [('vits', 'en'), ('和', 'ja'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')] - - print(split_by_language(text, ["zh", "en"])) - # output: [('vits', 'en'), ('和', 'zh'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')] - - text = "vits 和 Bert-VITS2 是 tts 模型。花费 3 days. 花费 3天。Take 3 days" - print(split_by_language(text, ["zh", "en"])) - # output: [('vits ', 'en'), ('和 ', 'zh'), ('Bert-VITS2 ', 'en'), ('是 ', 'zh'), ('tts ', 'en'), ('模型。花费 ', 'zh'), ('3 days. ', 'en'), ('花费 3天。', 'zh'), ('Take 3 days', 'en')] diff --git a/tools/translate.py b/tools/translate.py deleted file mode 100644 index be0f7ea45..000000000 --- a/tools/translate.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -翻译api -""" - -from config import config - -import random -import hashlib -import requests - - -def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""): - """ - :param Sentence: 待翻译语句 - :param from_Language: 待翻译语句语言 - :param to_Language: 目标语言 - :return: 翻译后语句 出错时返回None - - 常见语言代码:中文 zh 英语 en 日语 jp - """ - appid = config.translate_config.app_key - key = config.translate_config.secret_key - if appid == "" or key == "": - return "请开发者在config.yml中配置app_key与secret_key" - url = "https://fanyi-api.baidu.com/api/trans/vip/translate" - texts = Sentence.splitlines() - outTexts = [] - for t in texts: - if t != "": - # 签名计算 参考文档 https://api.fanyi.baidu.com/product/113 - salt = str(random.randint(1, 100000)) - signString = appid + t + salt + key - hs = hashlib.md5() - hs.update(signString.encode("utf-8")) - signString = hs.hexdigest() - if from_Language == "": - from_Language = "auto" - headers = {"Content-Type": "application/x-www-form-urlencoded"} - payload = { - "q": t, - "from": from_Language, - "to": to_Language, - "appid": appid, - "salt": salt, - "sign": signString, - } - # 发送请求 - try: - response = requests.post( - url=url, data=payload, headers=headers, timeout=3 - ) - response = response.json() - if "trans_result" in response.keys(): - result = response["trans_result"][0] - if "dst" in result.keys(): - dst = result["dst"] - outTexts.append(dst) - except Exception: - return Sentence - else: - outTexts.append(t) - return "\n".join(outTexts) From dada0af2b6ebb525d7132f4fe92c99f89c44160a Mon Sep 17 00:00:00 2001 From: litagin02 Date: Mon, 11 Mar 2024 12:37:51 +0900 Subject: [PATCH 085/148] Format import using isort --profile black --gitignore --lai 2 . --- config.py | 1 + data_utils.py | 1 + default_style.py | 7 ++++--- gen_yaml.py | 1 + mel_processing.py | 4 +++- preprocess_all.py | 4 +++- preprocess_text.py | 1 + resample.py | 3 ++- server_fastapi.py | 1 + slice.py | 1 + speech_mos.py | 1 + style_bert_vits2/models/attentions.py | 2 +- style_bert_vits2/models/commons.py | 3 ++- style_bert_vits2/models/infer.py | 5 ++--- style_bert_vits2/models/models.py | 5 +---- style_bert_vits2/models/models_jp_extra.py | 5 +---- style_bert_vits2/models/monotonic_alignment.py | 5 +++-- style_bert_vits2/models/utils/__init__.py | 3 ++- style_bert_vits2/nlp/__init__.py | 7 ++++--- style_bert_vits2/nlp/bert_models.py | 2 +- style_bert_vits2/nlp/chinese/g2p.py | 2 +- style_bert_vits2/nlp/chinese/tone_sandhi.py | 3 +-- style_bert_vits2/nlp/japanese/normalizer.py | 1 + style_bert_vits2/nlp/japanese/user_dict/__init__.py | 3 ++- .../nlp/japanese/user_dict/part_of_speech_data.py | 1 + style_bert_vits2/nlp/japanese/user_dict/word_model.py | 1 + style_gen.py | 4 +++- train_ms.py | 4 ++-- train_ms_jp_extra.py | 4 ++-- 29 files changed, 50 insertions(+), 35 deletions(-) diff --git a/config.py b/config.py index 6369e6bbd..77c384d76 100644 --- a/config.py +++ b/config.py @@ -11,6 +11,7 @@ from style_bert_vits2.logging import logger + # If not cuda available, set possible devices to cpu cuda_available = torch.cuda.is_available() diff --git a/data_utils.py b/data_utils.py index 04047e210..73d4303c8 100644 --- a/data_utils.py +++ b/data_utils.py @@ -15,6 +15,7 @@ from style_bert_vits2.models.utils import load_filepaths_and_text, load_wav_to_torch from style_bert_vits2.nlp import cleaned_text_to_sequence + """Multi speaker version""" diff --git a/default_style.py b/default_style.py index 67b6fc353..e75257d9d 100644 --- a/default_style.py +++ b/default_style.py @@ -1,9 +1,10 @@ +import json import os -from style_bert_vits2.constants import DEFAULT_STYLE -from style_bert_vits2.logging import logger import numpy as np -import json + +from style_bert_vits2.constants import DEFAULT_STYLE +from style_bert_vits2.logging import logger def set_style_config(json_path, output_path): diff --git a/gen_yaml.py b/gen_yaml.py index 76df20646..ac27103ea 100644 --- a/gen_yaml.py +++ b/gen_yaml.py @@ -1,6 +1,7 @@ import argparse import os import shutil + import yaml diff --git a/mel_processing.py b/mel_processing.py index e9e7ec362..02cd7d4bd 100644 --- a/mel_processing.py +++ b/mel_processing.py @@ -1,7 +1,9 @@ +import warnings + import torch import torch.utils.data from librosa.filters import mel as librosa_mel_fn -import warnings + # warnings.simplefilter(action='ignore', category=FutureWarning) warnings.filterwarnings(action="ignore") diff --git a/preprocess_all.py b/preprocess_all.py index 62c0b4f6e..c41159cde 100644 --- a/preprocess_all.py +++ b/preprocess_all.py @@ -1,7 +1,9 @@ import argparse -from webui.train import preprocess_all from multiprocessing import cpu_count +from webui.train import preprocess_all + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/preprocess_text.py b/preprocess_text.py index 4966305e3..ef2e2590e 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -12,6 +12,7 @@ from style_bert_vits2.nlp import clean_text from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + preprocess_text_config = config.preprocess_text_config diff --git a/resample.py b/resample.py index 7001af6a9..b63c64a23 100644 --- a/resample.py +++ b/resample.py @@ -7,9 +7,10 @@ import soundfile from tqdm import tqdm +from config import config from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -from config import config + DEFAULT_BLOCK_SIZE: float = 0.400 # seconds diff --git a/server_fastapi.py b/server_fastapi.py index 1f60b3c57..4ac4a7e51 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -38,6 +38,7 @@ from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.tts_model import TTSModel, TTSModelHolder + ln = config.server_config.language diff --git a/slice.py b/slice.py index c69f8bf88..fb5bb7165 100644 --- a/slice.py +++ b/slice.py @@ -11,6 +11,7 @@ from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + vad_model, utils = torch.hub.load( repo_or_dir="snakers4/silero-vad", model="silero_vad", diff --git a/speech_mos.py b/speech_mos.py index 6dd2caa58..453b7d313 100644 --- a/speech_mos.py +++ b/speech_mos.py @@ -14,6 +14,7 @@ from style_bert_vits2.logging import logger from style_bert_vits2.tts_model import TTSModel + warnings.filterwarnings("ignore") mos_result_dir = Path("mos_results") diff --git a/style_bert_vits2/models/attentions.py b/style_bert_vits2/models/attentions.py index 9a101120a..b851b5329 100644 --- a/style_bert_vits2/models/attentions.py +++ b/style_bert_vits2/models/attentions.py @@ -1,6 +1,6 @@ +import math from typing import Any, Optional -import math import torch from torch import nn from torch.nn import functional as F diff --git a/style_bert_vits2/models/commons.py b/style_bert_vits2/models/commons.py index da8993018..38f548cd7 100644 --- a/style_bert_vits2/models/commons.py +++ b/style_bert_vits2/models/commons.py @@ -3,9 +3,10 @@ コードと完全に一致している保証はない。あくまで参考程度とすること。 """ +from typing import Any, Optional, Union + import torch from torch.nn import functional as F -from typing import Any, Optional, Union def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: diff --git a/style_bert_vits2/models/infer.py b/style_bert_vits2/models/infer.py index b0ab9d293..a9bcbf658 100644 --- a/style_bert_vits2/models/infer.py +++ b/style_bert_vits2/models/infer.py @@ -1,12 +1,11 @@ -from typing import Any, cast, Optional, Union +from typing import Any, Optional, Union, cast import torch from numpy.typing import NDArray from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger -from style_bert_vits2.models import commons -from style_bert_vits2.models import utils +from style_bert_vits2.models import commons, utils from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import ( diff --git a/style_bert_vits2/models/models.py b/style_bert_vits2/models/models.py index 21a0be487..56fb27c62 100644 --- a/style_bert_vits2/models/models.py +++ b/style_bert_vits2/models/models.py @@ -7,10 +7,7 @@ from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm -from style_bert_vits2.models import attentions -from style_bert_vits2.models import commons -from style_bert_vits2.models import modules -from style_bert_vits2.models import monotonic_alignment +from style_bert_vits2.models import attentions, commons, modules, monotonic_alignment from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS diff --git a/style_bert_vits2/models/models_jp_extra.py b/style_bert_vits2/models/models_jp_extra.py index 00cc02ffc..2850baf20 100644 --- a/style_bert_vits2/models/models_jp_extra.py +++ b/style_bert_vits2/models/models_jp_extra.py @@ -7,10 +7,7 @@ from torch.nn import functional as F from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm -from style_bert_vits2.models import attentions -from style_bert_vits2.models import commons -from style_bert_vits2.models import modules -from style_bert_vits2.models import monotonic_alignment +from style_bert_vits2.models import attentions, commons, modules, monotonic_alignment from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS diff --git a/style_bert_vits2/models/monotonic_alignment.py b/style_bert_vits2/models/monotonic_alignment.py index d33631e41..8ec9d247f 100644 --- a/style_bert_vits2/models/monotonic_alignment.py +++ b/style_bert_vits2/models/monotonic_alignment.py @@ -3,10 +3,11 @@ コードと完全に一致している保証はない。あくまで参考程度とすること。 """ +from typing import Any + import numba import torch -from numpy import int32, float32, zeros -from typing import Any +from numpy import float32, int32, zeros def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: diff --git a/style_bert_vits2/models/utils/__init__.py b/style_bert_vits2/models/utils/__init__.py index 0fd3a47ab..f17837289 100644 --- a/style_bert_vits2/models/utils/__init__.py +++ b/style_bert_vits2/models/utils/__init__.py @@ -4,7 +4,7 @@ import re import subprocess from pathlib import Path -from typing import Any, Optional, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import torch @@ -15,6 +15,7 @@ from style_bert_vits2.models.utils import checkpoints # type: ignore from style_bert_vits2.models.utils import safetensors # type: ignore + if TYPE_CHECKING: # tensorboard はライブラリとしてインストールされている場合は依存関係に含まれないため、型チェック時のみインポートする from torch.utils.tensorboard import SummaryWriter diff --git a/style_bert_vits2/nlp/__init__.py b/style_bert_vits2/nlp/__init__.py index 5f3d63f6f..b0c908e42 100644 --- a/style_bert_vits2/nlp/__init__.py +++ b/style_bert_vits2/nlp/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from style_bert_vits2.constants import Languages from style_bert_vits2.nlp.symbols import ( @@ -7,6 +7,7 @@ SYMBOLS, ) + # __init__.py は配下のモジュールをインポートした時点で実行される # PyTorch のインポートは重いので、型チェック時以外はインポートしない if TYPE_CHECKING: @@ -99,10 +100,10 @@ def cleaned_text_to_sequence( cleaned_phones: list[str], tones: list[int], language: Languages ) -> tuple[list[int], list[int], list[int]]: """ - テキスト文字列を、テキスト内の記号に対応する一連の ID に変換する + 音素リスト・アクセントリスト・言語を、テキスト内の対応する ID に変換する Args: - cleaned_phones (list[str]): clean_text() でクリーニングされた音素のリスト (?) + cleaned_phones (list[str]): clean_text() でクリーニングされた音素のリスト tones (list[int]): 各音素のアクセント language (Languages): テキストの言語 diff --git a/style_bert_vits2/nlp/bert_models.py b/style_bert_vits2/nlp/bert_models.py index 220d58448..1e346a481 100644 --- a/style_bert_vits2/nlp/bert_models.py +++ b/style_bert_vits2/nlp/bert_models.py @@ -9,7 +9,7 @@ """ import gc -from typing import cast, Optional, Union +from typing import Optional, Union, cast import torch from transformers import ( diff --git a/style_bert_vits2/nlp/chinese/g2p.py b/style_bert_vits2/nlp/chinese/g2p.py index b5744cd83..004616742 100644 --- a/style_bert_vits2/nlp/chinese/g2p.py +++ b/style_bert_vits2/nlp/chinese/g2p.py @@ -2,7 +2,7 @@ from pathlib import Path import jieba.posseg as psg -from pypinyin import lazy_pinyin, Style +from pypinyin import Style, lazy_pinyin from style_bert_vits2.nlp.chinese.tone_sandhi import ToneSandhi from style_bert_vits2.nlp.symbols import PUNCTUATIONS diff --git a/style_bert_vits2/nlp/chinese/tone_sandhi.py b/style_bert_vits2/nlp/chinese/tone_sandhi.py index 5832434fd..552cb0d36 100644 --- a/style_bert_vits2/nlp/chinese/tone_sandhi.py +++ b/style_bert_vits2/nlp/chinese/tone_sandhi.py @@ -13,8 +13,7 @@ # limitations under the License. import jieba -from pypinyin import lazy_pinyin -from pypinyin import Style +from pypinyin import Style, lazy_pinyin class ToneSandhi: diff --git a/style_bert_vits2/nlp/japanese/normalizer.py b/style_bert_vits2/nlp/japanese/normalizer.py index b8cad9045..07b742c0a 100644 --- a/style_bert_vits2/nlp/japanese/normalizer.py +++ b/style_bert_vits2/nlp/japanese/normalizer.py @@ -1,5 +1,6 @@ import re import unicodedata + from num2words import num2words from style_bert_vits2.nlp.symbols import PUNCTUATIONS diff --git a/style_bert_vits2/nlp/japanese/user_dict/__init__.py b/style_bert_vits2/nlp/japanese/user_dict/__init__.py index 2a4aa2fc6..53ea63fdc 100644 --- a/style_bert_vits2/nlp/japanese/user_dict/__init__.py +++ b/style_bert_vits2/nlp/japanese/user_dict/__init__.py @@ -17,12 +17,13 @@ from style_bert_vits2.constants import DEFAULT_USER_DICT_DIR from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk -from style_bert_vits2.nlp.japanese.user_dict.word_model import UserDictWord, WordTypes from style_bert_vits2.nlp.japanese.user_dict.part_of_speech_data import ( MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data, ) +from style_bert_vits2.nlp.japanese.user_dict.word_model import UserDictWord, WordTypes + # root_dir = engine_root() # save_dir = get_save_dir() diff --git a/style_bert_vits2/nlp/japanese/user_dict/part_of_speech_data.py b/style_bert_vits2/nlp/japanese/user_dict/part_of_speech_data.py index 443bdc521..a48f0c589 100644 --- a/style_bert_vits2/nlp/japanese/user_dict/part_of_speech_data.py +++ b/style_bert_vits2/nlp/japanese/user_dict/part_of_speech_data.py @@ -14,6 +14,7 @@ WordTypes, ) + MIN_PRIORITY = USER_DICT_MIN_PRIORITY MAX_PRIORITY = USER_DICT_MAX_PRIORITY diff --git a/style_bert_vits2/nlp/japanese/user_dict/word_model.py b/style_bert_vits2/nlp/japanese/user_dict/word_model.py index bcd4d377f..c85a5b954 100644 --- a/style_bert_vits2/nlp/japanese/user_dict/word_model.py +++ b/style_bert_vits2/nlp/japanese/user_dict/word_model.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field, validator + USER_DICT_MIN_PRIORITY = 0 USER_DICT_MAX_PRIORITY = 10 diff --git a/style_gen.py b/style_gen.py index 5190575b0..f9685cf7f 100644 --- a/style_gen.py +++ b/style_gen.py @@ -1,6 +1,6 @@ import argparse -from concurrent.futures import ThreadPoolExecutor import warnings +from concurrent.futures import ThreadPoolExecutor import numpy as np import torch @@ -11,9 +11,11 @@ from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + warnings.filterwarnings("ignore", category=UserWarning) from pyannote.audio import Inference, Model + model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") inference = Inference(model, window="whole") device = torch.device(config.style_gen_config.device) diff --git a/train_ms.py b/train_ms.py index 50a2453cb..c1a392e5b 100644 --- a/train_ms.py +++ b/train_ms.py @@ -24,8 +24,7 @@ from losses import discriminator_loss, feature_loss, generator_loss, kl_loss from mel_processing import mel_spectrogram_torch, spec_to_mel_torch from style_bert_vits2.logging import logger -from style_bert_vits2.models import commons -from style_bert_vits2.models import utils +from style_bert_vits2.models import commons, utils from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.models import ( DurationDiscriminator, @@ -35,6 +34,7 @@ from style_bert_vits2.nlp.symbols import SYMBOLS from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = ( True # If encontered training problem,please try to disable TF32. diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index 321a040c1..722a52197 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -24,8 +24,7 @@ from losses import WavLMLoss, discriminator_loss, feature_loss, generator_loss, kl_loss from mel_processing import mel_spectrogram_torch, spec_to_mel_torch from style_bert_vits2.logging import logger -from style_bert_vits2.models import commons -from style_bert_vits2.models import utils +from style_bert_vits2.models import commons, utils from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.models_jp_extra import ( DurationDiscriminator, @@ -36,6 +35,7 @@ from style_bert_vits2.nlp.symbols import SYMBOLS from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = ( True # If encontered training problem,please try to disable TF32. From 44851a62199ffadd9a4f1d18d2294dcdd10dc1f9 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Mon, 11 Mar 2024 17:05:42 +0900 Subject: [PATCH 086/148] Clean and fix docs --- re_matching.py | 81 -------------------- style_bert_vits2/models/utils/__init__.py | 4 +- style_bert_vits2/models/utils/checkpoints.py | 4 +- style_bert_vits2/models/utils/safetensors.py | 4 +- style_bert_vits2/tts_model.py | 15 ++-- style_bert_vits2/voice.py | 2 +- 6 files changed, 13 insertions(+), 97 deletions(-) delete mode 100644 re_matching.py diff --git a/re_matching.py b/re_matching.py deleted file mode 100644 index dd464a5cc..000000000 --- a/re_matching.py +++ /dev/null @@ -1,81 +0,0 @@ -import re - - -def extract_language_and_text_updated(speaker, dialogue): - # 使用正则表达式匹配<语言>标签和其后的文本 - pattern_language_text = r"<(\S+?)>([^<]+)" - matches = re.findall(pattern_language_text, dialogue, re.DOTALL) - speaker = speaker[1:-1] - # 清理文本:去除两边的空白字符 - matches_cleaned = [(lang.upper(), text.strip()) for lang, text in matches] - matches_cleaned.append(speaker) - return matches_cleaned - - -def validate_text(input_text): - # 验证说话人的正则表达式 - pattern_speaker = r"(\[\S+?\])((?:\s*<\S+?>[^<\[\]]+?)+)" - - # 使用re.DOTALL标志使.匹配包括换行符在内的所有字符 - matches = re.findall(pattern_speaker, input_text, re.DOTALL) - - # 对每个匹配到的说话人内容进行进一步验证 - for _, dialogue in matches: - language_text_matches = extract_language_and_text_updated(_, dialogue) - if not language_text_matches: - return ( - False, - "Error: Invalid format detected in dialogue content. Please check your input.", - ) - - # 如果输入的文本中没有找到任何匹配项 - if not matches: - return ( - False, - "Error: No valid speaker format detected. Please check your input.", - ) - - return True, "Input is valid." - - -def text_matching(text: str) -> list: - speaker_pattern = r"(\[\S+?\])(.+?)(?=\[\S+?\]|$)" - matches = re.findall(speaker_pattern, text, re.DOTALL) - result = [] - for speaker, dialogue in matches: - result.append(extract_language_and_text_updated(speaker, dialogue)) - return result - - -def cut_para(text): - splitted_para = re.split("[\n]", text) # 按段分 - splitted_para = [ - sentence.strip() for sentence in splitted_para if sentence.strip() - ] # 删除空字符串 - return splitted_para - - -def cut_sent(para): - para = re.sub("([。!;?\?])([^”’])", r"\1\n\2", para) # 单字符断句符 - para = re.sub("(\.{6})([^”’])", r"\1\n\2", para) # 英文省略号 - para = re.sub("(\…{2})([^”’])", r"\1\n\2", para) # 中文省略号 - para = re.sub("([。!?\?][”’])([^,。!?\?])", r"\1\n\2", para) - para = para.rstrip() # 段尾如果有多余的\n就去掉它 - return para.split("\n") - - -if __name__ == "__main__": - text = """ - [说话人1] - [说话人2]你好吗?元気ですか?こんにちは,世界。你好吗? - [说话人3]谢谢。どういたしまして。 - """ - text_matching(text) - # 测试函数 - test_text = """ - [说话人1]你好,こんにちは!こんにちは,世界。 - [说话人2]你好吗? - """ - text_matching(test_text) - res = validate_text(test_text) - print(res) diff --git a/style_bert_vits2/models/utils/__init__.py b/style_bert_vits2/models/utils/__init__.py index f17837289..33e13247f 100644 --- a/style_bert_vits2/models/utils/__init__.py +++ b/style_bert_vits2/models/utils/__init__.py @@ -215,13 +215,13 @@ def get_logger( def get_steps(model_path: Union[str, Path]) -> Optional[int]: """ - モデルのパスからイテレーション番号を取得する + モデルのパスからイテレーション回数を取得する Args: model_path (Union[str, Path]): モデルのパス Returns: - Optional[int]: イテレーション番号 + Optional[int]: イテレーション回数 """ matches = re.findall(r"\d+", model_path) # type: ignore diff --git a/style_bert_vits2/models/utils/checkpoints.py b/style_bert_vits2/models/utils/checkpoints.py index 768973bcc..63a5fa3da 100644 --- a/style_bert_vits2/models/utils/checkpoints.py +++ b/style_bert_vits2/models/utils/checkpoints.py @@ -27,7 +27,7 @@ def load_checkpoint( for_infer (bool): 推論用に読み込むかどうかのフラグ Returns: - tuple[torch.nn.Module, Optional[torch.optim.Optimizer], float, int]: 更新されたモデルとオプティマイザー、学習率、イテレーション番号 + tuple[torch.nn.Module, Optional[torch.optim.Optimizer], float, int]: 更新されたモデルとオプティマイザー、学習率、イテレーション回数 """ assert os.path.isfile(checkpoint_path) @@ -104,7 +104,7 @@ def save_checkpoint( model (torch.nn.Module): 保存するモデル optimizer (Union[torch.optim.Optimizer, torch.optim.AdamW]): 保存するオプティマイザー learning_rate (float): 学習率 - iteration (int): イテレーション数 + iteration (int): イテレーション回数 checkpoint_path (Union[str, Path]): 保存先のパス """ logger.info( diff --git a/style_bert_vits2/models/utils/safetensors.py b/style_bert_vits2/models/utils/safetensors.py index 8917c778e..52ab115b2 100644 --- a/style_bert_vits2/models/utils/safetensors.py +++ b/style_bert_vits2/models/utils/safetensors.py @@ -22,7 +22,7 @@ def load_safetensors( for_infer (bool): 推論用に読み込むかどうかのフラグ Returns: - tuple[torch.nn.Module, Optional[int]]: 読み込まれたモデルとイテレーション番号(存在する場合) + tuple[torch.nn.Module, Optional[int]]: 読み込まれたモデルとイテレーション回数(存在する場合) """ tensors: dict[str, Any] = {} @@ -64,7 +64,7 @@ def save_safetensors( Args: model (torch.nn.Module): 保存するモデル - iteration (int): イテレーション番号 + iteration (int): イテレーション回数 checkpoint_path (Union[str, Path]): 保存先のパス is_half (bool): モデルを半精度で保存するかどうかのフラグ for_infer (bool): 推論用に保存するかどうかのフラグ diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index a769ae620..4276c970e 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -28,6 +28,7 @@ from style_bert_vits2.models.models_jp_extra import ( SynthesizerTrn as SynthesizerTrnJPExtra, ) + from style_bert_vits2.logging import logger from style_bert_vits2.voice import adjust_voice @@ -39,11 +40,7 @@ class TTSModel: """ def __init__( - self, - model_path: Path, - config_path: Path, - style_vec_path: Path, - device: str, + self, model_path: Path, config_path: Path, style_vec_path: Path, device: str ) -> None: """ Style-Bert-Vits2 の音声合成モデルを初期化する。 @@ -170,10 +167,10 @@ def infer( language (Languages, optional): 言語. Defaults to Languages.JP. speaker_id (int, optional): 話者 ID. Defaults to 0. reference_audio_path (Optional[str], optional): 音声スタイルの参照元の音声ファイルのパス. Defaults to None. - sdp_ratio (float, optional): SDP レシオ (値を大きくするとより感情豊かになる傾向がある). Defaults to DEFAULT_SDP_RATIO. - noise (float, optional): ノイズの大きさ. Defaults to DEFAULT_NOISE. - noise_w (float, optional): ノイズの大きさの重み. Defaults to DEFAULT_NOISEW. - length (float, optional): 長さ. Defaults to DEFAULT_LENGTH. + sdp_ratio (float, optional): DP と SDP の混合比。0 で DP のみ、1で SDP のみを使用 (値を大きくするとテンポに緩急がつく). Defaults to DEFAULT_SDP_RATIO. + noise (float, optional): DP に与えられるノイズ. Defaults to DEFAULT_NOISE. + noise_w (float, optional): SDP に与えられるノイズ. Defaults to DEFAULT_NOISEW. + length (float, optional): 生成音声の長さ(話速)のパラメータ。大きいほど生成音声が長くゆっくり、小さいほど短く早くなる。 Defaults to DEFAULT_LENGTH. line_split (bool, optional): テキストを改行ごとに分割して生成するかどうか. Defaults to DEFAULT_LINE_SPLIT. split_interval (float, optional): 改行ごとに分割する場合の無音 (秒). Defaults to DEFAULT_SPLIT_INTERVAL. assist_text (Optional[str], optional): 感情表現の参照元の補助テキスト. Defaults to None. diff --git a/style_bert_vits2/voice.py b/style_bert_vits2/voice.py index ed7843f73..75f7d51d3 100644 --- a/style_bert_vits2/voice.py +++ b/style_bert_vits2/voice.py @@ -19,7 +19,7 @@ def adjust_voice( fs (int): 音声のサンプリング周波数 wave (NDArray[Any]): 音声データ pitch_scale (float, optional): ピッチの高さ. Defaults to 1.0. - intonation_scale (float, optional): イントネーションの高さ. Defaults to 1.0. + intonation_scale (float, optional): イントネーションの平均からの変更比率. Defaults to 1.0. Returns: tuple[int, NDArray[Any]]: 調整後の音声データのサンプリング周波数と音声データ From b3cc705caa1d8572cc41d6906d6511c4f1e41709 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Mon, 11 Mar 2024 17:51:55 +0900 Subject: [PATCH 087/148] Fix: add use_jp_extra parameter to config.json (for non-jp-extra training) --- configs/config.json | 1 + style_bert_vits2/models/hyper_parameters.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/configs/config.json b/configs/config.json index 2f2ce7f0f..b8082bc79 100644 --- a/configs/config.json +++ b/configs/config.json @@ -24,6 +24,7 @@ "freeze_encoder": false }, "data": { + "use_jp_extra": false, "training_files": "Data/your_model_name/filelists/train.list", "validation_files": "Data/your_model_name/filelists/val.list", "max_wav_value": 32768.0, diff --git a/style_bert_vits2/models/hyper_parameters.py b/style_bert_vits2/models/hyper_parameters.py index 30c579e4a..254b285b7 100644 --- a/style_bert_vits2/models/hyper_parameters.py +++ b/style_bert_vits2/models/hyper_parameters.py @@ -16,7 +16,7 @@ class HyperParametersTrain(BaseModel): seed: int = 42 epochs: int = 1000 learning_rate: float = 0.0001 - betas: list[float] = [0.8, 0.99] + betas: tuple[float, float] = (0.8, 0.99) eps: float = 1e-9 batch_size: int = 2 bf16_run: bool = False @@ -50,7 +50,7 @@ class HyperParametersData(BaseModel): mel_fmin: float = 0.0 mel_fmax: Optional[float] = None add_blank: bool = True - n_speakers: int = 512 + n_speakers: int = 1 cleaned_text: bool = True spk2id: dict[str, int] = { "Dummy": 0, From 7736532c751a932d9d985adde7949196ff46b74b Mon Sep 17 00:00:00 2001 From: litagin02 Date: Mon, 11 Mar 2024 17:56:09 +0900 Subject: [PATCH 088/148] Match config.json values to default hyper_parameters.py --- configs/config.json | 8 ++++---- configs/configs_jp_extra.json | 7 ++++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/configs/config.json b/configs/config.json index b8082bc79..07a9304b8 100644 --- a/configs/config.json +++ b/configs/config.json @@ -1,5 +1,5 @@ { - "model_name": "your_model_name", + "model_name": "Dummy", "train": { "log_interval": 200, "eval_interval": 1000, @@ -25,8 +25,8 @@ }, "data": { "use_jp_extra": false, - "training_files": "Data/your_model_name/filelists/train.list", - "validation_files": "Data/your_model_name/filelists/val.list", + "training_files": "Data/Dummy/train.list", + "validation_files": "Data/Dummy/val.list", "max_wav_value": 32768.0, "sampling_rate": 44100, "filter_length": 2048, @@ -69,5 +69,5 @@ "use_spectral_norm": false, "gin_channels": 256 }, - "version": "2.3.1" + "version": "2.4" } diff --git a/configs/configs_jp_extra.json b/configs/configs_jp_extra.json index b566be165..bc0a9a4eb 100644 --- a/configs/configs_jp_extra.json +++ b/configs/configs_jp_extra.json @@ -1,4 +1,5 @@ { + "model_name": "Dummy", "train": { "log_interval": 200, "eval_interval": 1000, @@ -27,8 +28,8 @@ }, "data": { "use_jp_extra": true, - "training_files": "filelists/train.list", - "validation_files": "filelists/val.list", + "training_files": "Data/Dummy/train.list", + "validation_files": "Data/Dummy/val.list", "max_wav_value": 32768.0, "sampling_rate": 44100, "filter_length": 2048, @@ -75,5 +76,5 @@ "initial_channel": 64 } }, - "version": "2.3.1-JP-Extra" + "version": "2.4-JP-Extra" } From bc0729d970b1e2d490ca67f6136f9b81532e66e7 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 15:23:18 +0900 Subject: [PATCH 089/148] Fix: initialize worker and dict for use --- app.py | 20 +++++++++++++------- bert_gen.py | 9 +++++++++ preprocess_all.py | 8 ++++++++ preprocess_text.py | 8 ++++++++ server_fastapi.py | 5 +++++ webui/__init__.py | 16 ---------------- 6 files changed, 43 insertions(+), 23 deletions(-) delete mode 100644 webui/__init__.py diff --git a/app.py b/app.py index 60b91a686..0869ca570 100644 --- a/app.py +++ b/app.py @@ -6,15 +6,21 @@ import yaml from style_bert_vits2.constants import GRADIO_THEME, VERSION +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker +from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.tts_model import TTSModelHolder -from webui import ( - create_dataset_app, - create_inference_app, - create_merge_app, - create_style_vectors_app, - create_train_app, -) +from webui.dataset import create_dataset_app +from webui.inference import create_inference_app +from webui.merge import create_merge_app +from webui.style_vectors import create_style_vectors_app +from webui.train import create_train_app + + +# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 +pyopenjtalk_worker.initialize_worker() +# dict_data/ 以下の辞書データを pyopenjtalk に適用 +update_dict() # Get path settings with Path("configs/paths.yml").open("r", encoding="utf-8") as f: diff --git a/bert_gen.py b/bert_gen.py index 5a16af7e5..2ded36962 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -10,9 +10,18 @@ from style_bert_vits2.models import commons from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.nlp import cleaned_text_to_sequence, extract_bert_feature +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker +from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT +# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 +pyopenjtalk_worker.initialize_worker() + +# dict_data/ 以下の辞書データを pyopenjtalk に適用 +update_dict() + + def process_line(x): line, add_blank = x device = config.bert_gen_config.device diff --git a/preprocess_all.py b/preprocess_all.py index c41159cde..2eaf846b2 100644 --- a/preprocess_all.py +++ b/preprocess_all.py @@ -1,9 +1,17 @@ import argparse from multiprocessing import cpu_count +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker +from style_bert_vits2.nlp.japanese.user_dict import update_dict from webui.train import preprocess_all +# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 +pyopenjtalk_worker.initialize_worker() + +# dict_data/ 以下の辞書データを pyopenjtalk に適用 +update_dict() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/preprocess_text.py b/preprocess_text.py index ef2e2590e..6bc858563 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -10,9 +10,17 @@ from config import config from style_bert_vits2.logging import logger from style_bert_vits2.nlp import clean_text +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker +from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT +# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 +pyopenjtalk_worker.initialize_worker() + +# dict_data/ 以下の辞書データを pyopenjtalk に適用 +update_dict() + preprocess_text_config = config.preprocess_text_config diff --git a/server_fastapi.py b/server_fastapi.py index 4ac4a7e51..a8e8c93a7 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -35,7 +35,9 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.nlp import bert_models +from style_bert_vits2.nlp.japanese import pyopenjtalk_worker from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk +from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.tts_model import TTSModel, TTSModelHolder @@ -46,6 +48,9 @@ ## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する pyopenjtalk.initialize_worker() +# dict_data/ 以下の辞書データを pyopenjtalk に適用 +update_dict() + # 事前に BERT モデル/トークナイザーをロードしておく ## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い bert_models.load_model(Languages.JP) diff --git a/webui/__init__.py b/webui/__init__.py deleted file mode 100644 index 4bc8efab8..000000000 --- a/webui/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .dataset import create_dataset_app -from .inference import create_inference_app -from .merge import create_merge_app -from .style_vectors import create_style_vectors_app -from .train import create_train_app - - -class TrainSettings: - def __init__(self, setting_json): - self.setting_json = setting_json - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - pass From 3632d476c1a4eb99530af86e19962ec4a8884f5b Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 15:24:42 +0900 Subject: [PATCH 090/148] Clean --- server_fastapi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server_fastapi.py b/server_fastapi.py index a8e8c93a7..d544212a1 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -35,7 +35,6 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.nlp import bert_models -from style_bert_vits2.nlp.japanese import pyopenjtalk_worker from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.tts_model import TTSModel, TTSModelHolder From 005706a45658b6b2537ab123580293ee0280e6d1 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 17:26:03 +0900 Subject: [PATCH 091/148] Fix: colab notebook for 2.4 --- colab.ipynb | 42 +++++------------------------------ style_bert_vits2/tts_model.py | 3 +-- 2 files changed, 6 insertions(+), 39 deletions(-) diff --git a/colab.ipynb b/colab.ipynb index a48affce7..97a5b8def 100644 --- a/colab.ipynb +++ b/colab.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Style-Bert-VITS2 (ver 2.3.1) のGoogle Colabでの学習\n", + "# Style-Bert-VITS2 (ver 2.4) のGoogle Colabでの学習\n", "\n", "Google Colab上でStyle-Bert-VITS2の学習を行うことができます。\n", "\n", @@ -219,7 +219,7 @@ "batch_size = 4\n", "\n", "# 学習のエポック数(データセットを合計何周するか)。\n", - "# 100ぐらいで十分かもしれませんが、もっと多くやると質が上がるのかもしれません。\n", + "# 100で多すぎるほどかもしれませんが、もっと多くやると質が上がるのかもしれません。\n", "epochs = 100\n", "\n", "# 保存頻度。何ステップごとにモデルを保存するか。分からなければデフォルトのままで。\n", @@ -255,7 +255,7 @@ }, "outputs": [], "source": [ - "from webui_train import preprocess_all\n", + "from webui.train import preprocess_all\n", "\n", "preprocess_all(\n", " model_name=model_name,\n", @@ -307,7 +307,7 @@ "\n", "\n", "import yaml\n", - "from webui_train import get_path\n", + "from webui.train import get_path\n", "\n", "dataset_path, _, _, _, config_path = get_path(model_name)\n", "\n", @@ -350,41 +350,9 @@ }, "outputs": [], "source": [ - "#@title 学習結果を試すならここから\n", + "# 学習結果を試す・マージ・スタイル分けはこちらから\n", "!python app.py --share --dir {assets_root}" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. スタイル分け" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!python webui_style_vectors.py --share" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. マージ" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!python webui_merge.py --share" - ] } ], "metadata": { diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index 4276c970e..f7d4f2194 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -22,14 +22,13 @@ DEFAULT_STYLE_WEIGHT, Languages, ) +from style_bert_vits2.logging import logger from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.models.infer import get_net_g, infer from style_bert_vits2.models.models import SynthesizerTrn from style_bert_vits2.models.models_jp_extra import ( SynthesizerTrn as SynthesizerTrnJPExtra, ) - -from style_bert_vits2.logging import logger from style_bert_vits2.voice import adjust_voice From dfb4513ea7ad755a86c8e6a6e6f20dff002f667d Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 17:29:24 +0900 Subject: [PATCH 092/148] Change default port to None to avoid conflicting --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index 0869ca570..281eb2764 100644 --- a/app.py +++ b/app.py @@ -31,7 +31,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--host", type=str, default="127.0.0.1") -parser.add_argument("--port", type=int, default=7860) +parser.add_argument("--port", type=int, default=None) parser.add_argument("--no_autolaunch", action="store_true") parser.add_argument("--share", action="store_true") From 77ee026ab65b1d53a9a5348ea913901a3bf63ac4 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 17:47:00 +0900 Subject: [PATCH 093/148] Update .dockerignore --- .dockerignore | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/.dockerignore b/.dockerignore index a90abd2db..da10b929b 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,23 +3,15 @@ * +!/style_bert_vits2/ + !/bert/deberta-v2-large-japanese-char-wwm/ !/common/ !/configs/ !/dict_data/default.csv !/model_assets/ -!/monotonic_align/ -!/text/ -!/attentions.py -!/commons.py !/config.py !/default_config.yml -!/infer.py -!/models.py -!/models_jp_extra.py -!/modules.py !/requirements.txt !/server_editor.py -!/transforms.py -!/utils.py From 456fe00ae4dfd6ab7ee765877b5e2011578e244c Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 17:51:09 +0900 Subject: [PATCH 094/148] Delete old webui bat files --- Dataset.bat | 11 ----------- Merge.bat | 13 ------------- Style.bat | 12 ------------ Train.bat | 13 ------------- 4 files changed, 49 deletions(-) delete mode 100644 Dataset.bat delete mode 100644 Merge.bat delete mode 100644 Style.bat delete mode 100644 Train.bat diff --git a/Dataset.bat b/Dataset.bat deleted file mode 100644 index 03d3850a0..000000000 --- a/Dataset.bat +++ /dev/null @@ -1,11 +0,0 @@ -chcp 65001 > NUL -@echo off - -pushd %~dp0 -echo Running webui_dataset.py... -venv\Scripts\python webui_dataset.py - -if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) - -popd -pause \ No newline at end of file diff --git a/Merge.bat b/Merge.bat deleted file mode 100644 index 8e1dedba6..000000000 --- a/Merge.bat +++ /dev/null @@ -1,13 +0,0 @@ -chcp 65001 > NUL - -@echo off - -pushd %~dp0 - -echo Running webui_merge.py... -venv\Scripts\python webui_merge.py - -if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) - -popd -pause \ No newline at end of file diff --git a/Style.bat b/Style.bat deleted file mode 100644 index 409cf9105..000000000 --- a/Style.bat +++ /dev/null @@ -1,12 +0,0 @@ -chcp 65001 > NUL - -@echo off - -pushd %~dp0 -echo Running webui_style_vectors.py... -venv\Scripts\python webui_style_vectors.py - -if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) - -popd -pause \ No newline at end of file diff --git a/Train.bat b/Train.bat deleted file mode 100644 index 5a93d026f..000000000 --- a/Train.bat +++ /dev/null @@ -1,13 +0,0 @@ -chcp 65001 > NUL - -@echo off - -pushd %~dp0 - -echo Running webui_train.py... -venv\Scripts\python webui_train.py - -if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) - -popd -pause \ No newline at end of file From 4acfc2977e2c2c7c291ca424a6a4f401544d83eb Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 19:10:43 +0900 Subject: [PATCH 095/148] Ensure to write use_jp_extra option to config.json --- colab.ipynb | 4 ++-- style_bert_vits2/utils/stdout_wrapper.py | 2 +- webui/train.py | 3 +++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/colab.ipynb b/colab.ipynb index 97a5b8def..06840884a 100644 --- a/colab.ipynb +++ b/colab.ipynb @@ -35,8 +35,8 @@ "metadata": {}, "outputs": [], "source": [ - "#@title このセルを実行して環境構築してください。\n", - "#@markdown 最後に赤文字でエラーや警告が出ても何故かうまくいくみたいです。\n", + "# このセルを実行して環境構築してください。\n", + "# エラーダイアログ「WARNING: The following packages were previously imported in this runtime: [pydevd_plugins]」が出るが「キャンセル」を選択して続行してください。\n", "\n", "!git clone https://github.com/litagin02/Style-Bert-VITS2.git\n", "%cd Style-Bert-VITS2/\n", diff --git a/style_bert_vits2/utils/stdout_wrapper.py b/style_bert_vits2/utils/stdout_wrapper.py index 174d1fd26..95659dc9a 100644 --- a/style_bert_vits2/utils/stdout_wrapper.py +++ b/style_bert_vits2/utils/stdout_wrapper.py @@ -35,7 +35,7 @@ def fileno(self) -> int: try: - import google.colab # type: ignore + # import google.colab # type: ignore SAFE_STDOUT = StdoutWrapper() except ImportError: diff --git a/webui/train.py b/webui/train.py index d83747149..7ce8c1e91 100644 --- a/webui/train.py +++ b/webui/train.py @@ -87,6 +87,9 @@ def initialize( config["train"]["bf16_run"] = False # デフォルトでFalseのはずだが念のため + # 今はデフォルトであるが、以前は非JP-Extra版になくバグの原因になるので念のため + config["data"]["use_jp_extra"] = use_jp_extra + model_path = os.path.join(dataset_path, "models") if os.path.exists(model_path): logger.warning( From f12f20f3dadd85e1cc0353c2d4933ac08c9ed6da Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 19:13:13 +0900 Subject: [PATCH 096/148] Rollback of SAFE_STDOUT debug --- style_bert_vits2/utils/stdout_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/style_bert_vits2/utils/stdout_wrapper.py b/style_bert_vits2/utils/stdout_wrapper.py index 95659dc9a..174d1fd26 100644 --- a/style_bert_vits2/utils/stdout_wrapper.py +++ b/style_bert_vits2/utils/stdout_wrapper.py @@ -35,7 +35,7 @@ def fileno(self) -> int: try: - # import google.colab # type: ignore + import google.colab # type: ignore SAFE_STDOUT = StdoutWrapper() except ImportError: From 528d2cc4ba4c2395d83b46583e51f751c44b00ac Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 20:05:32 +0900 Subject: [PATCH 097/148] Refactor: typing and pathlib --- slice.py | 48 +++++++++++++++++++++++++++--------------------- style_gen.py | 29 ++++++++++++----------------- transcribe.py | 13 ++++++++----- 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/slice.py b/slice.py index fb5bb7165..99fac19ed 100644 --- a/slice.py +++ b/slice.py @@ -12,6 +12,8 @@ from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT +# TODO: 並列処理による高速化 + vad_model, utils = torch.hub.load( repo_or_dir="snakers4/silero-vad", model="silero_vad", @@ -23,7 +25,10 @@ def get_stamps( - audio_file, min_silence_dur_ms: int = 700, min_sec: float = 2, max_sec: float = 12 + audio_file: Path, + min_silence_dur_ms: int = 700, + min_sec: float = 2, + max_sec: float = 12, ): """ min_silence_dur_ms: int (ミリ秒): @@ -42,7 +47,7 @@ def get_stamps( min_ms = int(min_sec * 1000) - wav = read_audio(audio_file, sampling_rate=sampling_rate) + wav = read_audio(str(audio_file), sampling_rate=sampling_rate) speech_timestamps = get_speech_timestamps( wav, vad_model, @@ -56,13 +61,13 @@ def get_stamps( def split_wav( - audio_file, - target_dir="raw", - min_sec=2, - max_sec=12, - min_silence_dur_ms=700, -): - margin = 200 # ミリ秒単位で、音声の前後に余裕を持たせる + audio_file: Path, + target_dir: Path, + min_sec: float = 2, + max_sec: float = 12, + min_silence_dur_ms: int = 700, +) -> tuple[float, int]: + margin: int = 200 # ミリ秒単位で、音声の前後に余裕を持たせる speech_timestamps = get_stamps( audio_file, min_silence_dur_ms=min_silence_dur_ms, @@ -74,10 +79,10 @@ def split_wav( total_ms = len(data) / sr * 1000 - file_name = os.path.basename(audio_file).split(".")[0] - os.makedirs(target_dir, exist_ok=True) + file_name = audio_file.stem + target_dir.mkdir(parents=True, exist_ok=True) - total_time_ms = 0 + total_time_ms: float = 0 count = 0 # タイムスタンプに従って分割し、ファイルに保存 @@ -89,7 +94,7 @@ def split_wav( end_sample = int(end_ms / 1000 * sr) segment = data[start_sample:end_sample] - sf.write(os.path.join(target_dir, f"{file_name}-{i}.wav"), segment, sr) + sf.write(str(target_dir / f"{file_name}-{i}.wav"), segment, sr) total_time_ms += end_ms - start_ms count += 1 @@ -126,20 +131,21 @@ def split_wav( ) args = parser.parse_args() - with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: + with open(Path("configs/paths.yml"), "r", encoding="utf-8") as f: path_config: dict[str, str] = yaml.safe_load(f.read()) dataset_root = path_config["dataset_root"] - input_dir = args.input_dir - output_dir = os.path.join(dataset_root, args.model_name, "raw") - min_sec = args.min_sec - max_sec = args.max_sec - min_silence_dur_ms = args.min_silence_dur_ms + model_name = str(args.model_name) + input_dir = Path(args.input_dir) + output_dir = Path(dataset_root) / model_name / "raw" + min_sec: float = args.min_sec + max_sec: float = args.max_sec + min_silence_dur_ms: int = args.min_silence_dur_ms wav_files = Path(input_dir).glob("**/*.wav") wav_files = list(wav_files) logger.info(f"Found {len(wav_files)} wav files.") - if os.path.exists(output_dir): + if output_dir.exists(): logger.warning(f"Output directory {output_dir} already exists, deleting...") shutil.rmtree(output_dir) @@ -147,7 +153,7 @@ def split_wav( total_count = 0 for wav_file in tqdm(wav_files, file=SAFE_STDOUT): time_sec, count = split_wav( - audio_file=str(wav_file), + audio_file=wav_file, target_dir=output_dir, min_sec=min_sec, max_sec=max_sec, diff --git a/style_gen.py b/style_gen.py index f9685cf7f..b22292d14 100644 --- a/style_gen.py +++ b/style_gen.py @@ -1,9 +1,11 @@ import argparse import warnings from concurrent.futures import ThreadPoolExecutor +from typing import Any import numpy as np import torch +from numpy.typing import NDArray from tqdm import tqdm from config import config @@ -11,11 +13,9 @@ from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT - warnings.filterwarnings("ignore", category=UserWarning) from pyannote.audio import Inference, Model - model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") inference = Inference(model, window="whole") device = torch.device(config.style_gen_config.device) @@ -29,11 +29,11 @@ class NaNValueError(ValueError): # 推論時にインポートするために短いが関数を書く -def get_style_vector(wav_path): - return inference(wav_path) +def get_style_vector(wav_path: str) -> NDArray[Any]: + return inference(wav_path) # type: ignore -def save_style_vector(wav_path): +def save_style_vector(wav_path: str): try: style_vec = get_style_vector(wav_path) except Exception as e: @@ -48,20 +48,15 @@ def save_style_vector(wav_path): np.save(f"{wav_path}.npy", style_vec) # `test.wav` -> `test.wav.npy` -def process_line(line): - wavname = line.split("|")[0] +def process_line(line: str): + wav_path = line.split("|")[0] try: - save_style_vector(wavname) + save_style_vector(wav_path) return line, None except NaNValueError: return line, "nan_error" -def save_average_style_vector(style_vectors, filename="style_vectors.npy"): - average_vector = np.mean(style_vectors, axis=0) - np.save(filename, average_vector) - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -71,14 +66,14 @@ def save_average_style_vector(style_vectors, filename="style_vectors.npy"): "--num_processes", type=int, default=config.style_gen_config.num_processes ) args, _ = parser.parse_known_args() - config_path = args.config - num_processes = args.num_processes + config_path: str = args.config + num_processes: int = args.num_processes hps = HyperParameters.load_from_json(config_path) device = config.style_gen_config.device - training_lines = [] + training_lines: list[str] = [] with open(hps.data.training_files, encoding="utf-8") as f: training_lines.extend(f.readlines()) with ThreadPoolExecutor(max_workers=num_processes) as executor: @@ -99,7 +94,7 @@ def save_average_style_vector(style_vectors, filename="style_vectors.npy"): f"Found NaN value in {len(nan_training_lines)} files: {nan_files}, so they will be deleted from training data." ) - val_lines = [] + val_lines: list[str] = [] with open(hps.data.validation_files, encoding="utf-8") as f: val_lines.extend(f.readlines()) diff --git a/transcribe.py b/transcribe.py index 18509c9ea..f44ecc386 100644 --- a/transcribe.py +++ b/transcribe.py @@ -2,6 +2,7 @@ import os import sys from pathlib import Path +from typing import Optional import yaml from faster_whisper import WhisperModel @@ -12,7 +13,9 @@ from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -def transcribe(wav_path: Path, initial_prompt=None, language="ja"): +def transcribe( + wav_path: Path, initial_prompt: Optional[str] = None, language: str = "ja" +): segments, _ = model.transcribe( str(wav_path), beam_size=5, language=language, initial_prompt=initial_prompt ) @@ -45,10 +48,10 @@ def transcribe(wav_path: Path, initial_prompt=None, language="ja"): input_dir = dataset_root / model_name / "raw" output_file = dataset_root / model_name / "esd.list" - initial_prompt = args.initial_prompt - language = args.language - device = args.device - compute_type = args.compute_type + initial_prompt: str = args.initial_prompt + language: str = args.language + device: str = args.device + compute_type: str = args.compute_type output_file.parent.mkdir(parents=True, exist_ok=True) From bdf394d2c214e37e979290e3a3b5c4e73bf09676 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 20:56:50 +0900 Subject: [PATCH 098/148] Fix: migration to new TTSModelHolder --- server_fastapi.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/server_fastapi.py b/server_fastapi.py index d544212a1..49e85ac83 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -8,7 +8,7 @@ import sys from io import BytesIO from pathlib import Path -from typing import Dict, Optional, Union +from typing import Optional, Any from urllib.parse import unquote import GPUtil @@ -72,8 +72,12 @@ class AudioResponse(Response): media_type = "audio/wav" +loaded_models: list[TTSModel] = [] + + def load_models(model_holder: TTSModelHolder): - model_holder.models = [] + global loaded_models + loaded_models = [] for model_name, model_paths in model_holder.model_files_dict.items(): model = TTSModel( model_path=model_paths[0], @@ -81,8 +85,9 @@ def load_models(model_holder: TTSModelHolder): style_vec_path=model_holder.root_dir / model_name / "style_vectors.npy", device=model_holder.device, ) - model.load() - model_holder.models.append(model) + # 起動時に全てのモデルを読み込むのは時間がかかりメモリを食うのでやめる + # model.load() + loaded_models.append(model) if __name__ == "__main__": @@ -106,6 +111,7 @@ def load_models(model_holder: TTSModelHolder): logger.info("Loading models...") load_models(model_holder) + limit = config.server_config.limit app = FastAPI() allow_origins = config.server_config.origins @@ -120,7 +126,8 @@ def load_models(model_holder: TTSModelHolder): allow_methods=["*"], allow_headers=["*"], ) - app.logger = logger + # app.logger = logger + # ↑効いていなさそう。loggerをどうやって上書きするかはよく分からなかった。 @app.get("/voice", response_class=AudioResponse) async def voice( @@ -165,7 +172,7 @@ async def voice( assist_text_weight: float = Query( DEFAULT_ASSIST_TEXT_WEIGHT, description="assist_textの強さ" ), - style: Optional[Union[int, str]] = Query(DEFAULT_STYLE, description="スタイル"), + style: Optional[str] = Query(DEFAULT_STYLE, description="スタイル"), style_weight: float = Query(DEFAULT_STYLE_WEIGHT, description="スタイルの強さ"), reference_audio_path: Optional[str] = Query( None, description="スタイルを音声ファイルで行う" @@ -176,11 +183,11 @@ async def voice( f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}" ) if model_id >= len( - model_holder.models + model_holder.model_names ): # /models/refresh があるためQuery(le)で表現不可 raise_validation_error(f"model_id={model_id} not found", "model_id") - model = model_holder.models[model_id] + model = loaded_models[model_id] if speaker_name is None: if speaker_id not in model.id2spk.keys(): raise_validation_error( @@ -194,6 +201,7 @@ async def voice( speaker_id = model.spk2id[speaker_name] if style not in model.style2id.keys(): raise_validation_error(f"style={style} not found", "style") + assert style is not None if encoding is not None: text = unquote(text, encoding=encoding) sr, audio = model.infer( @@ -222,8 +230,8 @@ async def voice( def get_loaded_models_info(): """ロードされたモデル情報の取得""" - result: Dict[str, Dict] = dict() - for model_id, model in enumerate(model_holder.models): + result: dict[str, dict[str, Any]] = dict() + for model_id, model in enumerate(loaded_models): result[str(model_id)] = { "config_path": model.config_path, "model_path": model.model_path, From 6e4dcf0232d8fb3847724853e538b9e27989f48f Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 21:38:27 +0900 Subject: [PATCH 099/148] Add Data, model_assets to toplevel gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index ae22e6437..22f52c8fb 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ dist/ .coverage .ipynb_checkpoints/ +/Data/ +/model_assets/ + /*.yml !/default_config.yml /bert/*/*.bin From 69225a3e13701f4194cce147c12b7e9a52f609a9 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Tue, 12 Mar 2024 21:47:59 +0900 Subject: [PATCH 100/148] Style --- .gitignore | 1 + pyproject.toml | 17 +++++++++++++++++ resample.py | 1 + server_fastapi.py | 2 +- slice.py | 1 - style_gen.py | 3 +-- train_ms.py | 1 + webui/inference.py | 4 ++-- webui/train.py | 2 +- 9 files changed, 25 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 22f52c8fb..fe8824b59 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ venv/ dist/ .coverage .ipynb_checkpoints/ +.ruff_cache/ /Data/ /model_assets/ diff --git a/pyproject.toml b/pyproject.toml index 870c3b2d7..e8c218df5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,23 @@ cov = [ "cov-report", ] +[tool.hatch.envs.style] +detached = true +dependencies = [ + "black", + "isort", +] +[tool.hatch.envs.style.scripts] +check = [ + "black --check --diff .", + "isort --check-only --diff --profile black --gitignore --lai 2 .", +] +fmt = [ + "black .", + "isort --profile black --gitignore --lai 2 .", + "check", +] + [[tool.hatch.envs.test.matrix]] python = ["3.9", "3.10", "3.11"] diff --git a/resample.py b/resample.py index b63c64a23..5c3cc7992 100644 --- a/resample.py +++ b/resample.py @@ -1,6 +1,7 @@ import argparse import os from concurrent.futures import ThreadPoolExecutor +from multiprocessing import cpu_count import librosa import pyloudnorm as pyln diff --git a/server_fastapi.py b/server_fastapi.py index 49e85ac83..f3d3c94c8 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -8,7 +8,7 @@ import sys from io import BytesIO from pathlib import Path -from typing import Optional, Any +from typing import Any, Optional from urllib.parse import unquote import GPUtil diff --git a/slice.py b/slice.py index 99fac19ed..4a86a3dfa 100644 --- a/slice.py +++ b/slice.py @@ -1,5 +1,4 @@ import argparse -import os import shutil from pathlib import Path diff --git a/style_gen.py b/style_gen.py index b22292d14..384319a08 100644 --- a/style_gen.py +++ b/style_gen.py @@ -6,6 +6,7 @@ import numpy as np import torch from numpy.typing import NDArray +from pyannote.audio import Inference, Model from tqdm import tqdm from config import config @@ -13,8 +14,6 @@ from style_bert_vits2.models.hyper_parameters import HyperParameters from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -warnings.filterwarnings("ignore", category=UserWarning) -from pyannote.audio import Inference, Model model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") inference = Inference(model, window="whole") diff --git a/train_ms.py b/train_ms.py index c1a392e5b..c2538dade 100644 --- a/train_ms.py +++ b/train_ms.py @@ -1,5 +1,6 @@ import argparse import datetime +import gc import os import platform diff --git a/webui/inference.py b/webui/inference.py index 91baed50b..db598290e 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -39,7 +39,7 @@ bert_models.load_model(Languages.ZH) bert_models.load_tokenizer(Languages.ZH) -languages = [l.value for l in Languages] +languages = [lang.value for lang in Languages] initial_text = "こんにちは、初めまして。あなたの名前はなんていうの?" @@ -100,7 +100,7 @@ ], ] -initial_md = f""" +initial_md = """ - Ver 2.3で追加されたエディターのほうが実際に読み上げさせるには使いやすいかもしれません。`Editor.bat`か`python server_editor.py --inbrowser`で起動できます。 - 初期からある[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)は、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。ライセンスは[CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.ja)です。 diff --git a/webui/train.py b/webui/train.py index 7ce8c1e91..6d7206181 100644 --- a/webui/train.py +++ b/webui/train.py @@ -400,7 +400,7 @@ def run_tensorboard(model_name): yield gr.Button("Tensorboardを開く") -how_to_md = f""" +how_to_md = """ ## 使い方 - データを準備して、モデル名を入力して、必要なら設定を調整してから、「自動前処理を実行」ボタンを押してください。進捗状況等はターミナルに表示されます。 From 74af2c831c71b481ead7ce2d191c5f09abffd53c Mon Sep 17 00:00:00 2001 From: tsukumi Date: Tue, 12 Mar 2024 17:45:10 +0000 Subject: [PATCH 101/148] Add: cache_dir and revision optional arguments to load_model() / load_tokenizer() --- style_bert_vits2/nlp/bert_models.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/style_bert_vits2/nlp/bert_models.py b/style_bert_vits2/nlp/bert_models.py index 1e346a481..5f630c39c 100644 --- a/style_bert_vits2/nlp/bert_models.py +++ b/style_bert_vits2/nlp/bert_models.py @@ -38,12 +38,15 @@ def load_model( language: Languages, pretrained_model_name_or_path: Optional[str] = None, + cache_dir: Optional[str] = None, + revision: str = 'main', ) -> Union[PreTrainedModel, DebertaV2Model]: """ 指定された言語の BERT モデルをロードし、ロード済みの BERT モデルを返す。 一度ロードされていれば、ロード済みの BERT モデルを即座に返す。 - ライブラリ利用時は常に pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある。 + ライブラリ利用時は常に必ず pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある。 ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき。 + cache_dir と revision は pretrain_model_name_or_path がリポジトリ名の場合のみ有効。 Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている。 これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い。 @@ -54,6 +57,8 @@ def load_model( Args: language (Languages): ロードする学習済みモデルの対象言語 pretrained_model_name_or_path (Optional[str]): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) + cache_dir (Optional[str]): モデルのキャッシュディレクトリ。指定しない場合はデフォルトのキャッシュディレクトリが利用される (デフォルト: None) + revision (str): モデルの Hugging Face 上の Git リビジョン。指定しない場合は最新の main ブランチの内容が利用される (デフォルト: None) Returns: Union[PreTrainedModel, DebertaV2Model]: ロード済みの BERT モデル @@ -75,10 +80,10 @@ def load_model( if language == Languages.EN: model = cast( DebertaV2Model, - DebertaV2Model.from_pretrained(pretrained_model_name_or_path), + DebertaV2Model.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision), ) else: - model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path) + model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision) __loaded_models[language] = model logger.info( f"Loaded the {language} BERT model from {pretrained_model_name_or_path}" @@ -90,12 +95,15 @@ def load_model( def load_tokenizer( language: Languages, pretrained_model_name_or_path: Optional[str] = None, + cache_dir: Optional[str] = None, + revision: str = 'main', ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast, DebertaV2Tokenizer]: """ 指定された言語の BERT モデルをロードし、ロード済みの BERT トークナイザーを返す。 一度ロードされていれば、ロード済みの BERT トークナイザーを即座に返す。 - ライブラリ利用時は常に pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある。 + ライブラリ利用時は常に必ず pretrain_model_name_or_path (Hugging Face のリポジトリ名 or ローカルのファイルパス) を指定する必要がある。 ロードにはそれなりに時間がかかるため、ライブラリ利用前に明示的に pretrained_model_name_or_path を指定してロードしておくべき。 + cache_dir と revision は pretrain_model_name_or_path がリポジトリ名の場合のみ有効。 Style-Bert-VITS2 では、BERT モデルに下記の 3 つが利用されている。 これ以外の BERT モデルを指定した場合は正常に動作しない可能性が高い。 @@ -106,6 +114,8 @@ def load_tokenizer( Args: language (Languages): ロードする学習済みモデルの対象言語 pretrained_model_name_or_path (Optional[str]): ロードする学習済みモデルの名前またはパス。指定しない場合はデフォルトのパスが利用される (デフォルト: None) + cache_dir (Optional[str]): モデルのキャッシュディレクトリ。指定しない場合はデフォルトのキャッシュディレクトリが利用される (デフォルト: None) + revision (str): モデルの Hugging Face 上の Git リビジョン。指定しない場合は最新の main ブランチの内容が利用される (デフォルト: None) Returns: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, DebertaV2Tokenizer]: ロード済みの BERT トークナイザー @@ -125,9 +135,9 @@ def load_tokenizer( # BERT トークナイザーをロードし、辞書に格納して返す ## 英語のみ DebertaV2Tokenizer でロードする必要がある if language == Languages.EN: - tokenizer = DebertaV2Tokenizer.from_pretrained(pretrained_model_name_or_path) + tokenizer = DebertaV2Tokenizer.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision) else: - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision) __loaded_tokenizers[language] = tokenizer logger.info( f"Loaded the {language} BERT tokenizer from {pretrained_model_name_or_path}" From 483bc68d579fb028ae3c6c6febbebc9c42afefca Mon Sep 17 00:00:00 2001 From: tsukumi Date: Tue, 12 Mar 2024 18:00:51 +0000 Subject: [PATCH 102/148] Refactor: unify invoke format of open() function --- bert_gen.py | 4 ++-- config.py | 2 +- style_bert_vits2/models/utils/__init__.py | 8 +++++--- style_bert_vits2/models/utils/checkpoints.py | 2 +- style_bert_vits2/nlp/bert_models.py | 4 ++-- style_bert_vits2/nlp/chinese/g2p.py | 9 +++++---- style_bert_vits2/nlp/english/cmudict.py | 2 +- style_gen.py | 4 ++-- webui/merge.py | 10 +++++----- 9 files changed, 24 insertions(+), 21 deletions(-) diff --git a/bert_gen.py b/bert_gen.py index 2ded36962..068afd63e 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -73,10 +73,10 @@ def process_line(x): config_path = args.config hps = HyperParameters.load_from_json(config_path) lines = [] - with open(hps.data.training_files, encoding="utf-8") as f: + with open(hps.data.training_files, "r", encoding="utf-8") as f: lines.extend(f.readlines()) - with open(hps.data.validation_files, encoding="utf-8") as f: + with open(hps.data.validation_files, "r", encoding="utf-8") as f: lines.extend(f.readlines()) add_blank = [hps.data.add_blank] * len(lines) diff --git a/config.py b/config.py index 77c384d76..40f3f53d8 100644 --- a/config.py +++ b/config.py @@ -238,7 +238,7 @@ def __init__(self, config_path: str, path_config: dict[str, str]): "If you have no special needs, please do not modify default_config.yml." ) # sys.exit(0) - with open(file=config_path, mode="r", encoding="utf-8") as file: + with open(config_path, "r", encoding="utf-8") as file: yaml_config: Dict[str, any] = yaml.safe_load(file.read()) model_name: str = yaml_config["model_name"] self.model_name: str = model_name diff --git a/style_bert_vits2/models/utils/__init__.py b/style_bert_vits2/models/utils/__init__.py index 33e13247f..edd51ccb6 100644 --- a/style_bert_vits2/models/utils/__init__.py +++ b/style_bert_vits2/models/utils/__init__.py @@ -180,7 +180,7 @@ def load_filepaths_and_text( list[list[str]]: ファイルパスとテキストのリスト """ - with open(filename, encoding="utf-8") as f: + with open(filename, "r", encoding="utf-8") as f: filepaths_and_text = [line.strip().split(split) for line in f] return filepaths_and_text @@ -249,7 +249,8 @@ def check_git_hash(model_dir_path: Union[str, Path]) -> None: path = os.path.join(model_dir_path, "githash") if os.path.exists(path): - saved_hash = open(path).read() + with open(path, "r", encoding="utf-8") as f: + saved_hash = f.read() if saved_hash != cur_hash: logger.warning( "git hash values are different. {}(saved) != {}(current)".format( @@ -257,4 +258,5 @@ def check_git_hash(model_dir_path: Union[str, Path]) -> None: ) ) else: - open(path, "w").write(cur_hash) + with open(path, "w", encoding="utf-8") as f: + f.write(cur_hash) diff --git a/style_bert_vits2/models/utils/checkpoints.py b/style_bert_vits2/models/utils/checkpoints.py index 63a5fa3da..c601dda8a 100644 --- a/style_bert_vits2/models/utils/checkpoints.py +++ b/style_bert_vits2/models/utils/checkpoints.py @@ -85,7 +85,7 @@ def load_checkpoint( else: model.load_state_dict(new_state_dict, strict=False) - logger.info("Loaded '{}' (iteration {})".format(checkpoint_path, iteration)) + logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})") return model, optimizer, learning_rate, iteration diff --git a/style_bert_vits2/nlp/bert_models.py b/style_bert_vits2/nlp/bert_models.py index 5f630c39c..eb84eb750 100644 --- a/style_bert_vits2/nlp/bert_models.py +++ b/style_bert_vits2/nlp/bert_models.py @@ -39,7 +39,7 @@ def load_model( language: Languages, pretrained_model_name_or_path: Optional[str] = None, cache_dir: Optional[str] = None, - revision: str = 'main', + revision: str = "main", ) -> Union[PreTrainedModel, DebertaV2Model]: """ 指定された言語の BERT モデルをロードし、ロード済みの BERT モデルを返す。 @@ -96,7 +96,7 @@ def load_tokenizer( language: Languages, pretrained_model_name_or_path: Optional[str] = None, cache_dir: Optional[str] = None, - revision: str = 'main', + revision: str = "main", ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast, DebertaV2Tokenizer]: """ 指定された言語の BERT モデルをロードし、ロード済みの BERT トークナイザーを返す。 diff --git a/style_bert_vits2/nlp/chinese/g2p.py b/style_bert_vits2/nlp/chinese/g2p.py index 004616742..4ce2b26a9 100644 --- a/style_bert_vits2/nlp/chinese/g2p.py +++ b/style_bert_vits2/nlp/chinese/g2p.py @@ -8,10 +8,11 @@ from style_bert_vits2.nlp.symbols import PUNCTUATIONS -__PINYIN_TO_SYMBOL_MAP = { - line.split("\t")[0]: line.strip().split("\t")[1] - for line in open(Path(__file__).parent / "opencpop-strict.txt").readlines() -} +with open(Path(__file__).parent / "opencpop-strict.txt", "r", encoding="utf-8") as f: + __PINYIN_TO_SYMBOL_MAP = { + line.split("\t")[0]: line.strip().split("\t")[1] + for line in f.readlines() + } def g2p(text: str) -> tuple[list[str], list[int], list[int]]: diff --git a/style_bert_vits2/nlp/english/cmudict.py b/style_bert_vits2/nlp/english/cmudict.py index e6afb89a7..7772e77b8 100644 --- a/style_bert_vits2/nlp/english/cmudict.py +++ b/style_bert_vits2/nlp/english/cmudict.py @@ -20,7 +20,7 @@ def get_dict() -> dict[str, list[list[str]]]: def read_dict() -> dict[str, list[list[str]]]: g2p_dict = {} start_line = 49 - with open(CMU_DICT_PATH) as f: + with open(CMU_DICT_PATH, "r", encoding="utf-8") as f: line = f.readline() line_index = 1 while line: diff --git a/style_gen.py b/style_gen.py index 384319a08..02af06742 100644 --- a/style_gen.py +++ b/style_gen.py @@ -73,7 +73,7 @@ def process_line(line: str): device = config.style_gen_config.device training_lines: list[str] = [] - with open(hps.data.training_files, encoding="utf-8") as f: + with open(hps.data.training_files, "r", encoding="utf-8") as f: training_lines.extend(f.readlines()) with ThreadPoolExecutor(max_workers=num_processes) as executor: training_results = list( @@ -94,7 +94,7 @@ def process_line(line: str): ) val_lines: list[str] = [] - with open(hps.data.validation_files, encoding="utf-8") as f: + with open(hps.data.validation_files, "r", encoding="utf-8") as f: val_lines.extend(f.readlines()) with ThreadPoolExecutor(max_workers=num_processes) as executor: diff --git a/webui/merge.py b/webui/merge.py index f692b6748..18454c9a9 100644 --- a/webui/merge.py +++ b/webui/merge.py @@ -47,11 +47,11 @@ def merge_style(model_name_a, model_name_b, weight, output_name, style_triple_li os.path.join(assets_root, model_name_b, "style_vectors.npy") ) # (style_num_b, 256) with open( - os.path.join(assets_root, model_name_a, "config.json"), encoding="utf-8" + os.path.join(assets_root, model_name_a, "config.json"), "r", encoding="utf-8" ) as f: config_a = json.load(f) with open( - os.path.join(assets_root, model_name_b, "config.json"), encoding="utf-8" + os.path.join(assets_root, model_name_b, "config.json"), "r", encoding="utf-8" ) as f: config_b = json.load(f) style2id_a = config_a["data"]["style2id"] @@ -88,7 +88,7 @@ def merge_style(model_name_a, model_name_b, weight, output_name, style_triple_li # recipe.jsonを読み込んで、style_triple_listを追記 info_path = os.path.join(assets_root, output_name, "recipe.json") if os.path.exists(info_path): - with open(info_path, encoding="utf-8") as f: + with open(info_path, "r", encoding="utf-8") as f: info = json.load(f) else: info = {} @@ -261,12 +261,12 @@ def update_two_model_names_dropdown(model_holder: TTSModelHolder): def load_styles_gr(model_name_a, model_name_b): config_path_a = os.path.join(assets_root, model_name_a, "config.json") - with open(config_path_a, encoding="utf-8") as f: + with open(config_path_a, "r", encoding="utf-8") as f: config_a = json.load(f) styles_a = list(config_a["data"]["style2id"].keys()) config_path_b = os.path.join(assets_root, model_name_b, "config.json") - with open(config_path_b, encoding="utf-8") as f: + with open(config_path_b, "r", encoding="utf-8") as f: config_b = json.load(f) styles_b = list(config_b["data"]["style2id"].keys()) From e8a76e547bc32aeb6108fdae929e20402e91245d Mon Sep 17 00:00:00 2001 From: tsukumi Date: Tue, 12 Mar 2024 18:22:34 +0000 Subject: [PATCH 103/148] Refactor: No preloading of BERT models to avoid unnecessary GPU VRAM consumption during training in the Web UI Since the BERT features of the dataset are pre-extracted by bert_gen.py, there is no need to load the BERT model at training time. --- style_bert_vits2/tts_model.py | 8 ++++++++ webui/inference.py | 13 ++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index f7d4f2194..da30f4836 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -29,6 +29,7 @@ from style_bert_vits2.models.models_jp_extra import ( SynthesizerTrn as SynthesizerTrnJPExtra, ) +from style_bert_vits2.nlp import bert_models from style_bert_vits2.voice import adjust_voice @@ -379,6 +380,13 @@ def get_model(self, model_name: str, model_path_str: str) -> TTSModel: def get_model_for_gradio( self, model_name: str, model_path_str: str ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: + bert_models.load_model(Languages.JP) + bert_models.load_tokenizer(Languages.JP) + bert_models.load_model(Languages.EN) + bert_models.load_tokenizer(Languages.EN) + bert_models.load_model(Languages.ZH) + bert_models.load_tokenizer(Languages.ZH) + model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") diff --git a/webui/inference.py b/webui/inference.py index db598290e..d711846d3 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -19,7 +19,6 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.models.infer import InvalidToneError -from style_bert_vits2.nlp import bert_models from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.nlp.japanese.normalizer import normalize_text @@ -30,14 +29,10 @@ ## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する pyopenjtalk.initialize_worker() -# 事前に BERT モデル/トークナイザーをロードしておく -## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い -bert_models.load_model(Languages.JP) -bert_models.load_tokenizer(Languages.JP) -bert_models.load_model(Languages.EN) -bert_models.load_tokenizer(Languages.EN) -bert_models.load_model(Languages.ZH) -bert_models.load_tokenizer(Languages.ZH) +# Web UI での学習時の無駄な GPU VRAM 消費を避けるため、あえてここでは BERT モデルの事前ロードを行わない +# データセットの BERT 特徴量は事前に bert_gen.py により抽出されているため、学習時に BERT モデルをロードしておく必要はない +# BERT モデルの事前ロードは「ロード」ボタン押下時に実行される TTSModelHolder.get_model_for_gradio() 内で行われる +# Web UI での学習時、音声合成タブの「ロード」ボタンを押さなければ、BERT モデルが VRAM にロードされていない状態で学習を開始できる languages = [lang.value for lang in Languages] From 07d246b98b2f476eb7f6a0aa4856313d010cbb62 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Tue, 12 Mar 2024 18:27:08 +0000 Subject: [PATCH 104/148] Refactor: run "hatch run style:fmt" --- pyproject.toml | 14 +++++++------- style_bert_vits2/nlp/bert_models.py | 20 ++++++++++++++++---- style_bert_vits2/nlp/chinese/g2p.py | 3 +-- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e8c218df5..a5f53ec44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,18 +75,18 @@ cov = [ [tool.hatch.envs.style] detached = true dependencies = [ - "black", - "isort", + "black", + "isort", ] [tool.hatch.envs.style.scripts] check = [ - "black --check --diff .", - "isort --check-only --diff --profile black --gitignore --lai 2 .", + "black --check --diff .", + "isort --check-only --diff --profile black --gitignore --lai 2 .", ] fmt = [ - "black .", - "isort --profile black --gitignore --lai 2 .", - "check", + "black .", + "isort --profile black --gitignore --lai 2 .", + "check", ] [[tool.hatch.envs.test.matrix]] diff --git a/style_bert_vits2/nlp/bert_models.py b/style_bert_vits2/nlp/bert_models.py index eb84eb750..1166846f5 100644 --- a/style_bert_vits2/nlp/bert_models.py +++ b/style_bert_vits2/nlp/bert_models.py @@ -80,10 +80,14 @@ def load_model( if language == Languages.EN: model = cast( DebertaV2Model, - DebertaV2Model.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision), + DebertaV2Model.from_pretrained( + pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision + ), ) else: - model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision) + model = AutoModelForMaskedLM.from_pretrained( + pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision + ) __loaded_models[language] = model logger.info( f"Loaded the {language} BERT model from {pretrained_model_name_or_path}" @@ -135,9 +139,17 @@ def load_tokenizer( # BERT トークナイザーをロードし、辞書に格納して返す ## 英語のみ DebertaV2Tokenizer でロードする必要がある if language == Languages.EN: - tokenizer = DebertaV2Tokenizer.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision) + tokenizer = DebertaV2Tokenizer.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + revision=revision, + ) else: - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir, revision=revision) + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + revision=revision, + ) __loaded_tokenizers[language] = tokenizer logger.info( f"Loaded the {language} BERT tokenizer from {pretrained_model_name_or_path}" diff --git a/style_bert_vits2/nlp/chinese/g2p.py b/style_bert_vits2/nlp/chinese/g2p.py index 4ce2b26a9..f38e09fa8 100644 --- a/style_bert_vits2/nlp/chinese/g2p.py +++ b/style_bert_vits2/nlp/chinese/g2p.py @@ -10,8 +10,7 @@ with open(Path(__file__).parent / "opencpop-strict.txt", "r", encoding="utf-8") as f: __PINYIN_TO_SYMBOL_MAP = { - line.split("\t")[0]: line.strip().split("\t")[1] - for line in f.readlines() + line.split("\t")[0]: line.strip().split("\t")[1] for line in f.readlines() } From c4d6a8cdb7011f3478348541bbf2df3464b68aef Mon Sep 17 00:00:00 2001 From: tsukumi Date: Tue, 12 Mar 2024 18:45:54 +0000 Subject: [PATCH 105/148] Fix: large number of unnecessary files in built sdist Include in sdist only the minimum required files for style-bert-vits2 as a library. --- pyproject.toml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a5f53ec44..3045f039e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,27 @@ Source = "https://github.com/litagin02/Style-Bert-VITS2" [tool.hatch.version] path = "style_bert_vits2/constants.py" +[tool.hatch.build.targets.sdist] +only-include = [ + ".vscode", + "dict_data/default.csv", + "docs", + "style_bert_vits2", + "tests", + "LGPL_LICENSE", + "LICENSE", + "pyproject.toml", + "README.md", +] +exclude = [ + ".git", + ".gitignore", + ".gitattributes", +] + +[tool.hatch.build.targets.wheel] +packages = ["style_bert_vits2"] + [tool.hatch.envs.test] dependencies = [ "coverage[toml]>=6.5", From fedd019c04cd46ffd18125bb31181deaadc3de41 Mon Sep 17 00:00:00 2001 From: tsukumi Date: Tue, 12 Mar 2024 19:01:45 +0000 Subject: [PATCH 106/148] Revert "Refactor: No preloading of BERT models to avoid unnecessary GPU VRAM consumption during training in the Web UI" This reverts commit e8a76e547bc32aeb6108fdae929e20402e91245d. --- style_bert_vits2/tts_model.py | 8 -------- webui/inference.py | 13 +++++++++---- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index da30f4836..f7d4f2194 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -29,7 +29,6 @@ from style_bert_vits2.models.models_jp_extra import ( SynthesizerTrn as SynthesizerTrnJPExtra, ) -from style_bert_vits2.nlp import bert_models from style_bert_vits2.voice import adjust_voice @@ -380,13 +379,6 @@ def get_model(self, model_name: str, model_path_str: str) -> TTSModel: def get_model_for_gradio( self, model_name: str, model_path_str: str ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: - bert_models.load_model(Languages.JP) - bert_models.load_tokenizer(Languages.JP) - bert_models.load_model(Languages.EN) - bert_models.load_tokenizer(Languages.EN) - bert_models.load_model(Languages.ZH) - bert_models.load_tokenizer(Languages.ZH) - model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") diff --git a/webui/inference.py b/webui/inference.py index d711846d3..db598290e 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -19,6 +19,7 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.models.infer import InvalidToneError +from style_bert_vits2.nlp import bert_models from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.nlp.japanese.normalizer import normalize_text @@ -29,10 +30,14 @@ ## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する pyopenjtalk.initialize_worker() -# Web UI での学習時の無駄な GPU VRAM 消費を避けるため、あえてここでは BERT モデルの事前ロードを行わない -# データセットの BERT 特徴量は事前に bert_gen.py により抽出されているため、学習時に BERT モデルをロードしておく必要はない -# BERT モデルの事前ロードは「ロード」ボタン押下時に実行される TTSModelHolder.get_model_for_gradio() 内で行われる -# Web UI での学習時、音声合成タブの「ロード」ボタンを押さなければ、BERT モデルが VRAM にロードされていない状態で学習を開始できる +# 事前に BERT モデル/トークナイザーをロードしておく +## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い +bert_models.load_model(Languages.JP) +bert_models.load_tokenizer(Languages.JP) +bert_models.load_model(Languages.EN) +bert_models.load_tokenizer(Languages.EN) +bert_models.load_model(Languages.ZH) +bert_models.load_tokenizer(Languages.ZH) languages = [lang.value for lang in Languages] From 28fd715cb89b11711b128521f3b046c89e6fcca1 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 10:25:22 +0900 Subject: [PATCH 107/148] Restore gc empty_cache in train_ms_jp --- train_ms_jp_extra.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index 722a52197..7b2cb4210 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -1,5 +1,6 @@ import argparse import datetime +import gc import os import platform @@ -1008,9 +1009,9 @@ def train_and_evaluate( ) ) pbar.update() - # 本家ではこれをスピードアップのために消すと書かれていたので、一応消してみる - # gc.collect() - # torch.cuda.empty_cache() + + gc.collect() + torch.cuda.empty_cache() if pbar is None and rank == 0: logger.info(f"====> Epoch: {epoch}, step: {global_step}") From c37893a76b4ed05ce0bc42f5920d00ac2f9124b5 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 11:46:56 +0900 Subject: [PATCH 108/148] Add version suffix dev0 --- configs/config.json | 2 +- configs/configs_jp_extra.json | 2 +- style_bert_vits2/constants.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/config.json b/configs/config.json index 07a9304b8..db0411aed 100644 --- a/configs/config.json +++ b/configs/config.json @@ -69,5 +69,5 @@ "use_spectral_norm": false, "gin_channels": 256 }, - "version": "2.4" + "version": "2.4.dev0" } diff --git a/configs/configs_jp_extra.json b/configs/configs_jp_extra.json index bc0a9a4eb..23c5fe500 100644 --- a/configs/configs_jp_extra.json +++ b/configs/configs_jp_extra.json @@ -76,5 +76,5 @@ "initial_channel": 64 } }, - "version": "2.4-JP-Extra" + "version": "2.4.dev0-JP-Extra" } diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index 9d8469073..aa1433101 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -4,7 +4,7 @@ # Style-Bert-VITS2 のバージョン -VERSION = "2.4" +VERSION = "2.4.dev0" # Style-Bert-VITS2 のベースディレクトリ BASE_DIR = Path(__file__).parent.parent From 8baecb138afb6ed2c33c7b5644df018cbaf14af7 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 11:54:24 +0900 Subject: [PATCH 109/148] Skip cuda test --- tests/test_main.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index f00d90839..e2e6530ef 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -51,5 +51,6 @@ def test_synthesize_cpu(): synthesize(device="cpu") -def test_synthesize_cuda(): - synthesize(device="cuda") +# Windows環境ではtorchのcudaが簡単に入らないため、テストをスキップ +# def test_synthesize_cuda(): +# synthesize(device="cuda") From 9e5222619fd6d00ca9837c2325b728df0ead0b7c Mon Sep 17 00:00:00 2001 From: tsukumi Date: Tue, 12 Mar 2024 18:22:34 +0000 Subject: [PATCH 110/148] Refactor: No preloading of BERT models to avoid unnecessary GPU VRAM consumption during training in the Web UI Since the BERT features of the dataset are pre-extracted by bert_gen.py, there is no need to load the BERT model at training time. --- style_bert_vits2/tts_model.py | 8 ++++++++ webui/inference.py | 13 ++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index f7d4f2194..da30f4836 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -29,6 +29,7 @@ from style_bert_vits2.models.models_jp_extra import ( SynthesizerTrn as SynthesizerTrnJPExtra, ) +from style_bert_vits2.nlp import bert_models from style_bert_vits2.voice import adjust_voice @@ -379,6 +380,13 @@ def get_model(self, model_name: str, model_path_str: str) -> TTSModel: def get_model_for_gradio( self, model_name: str, model_path_str: str ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: + bert_models.load_model(Languages.JP) + bert_models.load_tokenizer(Languages.JP) + bert_models.load_model(Languages.EN) + bert_models.load_tokenizer(Languages.EN) + bert_models.load_model(Languages.ZH) + bert_models.load_tokenizer(Languages.ZH) + model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") diff --git a/webui/inference.py b/webui/inference.py index db598290e..d711846d3 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -19,7 +19,6 @@ ) from style_bert_vits2.logging import logger from style_bert_vits2.models.infer import InvalidToneError -from style_bert_vits2.nlp import bert_models from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone from style_bert_vits2.nlp.japanese.normalizer import normalize_text @@ -30,14 +29,10 @@ ## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する pyopenjtalk.initialize_worker() -# 事前に BERT モデル/トークナイザーをロードしておく -## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い -bert_models.load_model(Languages.JP) -bert_models.load_tokenizer(Languages.JP) -bert_models.load_model(Languages.EN) -bert_models.load_tokenizer(Languages.EN) -bert_models.load_model(Languages.ZH) -bert_models.load_tokenizer(Languages.ZH) +# Web UI での学習時の無駄な GPU VRAM 消費を避けるため、あえてここでは BERT モデルの事前ロードを行わない +# データセットの BERT 特徴量は事前に bert_gen.py により抽出されているため、学習時に BERT モデルをロードしておく必要はない +# BERT モデルの事前ロードは「ロード」ボタン押下時に実行される TTSModelHolder.get_model_for_gradio() 内で行われる +# Web UI での学習時、音声合成タブの「ロード」ボタンを押さなければ、BERT モデルが VRAM にロードされていない状態で学習を開始できる languages = [lang.value for lang in Languages] From bc89fde15e243549c166240aab6f8d6881cdc4ce Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 14:55:42 +0900 Subject: [PATCH 111/148] Delete local_dir_use_symlinks=False --- initialize.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/initialize.py b/initialize.py index 3736151f0..d2680aa90 100644 --- a/initialize.py +++ b/initialize.py @@ -16,12 +16,7 @@ def download_bert_models(): for file in v["files"]: if not Path(local_path).joinpath(file).exists(): logger.info(f"Downloading {k} {file}") - hf_hub_download( - v["repo_id"], - file, - local_dir=local_path, - local_dir_use_symlinks=False, - ) + hf_hub_download(v["repo_id"], file, local_dir=local_path) def download_slm_model(): @@ -29,12 +24,7 @@ def download_slm_model(): file = "pytorch_model.bin" if not Path(local_path).joinpath(file).exists(): logger.info(f"Downloading wavlm-base-plus {file}") - hf_hub_download( - "microsoft/wavlm-base-plus", - file, - local_dir=local_path, - local_dir_use_symlinks=False, - ) + hf_hub_download("microsoft/wavlm-base-plus", file, local_dir=local_path) def download_pretrained_models(): @@ -44,10 +34,7 @@ def download_pretrained_models(): if not Path(local_path).joinpath(file).exists(): logger.info(f"Downloading pretrained {file}") hf_hub_download( - "litagin/Style-Bert-VITS2-1.0-base", - file, - local_dir=local_path, - local_dir_use_symlinks=False, + "litagin/Style-Bert-VITS2-1.0-base", file, local_dir=local_path ) @@ -58,10 +45,7 @@ def download_jp_extra_pretrained_models(): if not Path(local_path).joinpath(file).exists(): logger.info(f"Downloading JP-Extra pretrained {file}") hf_hub_download( - "litagin/Style-Bert-VITS2-2.0-base-JP-Extra", - file, - local_dir=local_path, - local_dir_use_symlinks=False, + "litagin/Style-Bert-VITS2-2.0-base-JP-Extra", file, local_dir=local_path ) From bfcf5091ed4f4c9ddbbf7170437ba7cab54f5ebf Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 15:07:01 +0900 Subject: [PATCH 112/148] Fix running server_editor.py in bat (and dev ver) --- scripts/Update-Style-Bert-VITS2.bat | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/scripts/Update-Style-Bert-VITS2.bat b/scripts/Update-Style-Bert-VITS2.bat index b4f8b8d6c..04f593720 100644 --- a/scripts/Update-Style-Bert-VITS2.bat +++ b/scripts/Update-Style-Bert-VITS2.bat @@ -10,12 +10,12 @@ if not exist %CURL_CMD% ( pause & popd & exit /b 1 ) -@REM Style-Bert-VITS2.zip をGitHubのmasterの最新のものをダウンロード +@REM Style-Bert-VITS2.zip をGitHubのdev-refactorの最新のものをダウンロード %CURL_CMD% -Lo Style-Bert-VITS2.zip^ - https://github.com/litagin02/Style-Bert-VITS2/archive/refs/heads/master.zip + https://github.com/litagin02/Style-Bert-VITS2/archive/refs/heads/dev-refactor.zip if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) -@REM Style-Bert-VITS2.zip を解凍(フォルダ名前がBert-VITS2-masterになる) +@REM Style-Bert-VITS2.zip を解凍(フォルダ名前がBert-VITS2-dev-refactorになる) %PS_CMD% Expand-Archive -Path Style-Bert-VITS2.zip -DestinationPath . -Force if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) @@ -23,9 +23,9 @@ if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) del Style-Bert-VITS2.zip if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) -@REM Bert-VITS2-masterの中身をStyle-Bert-VITS2に上書き移動 -xcopy /QSY .\Style-Bert-VITS2-master\ .\Style-Bert-VITS2\ -rmdir /s /q Style-Bert-VITS2-master +@REM Bert-VITS2-dev-refactorの中身をStyle-Bert-VITS2に上書き移動 +xcopy /QSY .\Style-Bert-VITS2-dev-refactor\ .\Style-Bert-VITS2\ +rmdir /s /q Style-Bert-VITS2-dev-refactor if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) @REM 仮想環境のpip requirements.txtを更新 @@ -39,9 +39,14 @@ if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) echo Update completed. Running Style-Bert-VITS2 Editor... +@REM Style-Bert-VITS2フォルダに移動 +pushd Style-Bert-VITS2 + @REM Style-Bert-VITS2 Editorを起動 python server_editor.py --inbrowser pause -popd \ No newline at end of file +popd + +popd From 55c3847176c41b1df6ad2b2f1d61ff64239f9596 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 15:26:30 +0900 Subject: [PATCH 113/148] Add POST method for /voice in server_fastapi.py --- server_fastapi.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server_fastapi.py b/server_fastapi.py index 8c5a7ccf3..3c9d31e68 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -129,7 +129,7 @@ def load_models(model_holder: TTSModelHolder): # app.logger = logger # ↑効いていなさそう。loggerをどうやって上書きするかはよく分からなかった。 - @app.post("/voice", response_class=AudioResponse) + @app.api_route("/voice", methods=["GET", "POST"], response_class=AudioResponse) async def voice( request: Request, text: str = Query(..., min_length=1, max_length=limit, description="セリフ"), @@ -182,6 +182,10 @@ async def voice( logger.info( f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}" ) + if request.method == "GET": + logger.warning( + "The GET method is not recommended for this endpoint due to various restrictions. Please use the POST method." + ) if model_id >= len( model_holder.model_names ): # /models/refresh があるためQuery(le)で表現不可 From 3c1269e520d0720b65dc5b285ba469ef0b7e5ae7 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 15:41:46 +0900 Subject: [PATCH 114/148] Add library usage example notebook --- library.ipynb | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 library.ipynb diff --git a/library.ipynb b/library.ipynb new file mode 100644 index 000000000..81a102894 --- /dev/null +++ b/library.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Style-Bert-VITS2ライブラリの使用例\n", + "\n", + "`pip install style-bert-vits2`を使った、colabで動く使用例です。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LLrngKcQEAyP" + }, + "outputs": [], + "source": [ + "!pip install style-bert-vits2==2.4.dev0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9xRtfUg5EZkx" + }, + "outputs": [], + "source": [ + "from style_bert_vits2.nlp import bert_models\n", + "from style_bert_vits2.constants import Languages\n", + "\n", + "\n", + "bert_models.load_model(Languages.JP, \"ku-nlp/deberta-v2-large-japanese-char-wwm\")\n", + "bert_models.load_tokenizer(Languages.JP, \"ku-nlp/deberta-v2-large-japanese-char-wwm\")\n", + "# bert_models.load_model(Languages.EN, \"microsoft/deberta-v3-large\")\n", + "# bert_models.load_tokenizer(Languages.EN, \"microsoft/deberta-v3-large\")\n", + "# bert_models.load_model(Languages.ZH, \"hfl/chinese-roberta-wwm-ext-large\")\n", + "# bert_models.load_tokenizer(Languages.ZH, \"hfl/chinese-roberta-wwm-ext-large\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "q2V9d3HyFAr_" + }, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from huggingface_hub import hf_hub_download\n", + "\n", + "\n", + "model_file = \"jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors\"\n", + "config_file = \"jvnv-F1-jp/config.json\"\n", + "style_file = \"jvnv-F1-jp/style_vectors.npy\"\n", + "\n", + "for file in [model_file, config_file, style_file]:\n", + " print(file)\n", + " hf_hub_download(\n", + " \"litagin/style_bert_vits2_jvnv\",\n", + " file,\n", + " local_dir=\"model_assets\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hJa31MEUFhe4" + }, + "outputs": [], + "source": [ + "from style_bert_vits2.tts_model import TTSModel\n", + "\n", + "assets_root = Path(\"model_assets\")\n", + "\n", + "model = TTSModel(\n", + " model_path=assets_root / model_file,\n", + " config_path = assets_root / config_file,\n", + " style_vec_path = assets_root / style_file,\n", + " device=\"cpu\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Gal0tqrtGXZx" + }, + "outputs": [], + "source": [ + "from IPython.display import Audio, display\n", + "\n", + "sr, audio = model.infer(text=\"こんにちは\")\n", + "display(Audio(audio, rate=sr))" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From b309c9d74aaab4783105c969f520375acff95b00 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 15:43:33 +0900 Subject: [PATCH 115/148] =?UTF-8?q?Change=20from=20=E9=9F=B3=E7=A8=8B=20to?= =?UTF-8?q?=20=E9=9F=B3=E9=AB=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- webui/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui/inference.py b/webui/inference.py index d711846d3..ef71928a8 100644 --- a/webui/inference.py +++ b/webui/inference.py @@ -288,7 +288,7 @@ def tts_fn( maximum=1.5, value=1, step=0.05, - label="音程(1以外では音質劣化)", + label="音高(1以外では音質劣化)", ) intonation_scale = gr.Slider( minimum=0, From 7e3e6c55fab65b27a5b077c124bd77378f0483f8 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 15:46:02 +0900 Subject: [PATCH 116/148] docs --- style_bert_vits2/tts_model.py | 2 +- style_bert_vits2/voice.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index da30f4836..01c6fb420 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -180,7 +180,7 @@ def infer( style_weight (float, optional): 音声スタイルを適用する強さ. Defaults to DEFAULT_STYLE_WEIGHT. given_tone (Optional[list[int]], optional): アクセントのトーンのリスト. Defaults to None. pitch_scale (float, optional): ピッチの高さ (1.0 から変更すると若干音質が低下する). Defaults to 1.0. - intonation_scale (float, optional): イントネーションの高さ (1.0 から変更すると若干音質が低下する). Defaults to 1.0. + intonation_scale (float, optional): 抑揚の平均からの変化幅 (1.0 から変更すると若干音質が低下する). Defaults to 1.0. Returns: tuple[int, NDArray[Any]]: サンプリングレートと音声データ (16bit PCM) diff --git a/style_bert_vits2/voice.py b/style_bert_vits2/voice.py index 75f7d51d3..b0584195d 100644 --- a/style_bert_vits2/voice.py +++ b/style_bert_vits2/voice.py @@ -12,14 +12,14 @@ def adjust_voice( intonation_scale: float = 1.0, ) -> tuple[int, NDArray[Any]]: """ - 音声のピッチとイントネーションを調整する。 + 音声のピッチと抑揚を調整する。 変更すると若干音質が劣化するので、どちらも初期値のままならそのまま返す。 Args: fs (int): 音声のサンプリング周波数 wave (NDArray[Any]): 音声データ pitch_scale (float, optional): ピッチの高さ. Defaults to 1.0. - intonation_scale (float, optional): イントネーションの平均からの変更比率. Defaults to 1.0. + intonation_scale (float, optional): 抑揚の平均からの変更比率. Defaults to 1.0. Returns: tuple[int, NDArray[Any]]: 調整後の音声データのサンプリング周波数と音声データ From 3141a8e341423aa7fd19ed6e6627ad8e9888d994 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 16:26:29 +0900 Subject: [PATCH 117/148] debug logging --- style_bert_vits2/logging.py | 1 + 1 file changed, 1 insertion(+) diff --git a/style_bert_vits2/logging.py b/style_bert_vits2/logging.py index e5c216a00..e552bf780 100644 --- a/style_bert_vits2/logging.py +++ b/style_bert_vits2/logging.py @@ -12,4 +12,5 @@ format="{time:MM-DD HH:mm:ss} |{level:^8}| {file}:{line} | {message}", backtrace=True, diagnose=True, + level="TRACE", ) From 860957006be81500067f28cd747bf4efba9e7fea Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 16:27:06 +0900 Subject: [PATCH 118/148] Add typing and load bert models on the top --- bert_gen.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/bert_gen.py b/bert_gen.py index 068afd63e..933fa2121 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -13,7 +13,15 @@ from style_bert_vits2.nlp.japanese import pyopenjtalk_worker from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.constants import Languages +bert_models.load_model(Languages.JP) +bert_models.load_tokenizer(Languages.JP) +bert_models.load_model(Languages.EN) +bert_models.load_tokenizer(Languages.EN) +bert_models.load_model(Languages.ZH) +bert_models.load_tokenizer(Languages.ZH) # このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 pyopenjtalk_worker.initialize_worker() @@ -22,7 +30,7 @@ update_dict() -def process_line(x): +def process_line(x: tuple[str, bool]): line, add_blank = x device = config.bert_gen_config.device if config.bert_gen_config.use_multi_device: @@ -30,9 +38,9 @@ def process_line(x): rank = rank[0] if len(rank) > 0 else 0 if torch.cuda.is_available(): gpu_id = rank % torch.cuda.device_count() - device = torch.device(f"cuda:{gpu_id}") + device = f"cuda:{gpu_id}" else: - device = torch.device("cpu") + device = "cpu" wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|") phone = phones.split(" ") tone = [int(i) for i in tone.split(" ")] @@ -72,7 +80,7 @@ def process_line(x): args, _ = parser.parse_known_args() config_path = args.config hps = HyperParameters.load_from_json(config_path) - lines = [] + lines: list[str] = [] with open(hps.data.training_files, "r", encoding="utf-8") as f: lines.extend(f.readlines()) From aff666db752d766fc31bec1b344eac69590cce7d Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 16:51:23 +0900 Subject: [PATCH 119/148] Update webui for time_suffix --- webui/dataset.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/webui/dataset.py b/webui/dataset.py index acaef99a3..e0f0de7c7 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -9,6 +9,7 @@ def do_slice( min_sec: float, max_sec: float, min_silence_dur_ms: int, + time_suffix: bool, input_dir: str, ): if model_name == "": @@ -25,6 +26,8 @@ def do_slice( "--min_silence_dur_ms", str(min_silence_dur_ms), ] + if time_suffix: + cmd.append("--time_suffix") if input_dir != "": cmd += ["--input_dir", input_dir] # onnxの警告が出るので無視する @@ -130,6 +133,10 @@ def create_dataset_app() -> gr.Blocks: step=100, label="無音とみなして区切る最小の無音の長さ(ms)", ) + time_suffix = gr.Checkbox( + value=False, + label="WAVファイル名の末尾に元ファイルの時間範囲を付与する", + ) slice_button = gr.Button("スライスを実行") result1 = gr.Textbox(label="結果") with gr.Row(): @@ -172,7 +179,14 @@ def create_dataset_app() -> gr.Blocks: result2 = gr.Textbox(label="結果") slice_button.click( do_slice, - inputs=[model_name, min_sec, max_sec, min_silence_dur_ms, input_dir], + inputs=[ + model_name, + min_sec, + max_sec, + min_silence_dur_ms, + time_suffix, + input_dir, + ], outputs=[result1], ) transcribe_button.click( From b71e304fe948d8ed251f6872f82395ab5d05b1cb Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 16:52:21 +0900 Subject: [PATCH 120/148] Change SileroVAD to my repo to disable onnx warning --- slice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slice.py b/slice.py index 56ea0bfed..fa631a47d 100644 --- a/slice.py +++ b/slice.py @@ -14,7 +14,7 @@ # TODO: 並列処理による高速化 vad_model, utils = torch.hub.load( - repo_or_dir="snakers4/silero-vad", + repo_or_dir="litagin02/silero-vad", model="silero_vad", onnx=True, trust_repo=True, From 67f19ba627b89753e8cd5efe73717c11237323f5 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 17:12:12 +0900 Subject: [PATCH 121/148] Fix: disable multiprocessing for bert_gen due to pyopenjtalk_worker --- bert_gen.py | 16 +++++++++------- config.py | 2 +- default_config.yml | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/bert_gen.py b/bert_gen.py index 933fa2121..1e4a92ca1 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -6,15 +6,19 @@ from tqdm import tqdm from config import config +from style_bert_vits2.constants import Languages from style_bert_vits2.logging import logger from style_bert_vits2.models import commons from style_bert_vits2.models.hyper_parameters import HyperParameters -from style_bert_vits2.nlp import cleaned_text_to_sequence, extract_bert_feature +from style_bert_vits2.nlp import ( + bert_models, + cleaned_text_to_sequence, + extract_bert_feature, +) from style_bert_vits2.nlp.japanese import pyopenjtalk_worker from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -from style_bert_vits2.nlp import bert_models -from style_bert_vits2.constants import Languages + bert_models.load_model(Languages.JP) bert_models.load_tokenizer(Languages.JP) @@ -74,9 +78,6 @@ def process_line(x: tuple[str, bool]): parser.add_argument( "-c", "--config", type=str, default=config.bert_gen_config.config_path ) - parser.add_argument( - "--num_processes", type=int, default=config.bert_gen_config.num_processes - ) args, _ = parser.parse_known_args() config_path = args.config hps = HyperParameters.load_from_json(config_path) @@ -89,7 +90,8 @@ def process_line(x: tuple[str, bool]): add_blank = [hps.data.add_blank] * len(lines) if len(lines) != 0: - num_processes = args.num_processes + # pyopenjtalkの別ワーカー化により、並列処理でエラーがでる模様なので、一旦シングルスレッド強制にする + num_processes = 1 with ThreadPoolExecutor(max_workers=num_processes) as executor: _ = list( tqdm( diff --git a/config.py b/config.py index 40f3f53d8..2229e6156 100644 --- a/config.py +++ b/config.py @@ -92,7 +92,7 @@ class Bert_gen_config: def __init__( self, config_path: str, - num_processes: int = 2, + num_processes: int = 1, device: str = "cuda", use_multi_device: bool = False, ): diff --git a/default_config.yml b/default_config.yml index ca51e7f44..8060cfbec 100644 --- a/default_config.yml +++ b/default_config.yml @@ -22,7 +22,7 @@ preprocess_text: bert_gen: config_path: "config.json" - num_processes: 2 + num_processes: 1 device: "cuda" use_multi_device: false From dd407d882d3a60dbfd47f69e8b6d0ec1c98d5f2e Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 18:45:25 +0900 Subject: [PATCH 122/148] Fix typing of arguments (Languages class) --- bert_gen.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bert_gen.py b/bert_gen.py index 1e4a92ca1..5619ee624 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -50,7 +50,9 @@ def process_line(x: tuple[str, bool]): tone = [int(i) for i in tone.split(" ")] word2ph = [int(i) for i in word2ph.split(" ")] word2ph = [i for i in word2ph] - phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) + phone, tone, language = cleaned_text_to_sequence( + phone, tone, Languages[language_str] + ) if add_blank: phone = commons.intersperse(phone, 0) From 4f60a3d5d5a488ba622be9d1980043109e7a6ac3 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Wed, 13 Mar 2024 18:56:51 +0900 Subject: [PATCH 123/148] Feat: multiprocessing of slicing for much faster slicing --- slice.py | 120 ++++++++++++++++++++++++++++++++++++++--------- webui/dataset.py | 12 +++++ 2 files changed, 111 insertions(+), 21 deletions(-) diff --git a/slice.py b/slice.py index fa631a47d..367c62b66 100644 --- a/slice.py +++ b/slice.py @@ -1,6 +1,10 @@ import argparse import shutil +import sys from pathlib import Path +from queue import Queue +from threading import Thread +from typing import Any, Optional import soundfile as sf import torch @@ -10,20 +14,12 @@ from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT - # TODO: 並列処理による高速化 -vad_model, utils = torch.hub.load( - repo_or_dir="litagin02/silero-vad", - model="silero_vad", - onnx=True, - trust_repo=True, -) - -(get_speech_timestamps, _, read_audio, *_) = utils - def get_stamps( + vad_model: Any, + utils: Any, audio_file: Path, min_silence_dur_ms: int = 700, min_sec: float = 2, @@ -42,6 +38,7 @@ def get_stamps( この秒数より大きい発話は無視する。 """ + (get_speech_timestamps, _, read_audio, *_) = utils sampling_rate = 16000 # 16kHzか8kHzのみ対応 min_ms = int(min_sec * 1000) @@ -60,6 +57,8 @@ def get_stamps( def split_wav( + vad_model: Any, + utils: Any, audio_file: Path, target_dir: Path, min_sec: float = 2, @@ -69,7 +68,9 @@ def split_wav( ) -> tuple[float, int]: margin: int = 200 # ミリ秒単位で、音声の前後に余裕を持たせる speech_timestamps = get_stamps( - audio_file, + vad_model=vad_model, + utils=utils, + audio_file=audio_file, min_silence_dur_ms=min_silence_dur_ms, min_sec=min_sec, max_sec=max_sec, @@ -139,6 +140,12 @@ def split_wav( action="store_true", help="Make the filename end with -start_ms-end_ms when saving wav.", ) + parser.add_argument( + "--num_processes", + type=int, + default=3, + help="Number of processes to use. Default 3 seems to be the best.", + ) args = parser.parse_args() with open(Path("configs/paths.yml"), "r", encoding="utf-8") as f: @@ -152,6 +159,7 @@ def split_wav( max_sec: float = args.max_sec min_silence_dur_ms: int = args.min_silence_dur_ms time_suffix: bool = args.time_suffix + num_processes: int = args.num_processes wav_files = Path(input_dir).glob("**/*.wav") wav_files = list(wav_files) @@ -160,19 +168,89 @@ def split_wav( logger.warning(f"Output directory {output_dir} already exists, deleting...") shutil.rmtree(output_dir) + # Silero VADのモデルは、同じインスタンスで並列処理するとおかしくなるらしい + # ワーカーごとにモデルをロードするようにするため、Queueを使って処理する + def process_queue( + q: Queue[Optional[Path]], + result_queue: Queue[tuple[float, int]], + error_queue: Queue[tuple[Path, Exception]], + ): + # logger.debug("Worker started.") + vad_model, utils = torch.hub.load( + repo_or_dir="litagin02/silero-vad", + model="silero_vad", + onnx=True, + trust_repo=True, + ) + while True: + file = q.get() + if file is None: # 終了シグナルを確認 + q.task_done() + break + try: + time_sec, count = split_wav( + vad_model=vad_model, + utils=utils, + audio_file=file, + target_dir=output_dir, + min_sec=min_sec, + max_sec=max_sec, + min_silence_dur_ms=min_silence_dur_ms, + time_suffix=time_suffix, + ) + result_queue.put((time_sec, count)) + except Exception as e: + logger.error(f"Error processing {file}: {e}") + error_queue.put((file, e)) + result_queue.put((0, 0)) + finally: + q.task_done() + + q: Queue[Optional[Path]] = Queue() + result_queue: Queue[tuple[float, int]] = Queue() + error_queue: Queue[tuple[Path, Exception]] = Queue() + + # ファイル数が少ない場合は、ワーカー数をファイル数に合わせる + num_processes = min(num_processes, len(wav_files)) + + threads = [ + Thread(target=process_queue, args=(q, result_queue, error_queue)) + for _ in range(num_processes) + ] + for t in threads: + t.start() + + pbar = tqdm(total=len(wav_files), file=SAFE_STDOUT) + for file in wav_files: + q.put(file) + + # result_queueを監視し、要素が追加されるごとに結果を加算しプログレスバーを更新 total_sec = 0 total_count = 0 - for wav_file in tqdm(wav_files, file=SAFE_STDOUT): - time_sec, count = split_wav( - audio_file=wav_file, - target_dir=output_dir, - min_sec=min_sec, - max_sec=max_sec, - min_silence_dur_ms=min_silence_dur_ms, - time_suffix=time_suffix, - ) - total_sec += time_sec + for _ in range(len(wav_files)): + time, count = result_queue.get() + total_sec += time total_count += count + pbar.update(1) + + # 全ての処理が終わるまで待つ + q.join() + + # 終了シグナル None を送る + for _ in range(num_processes): + q.put(None) + + for t in threads: + t.join() + + pbar.close() + + if not error_queue.empty(): + error_str = "Error slicing some files:" + while not error_queue.empty(): + file, e = error_queue.get() + error_str += f"\n{file}: {e}" + raise RuntimeError(error_str) logger.info( f"Slice done! Total time: {total_sec / 60:.2f} min, {total_count} files." diff --git a/webui/dataset.py b/webui/dataset.py index e0f0de7c7..1305e21fc 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -11,6 +11,7 @@ def do_slice( min_silence_dur_ms: int, time_suffix: bool, input_dir: str, + num_processes: int = 3, ): if model_name == "": return "Error: モデル名を入力してください。" @@ -25,6 +26,8 @@ def do_slice( str(max_sec), "--min_silence_dur_ms", str(min_silence_dur_ms), + "--num_processes", + str(num_processes), ] if time_suffix: cmd.append("--time_suffix") @@ -137,6 +140,14 @@ def create_dataset_app() -> gr.Blocks: value=False, label="WAVファイル名の末尾に元ファイルの時間範囲を付与する", ) + num_processes = gr.Slider( + minimum=1, + maximum=10, + value=3, + step=1, + label="並列処理数(速度向上のため)", + info="3で十分高速、多くしてもCPU負荷が増すだけでそこまで速度は変わらない", + ) slice_button = gr.Button("スライスを実行") result1 = gr.Textbox(label="結果") with gr.Row(): @@ -186,6 +197,7 @@ def create_dataset_app() -> gr.Blocks: min_silence_dur_ms, time_suffix, input_dir, + num_processes, ], outputs=[result1], ) From b1972a3d3d90bc7d2bf8045dba331bacbb0c916d Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 15:10:08 +0900 Subject: [PATCH 124/148] Feat: HF whisper for transcribing (faster than faster-whisper) --- slice.py | 3 - tes.py | 0 transcribe.py | 150 +++++++++++++++++++++++++++++++++++++++++------ webui/dataset.py | 76 ++++++++++++++++++------ 4 files changed, 189 insertions(+), 40 deletions(-) create mode 100644 tes.py diff --git a/slice.py b/slice.py index 367c62b66..c51e35284 100644 --- a/slice.py +++ b/slice.py @@ -1,6 +1,5 @@ import argparse import shutil -import sys from pathlib import Path from queue import Queue from threading import Thread @@ -14,8 +13,6 @@ from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -# TODO: 並列処理による高速化 - def get_stamps( vad_model: Any, diff --git a/tes.py b/tes.py new file mode 100644 index 000000000..e69de29bb diff --git a/transcribe.py b/transcribe.py index f44ecc386..cf61ddefa 100644 --- a/transcribe.py +++ b/transcribe.py @@ -2,10 +2,10 @@ import os import sys from pathlib import Path -from typing import Optional +from typing import Any, Optional import yaml -from faster_whisper import WhisperModel +from torch.utils.data import Dataset from tqdm import tqdm from style_bert_vits2.constants import Languages @@ -13,16 +13,97 @@ from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -def transcribe( - wav_path: Path, initial_prompt: Optional[str] = None, language: str = "ja" +# faster-whisperは並列処理しても速度が向上しないので、単一モデルでループ処理する +def transcribe_with_faster_whisper( + model: "WhisperModel", + audio_file: Path, + initial_prompt: Optional[str] = None, + language: str = "ja", + num_beams: int = 1, ): segments, _ = model.transcribe( - str(wav_path), beam_size=5, language=language, initial_prompt=initial_prompt + str(audio_file), + beam_size=num_beams, + language=language, + initial_prompt=initial_prompt, ) texts = [segment.text for segment in segments] return "".join(texts) +# HF pipelineで進捗表示をするために必要なDatasetクラス +class StrListDataset(Dataset[str]): + def __init__(self, original_list: list[str]) -> None: + self.original_list = original_list + + def __len__(self) -> int: + return len(self.original_list) + + def __getitem__(self, i: int) -> str: + return self.original_list[i] + + +# HFのWhisperはファイルリストを与えるとバッチ処理ができて速い +def transcribe_files_with_hf_whisper( + audio_files: list[Path], + model_id: str, + initial_prompt: Optional[str] = None, + language: str = "ja", + batch_size: int = 16, + num_beams: int = 1, + device: str = "cuda", + pbar: Optional[tqdm] = None, +) -> list[str]: + import torch + from transformers import WhisperProcessor, pipeline + + processor: WhisperProcessor = WhisperProcessor.from_pretrained(model_id) + generate_kwargs: dict[str, Any] = { + "language": language, + "do_sample": False, + "num_beams": 5, + "early_stopping": True, + "num_return_sequences": 5, + } + if initial_prompt is not None: + prompt_ids: torch.Tensor = processor.get_prompt_ids( + initial_prompt, return_tensors="pt" + ) + prompt_ids = prompt_ids.to(device) + generate_kwargs["prompt_ids"] = prompt_ids + + pipe = pipeline( + model=model_id, + max_new_tokens=128, + chunk_length_s=30, + batch_size=batch_size, + torch_dtype=torch.float16, + device="cuda", + generate_kwargs=generate_kwargs, + ) + dataset = StrListDataset([str(f) for f in audio_files]) + + results: list[str] = [] + for whisper_result in pipe(dataset): + logger.debug(whisper_result) + for result in enumerate(whisper_result): + logger.debug(result) + logger.debug(f"Transcribed: {result['text']}") + text: str = whisper_result["text"] + # なぜかテキストの最初に" {initial_prompt}"が入るので、文字の最初からこれを削除する + # cf. https://github.com/huggingface/transformers/issues/27594 + if text.startswith(f" {initial_prompt}"): + text = text[len(f" {initial_prompt}") :] + results.append(text) + if pbar is not None: + pbar.update(1) + + if pbar is not None: + pbar.close() + + return results + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, required=True) @@ -37,6 +118,9 @@ def transcribe( parser.add_argument("--model", type=str, default="large-v3") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--compute_type", type=str, default="bfloat16") + parser.add_argument("--use_hf_whisper", action="store_true") + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--num_beams", type=int, default=1) args = parser.parse_args() @@ -49,22 +133,18 @@ def transcribe( input_dir = dataset_root / model_name / "raw" output_file = dataset_root / model_name / "esd.list" initial_prompt: str = args.initial_prompt + initial_prompt = initial_prompt.strip('"') language: str = args.language device: str = args.device compute_type: str = args.compute_type + batch_size: int = args.batch_size + num_beams: int = args.num_beams output_file.parent.mkdir(parents=True, exist_ok=True) - logger.info( - f"Loading Whisper model ({args.model}) with compute_type={compute_type}" - ) - try: - model = WhisperModel(args.model, device=device, compute_type=compute_type) - except ValueError as e: - logger.warning(f"Failed to load model, so use `auto` compute_type: {e}") - model = WhisperModel(args.model, device=device) - wav_files = [f for f in input_dir.rglob("*.wav") if f.is_file()] + wav_files = sorted(wav_files, key=lambda x: x.name) + if output_file.exists(): logger.warning(f"{output_file} exists, backing up to {output_file}.bak") backup_path = output_file.with_name(output_file.name + ".bak") @@ -82,10 +162,42 @@ def transcribe( else: raise ValueError(f"{language} is not supported.") - wav_files = sorted(wav_files, key=lambda x: x.name) + logger.info( + f"Loading Whisper model ({args.model}) with compute_type={compute_type}" + ) + if not args.use_hf_whisper: + from faster_whisper import WhisperModel + + try: + model = WhisperModel(args.model, device=device, compute_type=compute_type) + except ValueError as e: + logger.warning(f"Failed to load model, so use `auto` compute_type: {e}") + model = WhisperModel(args.model, device=device) + for wav_file in tqdm(wav_files, file=SAFE_STDOUT): + text = transcribe_with_faster_whisper( + model=model, + audio_file=wav_file, + initial_prompt=initial_prompt, + language=language, + num_beams=num_beams, + ) + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n") + else: + model_id = f"openai/whisper-{args.model}" + pbar = tqdm(total=len(wav_files), file=SAFE_STDOUT) + results = transcribe_files_with_hf_whisper( + audio_files=wav_files, + model_id=model_id, + initial_prompt=initial_prompt, + language=language, + batch_size=batch_size, + num_beams=num_beams, + device=device, + pbar=pbar, + ) + with open(output_file, "w", encoding="utf-8") as f: + for wav_file, text in zip(wav_files, results): + f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n") - for wav_file in tqdm(wav_files, file=SAFE_STDOUT): - text = transcribe(wav_file, initial_prompt=initial_prompt, language=language) - with open(output_file, "a", encoding="utf-8") as f: - f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n") sys.exit(0) diff --git a/webui/dataset.py b/webui/dataset.py index 1305e21fc..ffd1b95e0 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -41,28 +41,40 @@ def do_slice( def do_transcribe( - model_name, whisper_model, compute_type, language, initial_prompt, device + model_name, + whisper_model, + compute_type, + language, + initial_prompt, + device, + use_hf_whisper, + batch_size, + num_beams, ): if model_name == "": return "Error: モデル名を入力してください。" - success, message = run_script_with_log( - [ - "transcribe.py", - "--model_name", - model_name, - "--model", - whisper_model, - "--compute_type", - compute_type, - "--device", - device, - "--language", - language, - "--initial_prompt", - f'"{initial_prompt}"', - ] - ) + cmd = [ + "transcribe.py", + "--model_name", + model_name, + "--model", + whisper_model, + "--compute_type", + compute_type, + "--device", + device, + "--language", + language, + "--initial_prompt", + f'"{initial_prompt}"', + "--num_beams", + str(num_beams), + ] + if use_hf_whisper: + cmd.append("--use_hf_whisper") + cmd.extend(["--batch_size", str(batch_size)]) + success, message = run_script_with_log(cmd) if not success: return f"Error: {message}. しかし何故かエラーが起きても正常に終了している場合がほとんどなので、書き起こし結果を確認して問題なければ学習に使えます。" return "音声の文字起こしが完了しました。" @@ -165,6 +177,9 @@ def create_dataset_app() -> gr.Blocks: label="Whisperモデル", value="large-v3", ) + use_hf_whisper = gr.Checkbox( + label="HuggingFaceのWhisperを使う(使うと速度が速いがVRAMを多く使う)", + ) compute_type = gr.Dropdown( [ "int8", @@ -186,6 +201,23 @@ def create_dataset_app() -> gr.Blocks: value="こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!", info="このように書き起こしてほしいという例文(句読点の入れ方・笑い方・固有名詞等)", ) + num_beams = gr.Slider( + minimum=1, + maximum=10, + value=5, + step=1, + label="ビームサーチのビーム数", + info="小さいほど速度が上がり(以前は5)、精度は少し落ちるかもしれないがほぼ変わらない体感", + ) + batch_size = gr.Slider( + minimum=1, + maximum=128, + value=32, + step=1, + label="バッチサイズ", + info="大きくすると速度が速くなるがVRAMを多く使う", + visible=False, + ) transcribe_button = gr.Button("音声の文字起こし") result2 = gr.Textbox(label="結果") slice_button.click( @@ -210,8 +242,16 @@ def create_dataset_app() -> gr.Blocks: language, initial_prompt, device, + use_hf_whisper, + batch_size, + num_beams, ], outputs=[result2], ) + use_hf_whisper.change( + lambda x: gr.update(visible=x), + inputs=[use_hf_whisper], + outputs=[batch_size], + ) return app From b65d1d47554bcb7098523de1ff84b02ab3450dda Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 15:19:09 +0900 Subject: [PATCH 125/148] Improve log and webui --- transcribe.py | 12 ++++++------ webui/dataset.py | 4 ++-- webui/merge.py | 8 +++----- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/transcribe.py b/transcribe.py index cf61ddefa..582a89a99 100644 --- a/transcribe.py +++ b/transcribe.py @@ -61,9 +61,7 @@ def transcribe_files_with_hf_whisper( generate_kwargs: dict[str, Any] = { "language": language, "do_sample": False, - "num_beams": 5, - "early_stopping": True, - "num_return_sequences": 5, + "num_beams": num_beams, } if initial_prompt is not None: prompt_ids: torch.Tensor = processor.get_prompt_ids( @@ -72,6 +70,7 @@ def transcribe_files_with_hf_whisper( prompt_ids = prompt_ids.to(device) generate_kwargs["prompt_ids"] = prompt_ids + logger.info(f"generate_kwargs: {generate_kwargs}") pipe = pipeline( model=model_id, max_new_tokens=128, @@ -162,12 +161,12 @@ def transcribe_files_with_hf_whisper( else: raise ValueError(f"{language} is not supported.") - logger.info( - f"Loading Whisper model ({args.model}) with compute_type={compute_type}" - ) if not args.use_hf_whisper: from faster_whisper import WhisperModel + logger.info( + f"Loading Whisper model ({args.model}) with compute_type={compute_type}" + ) try: model = WhisperModel(args.model, device=device, compute_type=compute_type) except ValueError as e: @@ -185,6 +184,7 @@ def transcribe_files_with_hf_whisper( f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n") else: model_id = f"openai/whisper-{args.model}" + logger.info(f"Loading HF Whisper model ({model_id})") pbar = tqdm(total=len(wav_files), file=SAFE_STDOUT) results = transcribe_files_with_hf_whisper( audio_files=wav_files, diff --git a/webui/dataset.py b/webui/dataset.py index ffd1b95e0..a38f7c6f9 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -249,9 +249,9 @@ def create_dataset_app() -> gr.Blocks: outputs=[result2], ) use_hf_whisper.change( - lambda x: gr.update(visible=x), + lambda x: (gr.update(visible=x), gr.update(visible=not x)), inputs=[use_hf_whisper], - outputs=[batch_size], + outputs=[batch_size, compute_type], ) return app diff --git a/webui/merge.py b/webui/merge.py index 18454c9a9..c7f8eeede 100644 --- a/webui/merge.py +++ b/webui/merge.py @@ -286,10 +286,6 @@ def load_styles_gr(model_name_a, model_name_b): initial_md = """ -# Style-Bert-VITS2 モデルマージツール - -2つのStyle-Bert-VITS2モデルから、声質・話し方・話す速さを取り替えたり混ぜたりできます。 - ## 使い方 1. マージしたい2つのモデルを選択してください(`model_assets`フォルダの中から選ばれます)。 @@ -343,7 +339,9 @@ def create_merge_app(model_holder: TTSModelHolder) -> gr.Blocks: initial_model_files = model_holder.model_files_dict[model_names[initial_id]] with gr.Blocks(theme=GRADIO_THEME) as app: - gr.Markdown(initial_md) + gr.Markdown( + "2つのStyle-Bert-VITS2モデルから、声質・話し方・話す速さを取り替えたり混ぜたりできます。" + ) with gr.Accordion(label="使い方", open=False): gr.Markdown(initial_md) with gr.Row(): From 2996a0b1955eb32d4663040e0c64749367ecea42 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 15:41:10 +0900 Subject: [PATCH 126/148] docs --- webui/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/webui/dataset.py b/webui/dataset.py index a38f7c6f9..1d8367c10 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -178,7 +178,7 @@ def create_dataset_app() -> gr.Blocks: value="large-v3", ) use_hf_whisper = gr.Checkbox( - label="HuggingFaceのWhisperを使う(使うと速度が速いがVRAMを多く使う)", + label="HuggingFaceのWhisperを使う(速度が速いがVRAMを多く使う)", ) compute_type = gr.Dropdown( [ @@ -204,10 +204,10 @@ def create_dataset_app() -> gr.Blocks: num_beams = gr.Slider( minimum=1, maximum=10, - value=5, + value=1, step=1, label="ビームサーチのビーム数", - info="小さいほど速度が上がり(以前は5)、精度は少し落ちるかもしれないがほぼ変わらない体感", + info="小さいほど速度が上がる(以前は5)", ) batch_size = gr.Slider( minimum=1, From a5540bdb2ec0fb4f111a99d86ce05d50af6e7a02 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 16:25:04 +0900 Subject: [PATCH 127/148] Clean --- tes.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tes.py diff --git a/tes.py b/tes.py deleted file mode 100644 index e69de29bb..000000000 From 8c95f0ff75ecbbc21f36dd8cdee64244a4983df7 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 16:45:31 +0900 Subject: [PATCH 128/148] Fix hf whisper --- colab.ipynb | 4 +++- transcribe.py | 4 ---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/colab.ipynb b/colab.ipynb index 06840884a..5e80f526d 100644 --- a/colab.ipynb +++ b/colab.ipynb @@ -26,7 +26,9 @@ "source": [ "## 0. 環境構築\n", "\n", - "Style-Bert-VITS2の環境をcolab上に構築します。グラボモードが有効になっていることを確認し、以下のセルを順に実行してください。" + "Style-Bert-VITS2の環境をcolab上に構築します。グラボモードが有効になっていることを確認し、以下のセルを順に実行してください。\n", + "\n", + "最近のcolabのアップデートにより、エラーダイアログ「WARNING: The following packages were previously imported in this runtime: [pydevd_plugins]」が出るが、「キャンセル」を選択して続行してください。" ] }, { diff --git a/transcribe.py b/transcribe.py index 582a89a99..f0392262c 100644 --- a/transcribe.py +++ b/transcribe.py @@ -84,10 +84,6 @@ def transcribe_files_with_hf_whisper( results: list[str] = [] for whisper_result in pipe(dataset): - logger.debug(whisper_result) - for result in enumerate(whisper_result): - logger.debug(result) - logger.debug(f"Transcribed: {result['text']}") text: str = whisper_result["text"] # なぜかテキストの最初に" {initial_prompt}"が入るので、文字の最初からこれを削除する # cf. https://github.com/huggingface/transformers/issues/27594 From 6f42a1514ff60071c4d73e43ab65c731ff6eeb9c Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 17:42:29 +0900 Subject: [PATCH 129/148] Add ngram_size option --- slice.py | 8 ++++++++ transcribe.py | 12 ++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/slice.py b/slice.py index c51e35284..2dece8ab1 100644 --- a/slice.py +++ b/slice.py @@ -165,6 +165,14 @@ def split_wav( logger.warning(f"Output directory {output_dir} already exists, deleting...") shutil.rmtree(output_dir) + # モデルをダウンロードしておく + _ = torch.hub.load( + repo_or_dir="litagin02/silero-vad", + model="silero_vad", + onnx=True, + trust_repo=True, + ) + # Silero VADのモデルは、同じインスタンスで並列処理するとおかしくなるらしい # ワーカーごとにモデルをロードするようにするため、Queueを使って処理する def process_queue( diff --git a/transcribe.py b/transcribe.py index f0392262c..8b3d21a38 100644 --- a/transcribe.py +++ b/transcribe.py @@ -20,12 +20,14 @@ def transcribe_with_faster_whisper( initial_prompt: Optional[str] = None, language: str = "ja", num_beams: int = 1, + no_repeat_ngram_size: int = 10, ): segments, _ = model.transcribe( str(audio_file), beam_size=num_beams, language=language, initial_prompt=initial_prompt, + no_repeat_ngram_size=no_repeat_ngram_size, ) texts = [segment.text for segment in segments] return "".join(texts) @@ -51,6 +53,7 @@ def transcribe_files_with_hf_whisper( language: str = "ja", batch_size: int = 16, num_beams: int = 1, + no_repeat_ngram_size: int = 10, device: str = "cuda", pbar: Optional[tqdm] = None, ) -> list[str]: @@ -62,7 +65,10 @@ def transcribe_files_with_hf_whisper( "language": language, "do_sample": False, "num_beams": num_beams, + "no_repeat_ngram_size": no_repeat_ngram_size, } + logger.info(f"generate_kwargs: {generate_kwargs}") + if initial_prompt is not None: prompt_ids: torch.Tensor = processor.get_prompt_ids( initial_prompt, return_tensors="pt" @@ -70,7 +76,6 @@ def transcribe_files_with_hf_whisper( prompt_ids = prompt_ids.to(device) generate_kwargs["prompt_ids"] = prompt_ids - logger.info(f"generate_kwargs: {generate_kwargs}") pipe = pipeline( model=model_id, max_new_tokens=128, @@ -116,7 +121,7 @@ def transcribe_files_with_hf_whisper( parser.add_argument("--use_hf_whisper", action="store_true") parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--num_beams", type=int, default=1) - + parser.add_argument("--no_repeat_ngram_size", type=int, default=10) args = parser.parse_args() with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: @@ -134,6 +139,7 @@ def transcribe_files_with_hf_whisper( compute_type: str = args.compute_type batch_size: int = args.batch_size num_beams: int = args.num_beams + no_repeat_ngram_size: int = args.no_repeat_ngram_size output_file.parent.mkdir(parents=True, exist_ok=True) @@ -175,6 +181,7 @@ def transcribe_files_with_hf_whisper( initial_prompt=initial_prompt, language=language, num_beams=num_beams, + no_repeat_ngram_size=no_repeat_ngram_size, ) with open(output_file, "a", encoding="utf-8") as f: f.write(f"{wav_file.name}|{model_name}|{language_id}|{text}\n") @@ -189,6 +196,7 @@ def transcribe_files_with_hf_whisper( language=language, batch_size=batch_size, num_beams=num_beams, + no_repeat_ngram_size=no_repeat_ngram_size, device=device, pbar=pbar, ) From 2eb489f91edadbcedcfe6dd074486c6a00cb3ac9 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 18:29:06 +0900 Subject: [PATCH 130/148] docs --- transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transcribe.py b/transcribe.py index 8b3d21a38..210ce2c7e 100644 --- a/transcribe.py +++ b/transcribe.py @@ -167,7 +167,7 @@ def transcribe_files_with_hf_whisper( from faster_whisper import WhisperModel logger.info( - f"Loading Whisper model ({args.model}) with compute_type={compute_type}" + f"Loading faster-whisper model ({args.model}) with compute_type={compute_type}" ) try: model = WhisperModel(args.model, device=device, compute_type=compute_type) From 51e246658d25631d2498cd6cbe79b6906595ee21 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 18:30:21 +0900 Subject: [PATCH 131/148] Change torch==2.1 to torch >=2.1 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bf47ec0fd..680f28b12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,6 @@ requests safetensors scipy tensorboard -torch>=2.1,<2.2 +torch transformers umap-learn From 1b72023e02031f95917d667654e174c8843ce801 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 19:15:21 +0900 Subject: [PATCH 132/148] Require torch>=2.1 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 680f28b12..d111e6910 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,6 @@ requests safetensors scipy tensorboard -torch +torch>=2.1 transformers umap-learn From 73b01f717276051b7c17783eea4e1aa224c8d2e1 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Thu, 14 Mar 2024 23:31:16 +0900 Subject: [PATCH 133/148] Add another wave dash for long symbol --- style_bert_vits2/nlp/japanese/normalizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/style_bert_vits2/nlp/japanese/normalizer.py b/style_bert_vits2/nlp/japanese/normalizer.py index 07b742c0a..d3bc40e9c 100644 --- a/style_bert_vits2/nlp/japanese/normalizer.py +++ b/style_bert_vits2/nlp/japanese/normalizer.py @@ -36,9 +36,10 @@ def normalize_text(text: str) -> str: res = unicodedata.normalize("NFKC", text) # ここでアルファベットは半角になる res = __convert_numbers_to_words(res) # 「100円」→「百円」等 - # 「~」と「~」も長音記号として扱う + # 「~」と「〜」とと「~」も長音記号として扱う res = res.replace("~", "ー") res = res.replace("~", "ー") + res = res.replace("〜", "ー") res = replace_punctuation(res) # 句読点等正規化、読めない文字を削除 From 5c1e791b9a3748a7258c54a26e3c0cafb537c936 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 10:30:22 +0900 Subject: [PATCH 134/148] Add cleaning bat script --- scripts/Clean.bat | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 scripts/Clean.bat diff --git a/scripts/Clean.bat b/scripts/Clean.bat new file mode 100644 index 000000000..0a3af3471 --- /dev/null +++ b/scripts/Clean.bat @@ -0,0 +1,66 @@ +chcp 65001 > NUL +@echo off +setlocal +echo 不要になった以下のフォルダ・ファイルを削除します: +echo 注: 学習やマージ等はApp.batへ統合されました。 +echo Style-Bert-VITS2\common\ +echo Style-Bert-VITS2\monotonic_align\ +echo Style-Bert-VITS2\text\ +echo Style-Bert-VITS2\tools\ +echo Style-Bert-VITS2\attentions.py +echo Style-Bert-VITS2\commons.py +echo Style-Bert-VITS2\Dataset.bat +echo Style-Bert-VITS2\infer.py +echo Style-Bert-VITS2\Merge.bat +echo Style-Bert-VITS2\models_jp_extra.py +echo Style-Bert-VITS2\models.py +echo Style-Bert-VITS2\modules.py +echo Style-Bert-VITS2\re_matching.py +echo Style-Bert-VITS2\spec_gen.py +echo Style-Bert-VITS2\Style.bat +echo Style-Bert-VITS2\Train.bat +echo Style-Bert-VITS2\transforms.py +echo Style-Bert-VITS2\update_status.py +echo Style-Bert-VITS2\utils.py +echo Style-Bert-VITS2\webui_dataset.py +echo Style-Bert-VITS2\webui_merge.py +echo Style-Bert-VITS2\webui_style_vectors.py +echo Style-Bert-VITS2\webui_train.py +echo Style-Bert-VITS2\webui.py +echo. +set /p delConfirm=削除しますか? (y/n): +if /I "%delConfirm%"=="Y" goto proceed +if /I "%delConfirm%"=="y" goto proceed +if "%delConfirm%"=="" goto proceed +goto end + +:proceed +rd /s /q "Style-Bert-VITS2\common" +rd /s /q "Style-Bert-VITS2\monotonic_align" +rd /s /q "Style-Bert-VITS2\text" +rd /s /q "Style-Bert-VITS2\tools" +del /q "Style-Bert-VITS2\attentions.py" +del /q "Style-Bert-VITS2\commons.py" +del /q "Style-Bert-VITS2\Dataset.bat" +del /q "Style-Bert-VITS2\infer.py" +del /q "Style-Bert-VITS2\Merge.bat" +del /q "Style-Bert-VITS2\models_jp_extra.py" +del /q "Style-Bert-VITS2\models.py" +del /q "Style-Bert-VITS2\modules.py" +del /q "Style-Bert-VITS2\re_matching.py" +del /q "Style-Bert-VITS2\spec_gen.py" +del /q "Style-Bert-VITS2\Style.bat" +del /q "Style-Bert-VITS2\Train.bat" +del /q "Style-Bert-VITS2\transforms.py" +del /q "Style-Bert-VITS2\update_status.py" +del /q "Style-Bert-VITS2\utils.py" +del /q "Style-Bert-VITS2\webui_dataset.py" +del /q "Style-Bert-VITS2\webui_merge.py" +del /q "Style-Bert-VITS2\webui_style_vectors.py" +del /q "Style-Bert-VITS2\webui_train.py" +del /q "Style-Bert-VITS2\webui.py" +echo 完了しました。 +pause + +:end +endlocal From 867855ace9bf4b824cdfcb99251663a581573e7e Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 11:10:23 +0900 Subject: [PATCH 135/148] Refactor, improve: use pathlib, allow audios other than wavs in resample --- resample.py | 100 ++++++++++++++++++++++++------------------------- webui/train.py | 12 +++--- 2 files changed, 54 insertions(+), 58 deletions(-) diff --git a/resample.py b/resample.py index 5c3cc7992..f9f1e3b23 100644 --- a/resample.py +++ b/resample.py @@ -1,18 +1,19 @@ import argparse -import os -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from multiprocessing import cpu_count +from pathlib import Path +from typing import Any import librosa import pyloudnorm as pyln import soundfile +from numpy.typing import NDArray from tqdm import tqdm from config import config from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT - DEFAULT_BLOCK_SIZE: float = 0.400 # seconds @@ -20,32 +21,37 @@ class BlockSizeException(Exception): pass -def normalize_audio(data, sr): +def normalize_audio(data: NDArray[Any], sr: int): meter = pyln.Meter(sr, block_size=DEFAULT_BLOCK_SIZE) # create BS.1770 meter try: loudness = meter.integrated_loudness(data) except ValueError as e: raise BlockSizeException(e) - # logger.info(f"loudness: {loudness}") + data = pyln.normalize.loudness(data, loudness, -23.0) return data -def process(item): - spkdir, wav_name, args = item - wav_path = os.path.join(args.in_dir, spkdir, wav_name) - if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"): - wav, sr = librosa.load(wav_path, sr=args.sr) - if args.normalize: +def resample(file: Path, output_dir: Path, target_sr: int, normalize: bool, trim: bool): + try: + # librosaが読めるファイルかチェック + # wav以外にもmp3やoggやflacなども読める + wav: NDArray[Any] + sr: int + wav, sr = librosa.load(file, sr=target_sr) + if normalize: try: wav = normalize_audio(wav, sr) except BlockSizeException: + print("") logger.info( - f"Skip normalize due to less than {DEFAULT_BLOCK_SIZE} second audio: {wav_path}" + f"Skip normalize due to less than {DEFAULT_BLOCK_SIZE} second audio: {file}" ) - if args.trim: + if trim: wav, _ = librosa.effects.trim(wav, top_db=30) - soundfile.write(os.path.join(args.out_dir, spkdir, wav_name), wav, sr) + soundfile.write(output_dir / file.with_suffix(".wav").name, wav, sr) + except Exception as e: + logger.warning(f"Cannot load file, so skipping: {file}, {e}") if __name__ == "__main__": @@ -57,14 +63,14 @@ def process(item): help="sampling rate", ) parser.add_argument( - "--in_dir", + "--input_dir", "-i", type=str, default=config.resample_config.in_dir, help="path to source dir", ) parser.add_argument( - "--out_dir", + "--output_dir", "-o", type=str, default=config.resample_config.out_dir, @@ -88,46 +94,36 @@ def process(item): default=False, help="trim silence (start and end only)", ) - args, _ = parser.parse_known_args() - # autodl 无卡模式会识别出46个cpu + args = parser.parse_args() + if args.num_processes == 0: processes = cpu_count() - 2 if cpu_count() > 4 else 1 else: - processes = args.num_processes - - tasks = [] - - for dirpath, _, filenames in os.walk(args.in_dir): - # 子级目录 - spk_dir = os.path.relpath(dirpath, args.in_dir) - spk_dir_out = os.path.join(args.out_dir, spk_dir) - if not os.path.isdir(spk_dir_out): - os.makedirs(spk_dir_out, exist_ok=True) - for filename in filenames: - if filename.lower().endswith(".wav"): - twople = (spk_dir, filename, args) - tasks.append(twople) - - if len(tasks) == 0: - logger.error(f"No wav files found in {args.in_dir}") - raise ValueError(f"No wav files found in {args.in_dir}") - - # pool = Pool(processes=processes) - # for _ in tqdm( - # pool.imap_unordered(process, tasks), file=SAFE_STDOUT, total=len(tasks) - # ): - # pass - - # pool.close() - # pool.join() + processes: int = args.num_processes + + input_dir = Path(args.input_dir) + output_dir = Path(args.output_dir) + sr = int(args.sr) + normalize: bool = args.normalize + trim: bool = args.trim + + # 後でlibrosaに読ませて有効な音声ファイルかチェックするので、全てのファイルを取得 + original_files = [f for f in input_dir.rglob("*") if f.is_file()] + + if len(original_files) == 0: + logger.error(f"No files found in {input_dir}") + raise ValueError(f"No files found in {input_dir}") + + output_dir.mkdir(parents=True, exist_ok=True) with ThreadPoolExecutor(max_workers=processes) as executor: - _ = list( - tqdm( - executor.map(process, tasks), - total=len(tasks), - file=SAFE_STDOUT, - ) - ) + futures = [ + executor.submit(resample, file, output_dir, sr, normalize, trim) + for file in original_files + ] + for future in tqdm( + as_completed(futures), total=len(original_files), file=SAFE_STDOUT + ): + pass logger.info("Resampling Done!") diff --git a/webui/train.py b/webui/train.py index 6d7206181..f7c165938 100644 --- a/webui/train.py +++ b/webui/train.py @@ -129,14 +129,14 @@ def initialize( def resample(model_name, normalize, trim, num_processes): logger.info("Step 2: start resampling...") dataset_path, _, _, _, _ = get_path(model_name) - in_dir = os.path.join(dataset_path, "raw") - out_dir = os.path.join(dataset_path, "wavs") + input_dir = os.path.join(dataset_path, "raw") + output_dir = os.path.join(dataset_path, "wavs") cmd = [ "resample.py", - "--in_dir", - in_dir, - "--out_dir", - out_dir, + "-i", + input_dir, + "-o", + output_dir, "--num_processes", str(num_processes), "--sr", From 3d8b60c03c4b5bf23d36090ede0d425a184936fc Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 12:27:46 +0900 Subject: [PATCH 136/148] Refactor --- preprocess_text.py | 269 ++++++++++++-------- resample.py | 4 + style_bert_vits2/logging.py | 1 - style_bert_vits2/nlp/japanese/normalizer.py | 2 +- style_gen.py | 1 - webui/style_vectors.py | 2 +- webui/train.py | 161 ++++++------ 7 files changed, 242 insertions(+), 198 deletions(-) diff --git a/preprocess_text.py b/preprocess_text.py index 6bc858563..65fc77b9d 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -1,145 +1,161 @@ +import argparse import json -import os from collections import defaultdict +from pathlib import Path from random import shuffle from typing import Optional -import click from tqdm import tqdm -from config import config +from config import Preprocess_text_config, config from style_bert_vits2.logging import logger from style_bert_vits2.nlp import clean_text from style_bert_vits2.nlp.japanese import pyopenjtalk_worker from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT - # このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 pyopenjtalk_worker.initialize_worker() # dict_data/ 以下の辞書データを pyopenjtalk に適用 update_dict() -preprocess_text_config = config.preprocess_text_config + +preprocess_text_config: Preprocess_text_config = config.preprocess_text_config # Count lines for tqdm -def count_lines(file_path: str): - with open(file_path, "r", encoding="utf-8") as file: +def count_lines(file_path: Path): + with file_path.open("r", encoding="utf-8") as file: return sum(1 for _ in file) -@click.command() -@click.option( - "--transcription-path", - default=preprocess_text_config.transcription_path, - type=click.Path(exists=True, file_okay=True, dir_okay=False), -) -@click.option("--cleaned-path", default=preprocess_text_config.cleaned_path) -@click.option("--train-path", default=preprocess_text_config.train_path) -@click.option("--val-path", default=preprocess_text_config.val_path) -@click.option( - "--config-path", - default=preprocess_text_config.config_path, - type=click.Path(exists=True, file_okay=True, dir_okay=False), -) -@click.option("--val-per-lang", default=preprocess_text_config.val_per_lang) -@click.option("--max-val-total", default=preprocess_text_config.max_val_total) -@click.option("--clean/--no-clean", default=preprocess_text_config.clean) -@click.option("-y", "--yml_config") -@click.option("--use_jp_extra", is_flag=True) -@click.option("--yomi_error", default="raise") +def write_error_log(error_log_path: Path, line: str, error: Exception): + with error_log_path.open("a", encoding="utf-8") as error_log: + error_log.write(f"{line.strip()}\n{error}\n\n") + + +def process_line( + line: str, + transcription_path: Path, + correct_path: bool, + use_jp_extra: bool, + yomi_error: str, +): + splitted_line = line.strip().split("|") + if len(splitted_line) != 4: + raise ValueError(f"Invalid line format: {line.strip()}") + utt, spk, language, text = splitted_line + norm_text, phones, tones, word2ph = clean_text( + text=text, + language=language, # type: ignore + use_jp_extra=use_jp_extra, + raise_yomi_error=(yomi_error != "use"), + ) + if correct_path: + utt = str(transcription_path.parent / "wavs" / utt) + + return "{}|{}|{}|{}|{}|{}|{}\n".format( + utt, + spk, + language, + norm_text, + " ".join(phones), + " ".join([str(i) for i in tones]), + " ".join([str(i) for i in word2ph]), + ) + + def preprocess( - transcription_path: str, - cleaned_path: Optional[str], - train_path: str, - val_path: str, - config_path: str, + transcription_path: Path, + cleaned_path: Optional[Path], + train_path: Path, + val_path: Path, + config_path: Path, val_per_lang: int, max_val_total: int, - clean: bool, - yml_config: str, # 这个不要删 + # clean: bool, use_jp_extra: bool, yomi_error: str, + correct_path: bool, ): assert yomi_error in ["raise", "skip", "use"] if cleaned_path == "" or cleaned_path is None: - cleaned_path = transcription_path + ".cleaned" + cleaned_path = transcription_path.with_name( + transcription_path.name + ".cleaned" + ) - error_log_path = os.path.join(os.path.dirname(cleaned_path), "text_error.log") - if os.path.exists(error_log_path): - os.remove(error_log_path) + error_log_path = transcription_path.parent / "text_error.log" + if error_log_path.exists(): + error_log_path.unlink() error_count = 0 - if clean: - total_lines = count_lines(transcription_path) - with open(cleaned_path, "w", encoding="utf-8") as out_file: - with open(transcription_path, "r", encoding="utf-8") as trans_file: - for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines): - try: - utt, spk, language, text = line.strip().split("|") - norm_text, phones, tones, word2ph = clean_text( - text=text, - language=language, # type: ignore - use_jp_extra=use_jp_extra, - raise_yomi_error=(yomi_error != "use"), - ) - - out_file.write( - "{}|{}|{}|{}|{}|{}|{}\n".format( - utt, - spk, - language, - norm_text, - " ".join(phones), - " ".join([str(i) for i in tones]), - " ".join([str(i) for i in word2ph]), - ) - ) - except Exception as e: - logger.error( - f"An error occurred at line:\n{line.strip()}\n{e}", - encoding="utf-8", - ) - with open(error_log_path, "a", encoding="utf-8") as error_log: - error_log.write(f"{line.strip()}\n{e}\n\n") - error_count += 1 + total_lines = count_lines(transcription_path) + + # transcription_path から 1行ずつ読み込んで文章処理して cleaned_path に書き込む + with ( + transcription_path.open("r", encoding="utf-8") as trans_file, + cleaned_path.open("w", encoding="utf-8") as out_file, + ): + for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines): + try: + processed_line = process_line( + line, + transcription_path, + correct_path, + use_jp_extra, + yomi_error, + ) + out_file.write(processed_line) + except Exception as e: + logger.error( + f"An error occurred at line:\n{line.strip()}\n{e}", encoding="utf-8" + ) + write_error_log(error_log_path, line, e) + error_count += 1 transcription_path = cleaned_path - spk_utt_map = defaultdict(list) - spk_id_map = {} - current_sid = 0 - - with open(transcription_path, "r", encoding="utf-8") as f: - audioPaths = set() - countSame = 0 - countNotFound = 0 + + # 各話者ごとのlineの辞書 + spk_utt_map: dict[str, list[str]] = defaultdict(list) + + # 話者からIDへの写像 + spk_id_map: dict[str, int] = {} + + # 話者ID + current_sid: int = 0 + + # 音源ファイルのチェックや、spk_id_mapの作成 + with transcription_path.open("r", encoding="utf-8") as f: + audio_paths: set[str] = set() + count_same = 0 + count_not_found = 0 for line in f.readlines(): - utt, spk, language, text, phones, tones, word2ph = line.strip().split("|") - if utt in audioPaths: - # 过滤数据集错误:相同的音频匹配多个文本,导致后续bert出问题 - logger.warning(f"Same audio matches multiple texts: {line}") - countSame += 1 + utt, spk = line.strip().split("|")[:2] + if utt in audio_paths: + logger.warning(f"Same audio file appears multiple times: {utt}") + count_same += 1 continue - if not os.path.isfile(utt): - # 过滤数据集错误:不存在对应音频 + if not Path(utt).is_file(): logger.warning(f"Audio not found: {utt}") - countNotFound += 1 + count_not_found += 1 continue - audioPaths.add(utt) - spk_utt_map[language].append(line) + audio_paths.add(utt) + spk_utt_map[spk].append(line) + + # 新しい話者が出てきたら話者IDを割り当て、current_sidを1増やす if spk not in spk_id_map.keys(): spk_id_map[spk] = current_sid current_sid += 1 - if countSame > 0 or countNotFound > 0: + if count_same > 0 or count_not_found > 0: logger.warning( - f"Total repeated audios: {countSame}, Total number of audio not found: {countNotFound}" + f"Total repeated audios: {count_same}, Total number of audio not found: {count_not_found}" ) - train_list = [] - val_list = [] + train_list: list[str] = [] + val_list: list[str] = [] + # 各話者ごとにシャッフルして、val_per_lang個をval_listに、残りをtrain_listに追加 for spk, utts in spk_utt_map.items(): shuffle(utts) val_list += utts[:val_per_lang] @@ -150,26 +166,21 @@ def preprocess( train_list += val_list[max_val_total:] val_list = val_list[:max_val_total] - with open(train_path, "w", encoding="utf-8") as f: + with train_path.open("w", encoding="utf-8") as f: for line in train_list: f.write(line) - with open(val_path, "w", encoding="utf-8") as f: + with val_path.open("w", encoding="utf-8") as f: for line in val_list: f.write(line) - json_config = json.load(open(config_path, encoding="utf-8")) + with config_path.open("r", encoding="utf-8") as f: + json_config = json.load(f) + json_config["data"]["spk2id"] = spk_id_map json_config["data"]["n_speakers"] = len(spk_id_map) - # 新增写入:写入训练版本、数据集路径 - # json_config["version"] = latest_version - json_config["data"]["training_files"] = os.path.normpath(train_path).replace( - "\\", "/" - ) - json_config["data"]["validation_files"] = os.path.normpath(val_path).replace( - "\\", "/" - ) - with open(config_path, "w", encoding="utf-8") as f: + + with config_path.open("w", encoding="utf-8") as f: json.dump(json_config, f, indent=2, ensure_ascii=False) if error_count > 0: if yomi_error == "skip": @@ -194,4 +205,50 @@ def preprocess( if __name__ == "__main__": - preprocess() + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcription-path", default=preprocess_text_config.transcription_path + ) + parser.add_argument("--cleaned-path", default=preprocess_text_config.cleaned_path) + parser.add_argument("--train-path", default=preprocess_text_config.train_path) + parser.add_argument("--val-path", default=preprocess_text_config.val_path) + parser.add_argument("--config-path", default=preprocess_text_config.config_path) + + # 「話者ごと」のバリデーションデータ数、言語ごとではない! + # 元のコードや設定ファイルでval_per_langとなっていたので名前をそのままにしている + parser.add_argument( + "--val-per-lang", + default=preprocess_text_config.val_per_lang, + help="Number of validation data per SPEAKER, not per language (due to compatibility with the original code).", + ) + parser.add_argument("--max-val-total", default=preprocess_text_config.max_val_total) + parser.add_argument("--use_jp_extra", action="store_true") + parser.add_argument("--yomi_error", default="raise") + parser.add_argument("--correct_path", action="store_true") + + args = parser.parse_args() + logger.debug(f"args: {args}") + + transcription_path = Path(args.transcription_path) + cleaned_path = Path(args.cleaned_path) if args.cleaned_path else None + train_path = Path(args.train_path) + val_path = Path(args.val_path) + config_path = Path(args.config_path) + val_per_lang = int(args.val_per_lang) + max_val_total = int(args.max_val_total) + use_jp_extra: bool = args.use_jp_extra + yomi_error: str = args.yomi_error + correct_path: bool = args.correct_path + + preprocess( + transcription_path=transcription_path, + cleaned_path=cleaned_path, + train_path=train_path, + val_path=val_path, + config_path=config_path, + val_per_lang=val_per_lang, + max_val_total=max_val_total, + use_jp_extra=use_jp_extra, + yomi_error=yomi_error, + correct_path=correct_path, + ) diff --git a/resample.py b/resample.py index f9f1e3b23..5be01a734 100644 --- a/resample.py +++ b/resample.py @@ -14,6 +14,7 @@ from style_bert_vits2.logging import logger from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + DEFAULT_BLOCK_SIZE: float = 0.400 # seconds @@ -33,6 +34,9 @@ def normalize_audio(data: NDArray[Any], sr: int): def resample(file: Path, output_dir: Path, target_sr: int, normalize: bool, trim: bool): + """ + fileを読み込んで、target_srなwavファイルに変換してoutput_dir直下に保存する + """ try: # librosaが読めるファイルかチェック # wav以外にもmp3やoggやflacなども読める diff --git a/style_bert_vits2/logging.py b/style_bert_vits2/logging.py index e552bf780..e5c216a00 100644 --- a/style_bert_vits2/logging.py +++ b/style_bert_vits2/logging.py @@ -12,5 +12,4 @@ format="{time:MM-DD HH:mm:ss} |{level:^8}| {file}:{line} | {message}", backtrace=True, diagnose=True, - level="TRACE", ) diff --git a/style_bert_vits2/nlp/japanese/normalizer.py b/style_bert_vits2/nlp/japanese/normalizer.py index d3bc40e9c..5ceb2f8c1 100644 --- a/style_bert_vits2/nlp/japanese/normalizer.py +++ b/style_bert_vits2/nlp/japanese/normalizer.py @@ -36,7 +36,7 @@ def normalize_text(text: str) -> str: res = unicodedata.normalize("NFKC", text) # ここでアルファベットは半角になる res = __convert_numbers_to_words(res) # 「100円」→「百円」等 - # 「~」と「〜」とと「~」も長音記号として扱う + # 「~」と「〜」と「~」も長音記号として扱う res = res.replace("~", "ー") res = res.replace("~", "ー") res = res.replace("〜", "ー") diff --git a/style_gen.py b/style_gen.py index 02af06742..a21ab2b33 100644 --- a/style_gen.py +++ b/style_gen.py @@ -1,5 +1,4 @@ import argparse -import warnings from concurrent.futures import ThreadPoolExecutor from typing import Any diff --git a/webui/style_vectors.py b/webui/style_vectors.py index 05056eb03..e125c3fe9 100644 --- a/webui/style_vectors.py +++ b/webui/style_vectors.py @@ -6,12 +6,12 @@ import matplotlib.pyplot as plt import numpy as np import yaml +from config import config from scipy.spatial.distance import pdist, squareform from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans from sklearn.manifold import TSNE from umap import UMAP -from config import config from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME from style_bert_vits2.logging import logger diff --git a/webui/train.py b/webui/train.py index f7c165938..c4cd60a80 100644 --- a/webui/train.py +++ b/webui/train.py @@ -24,32 +24,31 @@ # Get path settings with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: path_config: dict[str, str] = yaml.safe_load(f.read()) - dataset_root = path_config["dataset_root"] - # assets_root = path_config["assets_root"] + dataset_root = Path(path_config["dataset_root"]) -def get_path(model_name): +def get_path(model_name: str) -> tuple[Path, Path, Path, Path, Path]: assert model_name != "", "モデル名は空にできません" - dataset_path = os.path.join(dataset_root, model_name) - lbl_path = os.path.join(dataset_path, "esd.list") - train_path = os.path.join(dataset_path, "train.list") - val_path = os.path.join(dataset_path, "val.list") - config_path = os.path.join(dataset_path, "config.json") + dataset_path = dataset_root / model_name + lbl_path = dataset_path / "esd.list" + train_path = dataset_path / "train.list" + val_path = dataset_path / "val.list" + config_path = dataset_path / "config.json" return dataset_path, lbl_path, train_path, val_path, config_path def initialize( - model_name, - batch_size, - epochs, - save_every_steps, - freeze_EN_bert, - freeze_JP_bert, - freeze_ZH_bert, - freeze_style, - freeze_decoder, - use_jp_extra, - log_interval, + model_name: str, + batch_size: int, + epochs: int, + save_every_steps: int, + freeze_EN_bert: bool, + freeze_JP_bert: bool, + freeze_ZH_bert: bool, + freeze_style: bool, + freeze_decoder: bool, + use_jp_extra: bool, + log_interval: int, ): global logger_handler dataset_path, _, train_path, val_path, config_path = get_path(model_name) @@ -72,8 +71,8 @@ def initialize( with open(default_config_path, "r", encoding="utf-8") as f: config = json.load(f) config["model_name"] = model_name - config["data"]["training_files"] = train_path - config["data"]["validation_files"] = val_path + config["data"]["training_files"] = str(train_path) + config["data"]["validation_files"] = str(val_path) config["train"]["batch_size"] = batch_size config["train"]["epochs"] = epochs config["train"]["eval_interval"] = save_every_steps @@ -90,18 +89,18 @@ def initialize( # 今はデフォルトであるが、以前は非JP-Extra版になくバグの原因になるので念のため config["data"]["use_jp_extra"] = use_jp_extra - model_path = os.path.join(dataset_path, "models") - if os.path.exists(model_path): + model_path = dataset_path / "models" + if model_path.exists(): logger.warning( f"Step 1: {model_path} already exists, so copy it to backup to {model_path}_backup" ) shutil.copytree( src=model_path, - dst=os.path.join(dataset_path, "models_backup"), + dst=dataset_path / "models_backup", dirs_exist_ok=True, ) shutil.rmtree(model_path) - pretrained_dir = "pretrained" if not use_jp_extra else "pretrained_jp_extra" + pretrained_dir = Path("pretrained" if not use_jp_extra else "pretrained_jp_extra") try: shutil.copytree( src=pretrained_dir, @@ -113,30 +112,29 @@ def initialize( with open(config_path, "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) - if not os.path.exists("config.yml"): + if not Path("config.yml").exists(): shutil.copy(src="default_config.yml", dst="config.yml") - # yml_data = safe_load(open("config.yml", "r", encoding="utf-8")) with open("config.yml", "r", encoding="utf-8") as f: yml_data = yaml.safe_load(f) yml_data["model_name"] = model_name - yml_data["dataset_path"] = dataset_path + yml_data["dataset_path"] = str(dataset_path) with open("config.yml", "w", encoding="utf-8") as f: yaml.dump(yml_data, f, allow_unicode=True) logger.success("Step 1: initialization finished.") return True, "Step 1, Success: 初期設定が完了しました" -def resample(model_name, normalize, trim, num_processes): +def resample(model_name: str, normalize: bool, trim: bool, num_processes: int): logger.info("Step 2: start resampling...") dataset_path, _, _, _, _ = get_path(model_name) - input_dir = os.path.join(dataset_path, "raw") - output_dir = os.path.join(dataset_path, "wavs") + input_dir = dataset_path / "raw" + output_dir = dataset_path / "wavs" cmd = [ "resample.py", "-i", - input_dir, + str(input_dir), "-o", - output_dir, + str(output_dir), "--num_processes", str(num_processes), "--sr", @@ -157,42 +155,30 @@ def resample(model_name, normalize, trim, num_processes): return True, "Step 2, Success: 音声ファイルの前処理が完了しました" -def preprocess_text(model_name, use_jp_extra, val_per_lang, yomi_error): +def preprocess_text( + model_name: str, use_jp_extra: bool, val_per_lang: int, yomi_error: str +): logger.info("Step 3: start preprocessing text...") - dataset_path, lbl_path, train_path, val_path, config_path = get_path(model_name) - try: - lines = open(lbl_path, "r", encoding="utf-8").readlines() - except FileNotFoundError: + _, lbl_path, train_path, val_path, config_path = get_path(model_name) + if not lbl_path.exists(): logger.error(f"Step 3: {lbl_path} not found.") return False, f"Step 3, Error: 書き起こしファイル {lbl_path} が見つかりません。" - new_lines = [] - for line in lines: - if len(line.strip().split("|")) != 4: - logger.error(f"Step 3: {lbl_path} has invalid format at line:\n{line}") - return ( - False, - f"Step 3, Error: 書き起こしファイル次の行の形式が不正です:\n{line}", - ) - path, spk, language, text = line.strip().split("|") - # pathをファイル名だけ取り出して正しいパスに変更 - path = Path(dataset_path) / "wavs" / Path(path).name - new_lines.append(f"{path}|{spk}|{language}|{text}\n") - with open(lbl_path, "w", encoding="utf-8") as f: - f.writelines(new_lines) + cmd = [ "preprocess_text.py", "--config-path", - config_path, + str(config_path), "--transcription-path", - lbl_path, + str(lbl_path), "--train-path", - train_path, + str(train_path), "--val-path", - val_path, + str(val_path), "--val-per-lang", str(val_per_lang), "--yomi_error", yomi_error, + "--correct_path", # 音声ファイルのパスを正しいパスに修正する ] if use_jp_extra: cmd.append("--use_jp_extra") @@ -213,17 +199,11 @@ def preprocess_text(model_name, use_jp_extra, val_per_lang, yomi_error): return True, "Step 3, Success: 書き起こしファイルの前処理が完了しました" -def bert_gen(model_name): +def bert_gen(model_name: str): logger.info("Step 4: start bert_gen...") _, _, _, _, config_path = get_path(model_name) success, message = run_script_with_log( - [ - "bert_gen.py", - "--config", - config_path, - # "--num_processes", # bert_genは重いのでプロセス数いじらない - # str(num_processes), - ] + ["bert_gen.py", "--config", str(config_path)] ) if not success: logger.error("Step 4: bert_gen failed.") @@ -238,14 +218,14 @@ def bert_gen(model_name): return True, "Step 4, Success: BERT特徴ファイルの生成が完了しました" -def style_gen(model_name, num_processes): +def style_gen(model_name: str, num_processes: int): logger.info("Step 5: start style_gen...") _, _, _, _, config_path = get_path(model_name) success, message = run_script_with_log( [ "style_gen.py", "--config", - config_path, + str(config_path), "--num_processes", str(num_processes), ] @@ -267,22 +247,22 @@ def style_gen(model_name, num_processes): def preprocess_all( - model_name, - batch_size, - epochs, - save_every_steps, - num_processes, - normalize, - trim, - freeze_EN_bert, - freeze_JP_bert, - freeze_ZH_bert, - freeze_style, - freeze_decoder, - use_jp_extra, - val_per_lang, - log_interval, - yomi_error, + model_name: str, + batch_size: int, + epochs: int, + save_every_steps: int, + num_processes: int, + normalize: bool, + trim: bool, + freeze_EN_bert: bool, + freeze_JP_bert: bool, + freeze_ZH_bert: bool, + freeze_style: bool, + freeze_decoder: bool, + use_jp_extra: bool, + val_per_lang: int, + log_interval: int, + yomi_error: str, ): if model_name == "": return False, "Error: モデル名を入力してください" @@ -333,18 +313,23 @@ def preprocess_all( ) -def train(model_name, skip_style=False, use_jp_extra=True, speedup=False): +def train( + model_name: str, + skip_style: bool = False, + use_jp_extra: bool = True, + speedup: bool = False, +): dataset_path, _, _, _, config_path = get_path(model_name) - # 学習再開の場合は念のためconfig.ymlの名前等を更新 + # 学習再開の場合を考えて念のためconfig.ymlの名前等を更新 with open("config.yml", "r", encoding="utf-8") as f: yml_data = yaml.safe_load(f) yml_data["model_name"] = model_name - yml_data["dataset_path"] = dataset_path + yml_data["dataset_path"] = str(dataset_path) with open("config.yml", "w", encoding="utf-8") as f: yaml.dump(yml_data, f, allow_unicode=True) train_py = "train_ms.py" if not use_jp_extra else "train_ms_jp_extra.py" - cmd = [train_py, "--config", config_path, "--model", dataset_path] + cmd = [train_py, "--config", str(config_path), "--model", str(dataset_path)] if skip_style: cmd.append("--skip_default_style") if speedup: @@ -360,7 +345,7 @@ def train(model_name, skip_style=False, use_jp_extra=True, speedup=False): return True, "Success: 学習が完了しました" -def wait_for_tensorboard(port=6006, timeout=10): +def wait_for_tensorboard(port: int = 6006, timeout: float = 10): start_time = time.time() while True: try: @@ -375,7 +360,7 @@ def wait_for_tensorboard(port=6006, timeout=10): time.sleep(0.1) -def run_tensorboard(model_name): +def run_tensorboard(model_name: str): global tensorboard_executed if not tensorboard_executed: python = sys.executable From 432edcc52b910b7a795cf8174471d49082c2221c Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 12:34:54 +0900 Subject: [PATCH 137/148] Bump ver, rename configs_jp_extra to config_jp_extra for consistency --- configs/config.json | 2 +- configs/{configs_jp_extra.json => config_jp_extra.json} | 2 +- style_bert_vits2/constants.py | 2 +- style_bert_vits2/models/hyper_parameters.py | 2 +- webui/train.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) rename configs/{configs_jp_extra.json => config_jp_extra.json} (98%) diff --git a/configs/config.json b/configs/config.json index db0411aed..594cd182a 100644 --- a/configs/config.json +++ b/configs/config.json @@ -69,5 +69,5 @@ "use_spectral_norm": false, "gin_channels": 256 }, - "version": "2.4.dev0" + "version": "2.4.0" } diff --git a/configs/configs_jp_extra.json b/configs/config_jp_extra.json similarity index 98% rename from configs/configs_jp_extra.json rename to configs/config_jp_extra.json index 23c5fe500..bf5722c73 100644 --- a/configs/configs_jp_extra.json +++ b/configs/config_jp_extra.json @@ -76,5 +76,5 @@ "initial_channel": 64 } }, - "version": "2.4.dev0-JP-Extra" + "version": "2.4.0-JP-Extra" } diff --git a/style_bert_vits2/constants.py b/style_bert_vits2/constants.py index aa1433101..337cb81fa 100644 --- a/style_bert_vits2/constants.py +++ b/style_bert_vits2/constants.py @@ -4,7 +4,7 @@ # Style-Bert-VITS2 のバージョン -VERSION = "2.4.dev0" +VERSION = "2.4.0" # Style-Bert-VITS2 のベースディレクトリ BASE_DIR = Path(__file__).parent.parent diff --git a/style_bert_vits2/models/hyper_parameters.py b/style_bert_vits2/models/hyper_parameters.py index 254b285b7..feb6bfbaf 100644 --- a/style_bert_vits2/models/hyper_parameters.py +++ b/style_bert_vits2/models/hyper_parameters.py @@ -1,6 +1,6 @@ """ Style-Bert-VITS2 モデルのハイパーパラメータを表す Pydantic モデル。 -デフォルト値は configs/configs_jp_extra.json 内の定義と概ね同一で、 +デフォルト値は configs/config_jp_extra.json 内の定義と概ね同一で、 万が一ロードした config.json に存在しないキーがあった際のフェイルセーフとして適用される。 """ diff --git a/webui/train.py b/webui/train.py index c4cd60a80..eadefafed 100644 --- a/webui/train.py +++ b/webui/train.py @@ -65,7 +65,7 @@ def initialize( ) default_config_path = ( - "configs/config.json" if not use_jp_extra else "configs/configs_jp_extra.json" + "configs/config.json" if not use_jp_extra else "configs/config_jp_extra.json" ) with open(default_config_path, "r", encoding="utf-8") as f: From b02b882c04ac0142adca6ed86279167c5b2dcec4 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 12:53:24 +0900 Subject: [PATCH 138/148] Feat: support various audio format in slice.py --- slice.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/slice.py b/slice.py index 2dece8ab1..c4b2292fd 100644 --- a/slice.py +++ b/slice.py @@ -14,6 +14,11 @@ from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT +def is_audio_file(file: Path) -> bool: + supported_extensions = [".wav", ".flac", ".mp3", ".ogg", ".opus"] + return file.suffix.lower() in supported_extensions + + def get_stamps( vad_model: Any, utils: Any, @@ -158,9 +163,9 @@ def split_wav( time_suffix: bool = args.time_suffix num_processes: int = args.num_processes - wav_files = Path(input_dir).glob("**/*.wav") - wav_files = list(wav_files) - logger.info(f"Found {len(wav_files)} wav files.") + audio_files = [file for file in input_dir.rglob("*") if is_audio_file(file)] + + logger.info(f"Found {len(audio_files)} audio files.") if output_dir.exists(): logger.warning(f"Output directory {output_dir} already exists, deleting...") shutil.rmtree(output_dir) @@ -216,7 +221,7 @@ def process_queue( error_queue: Queue[tuple[Path, Exception]] = Queue() # ファイル数が少ない場合は、ワーカー数をファイル数に合わせる - num_processes = min(num_processes, len(wav_files)) + num_processes = min(num_processes, len(audio_files)) threads = [ Thread(target=process_queue, args=(q, result_queue, error_queue)) @@ -225,14 +230,14 @@ def process_queue( for t in threads: t.start() - pbar = tqdm(total=len(wav_files), file=SAFE_STDOUT) - for file in wav_files: + pbar = tqdm(total=len(audio_files), file=SAFE_STDOUT) + for file in audio_files: q.put(file) # result_queueを監視し、要素が追加されるごとに結果を加算しプログレスバーを更新 total_sec = 0 total_count = 0 - for _ in range(len(wav_files)): + for _ in range(len(audio_files)): time, count = result_queue.get() total_sec += time total_count += count From d80879396020aa639b0e625335741ad6b9b2c759 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 19:56:55 +0900 Subject: [PATCH 139/148] Change default HF whisper batch size to 16 --- webui/dataset.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/webui/dataset.py b/webui/dataset.py index 1d8367c10..4a270f65a 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -194,6 +194,15 @@ def create_dataset_app() -> gr.Blocks: label="計算精度", value="bfloat16", ) + batch_size = gr.Slider( + minimum=1, + maximum=128, + value=16, + step=1, + label="バッチサイズ", + info="大きくすると速度が速くなるがVRAMを多く使う", + visible=False, + ) device = gr.Radio(["cuda", "cpu"], label="デバイス", value="cuda") language = gr.Dropdown(["ja", "en", "zh"], value="ja", label="言語") initial_prompt = gr.Textbox( @@ -209,15 +218,6 @@ def create_dataset_app() -> gr.Blocks: label="ビームサーチのビーム数", info="小さいほど速度が上がる(以前は5)", ) - batch_size = gr.Slider( - minimum=1, - maximum=128, - value=32, - step=1, - label="バッチサイズ", - info="大きくすると速度が速くなるがVRAMを多く使う", - visible=False, - ) transcribe_button = gr.Button("音声の文字起こし") result2 = gr.Textbox(label="結果") slice_button.click( @@ -249,9 +249,9 @@ def create_dataset_app() -> gr.Blocks: outputs=[result2], ) use_hf_whisper.change( - lambda x: (gr.update(visible=x), gr.update(visible=not x)), + lambda x: (gr.update(visible=x), gr.update(visible=not x), gr.update(visible=not x)), inputs=[use_hf_whisper], - outputs=[batch_size, compute_type], + outputs=[batch_size, compute_type, device], ) return app From d808cd635de4992168ed3ce8393c64325823c36c Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 19:58:16 +0900 Subject: [PATCH 140/148] Delete num_process option for slice for simplicity --- webui/dataset.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/webui/dataset.py b/webui/dataset.py index 4a270f65a..fe14030ac 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -11,7 +11,6 @@ def do_slice( min_silence_dur_ms: int, time_suffix: bool, input_dir: str, - num_processes: int = 3, ): if model_name == "": return "Error: モデル名を入力してください。" @@ -26,8 +25,6 @@ def do_slice( str(max_sec), "--min_silence_dur_ms", str(min_silence_dur_ms), - "--num_processes", - str(num_processes), ] if time_suffix: cmd.append("--time_suffix") @@ -152,14 +149,6 @@ def create_dataset_app() -> gr.Blocks: value=False, label="WAVファイル名の末尾に元ファイルの時間範囲を付与する", ) - num_processes = gr.Slider( - minimum=1, - maximum=10, - value=3, - step=1, - label="並列処理数(速度向上のため)", - info="3で十分高速、多くしてもCPU負荷が増すだけでそこまで速度は変わらない", - ) slice_button = gr.Button("スライスを実行") result1 = gr.Textbox(label="結果") with gr.Row(): @@ -229,7 +218,6 @@ def create_dataset_app() -> gr.Blocks: min_silence_dur_ms, time_suffix, input_dir, - num_processes, ], outputs=[result1], ) From 3efcb98858003b5577f1c197848964858f491155 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 20:16:58 +0900 Subject: [PATCH 141/148] docs --- colab.ipynb | 4 +-- docs/CHANGELOG.md | 53 +++++++++++++++++++++++++++++ docs/CLI.md | 15 +++++--- library.ipynb | 29 +++++++++++++--- preprocess_text.py | 1 - scripts/Clean.bat | 2 +- scripts/Update-Style-Bert-VITS2.bat | 14 ++++---- webui/dataset.py | 3 +- webui/merge.py | 2 +- 9 files changed, 102 insertions(+), 21 deletions(-) diff --git a/colab.ipynb b/colab.ipynb index 5e80f526d..27ccbc180 100644 --- a/colab.ipynb +++ b/colab.ipynb @@ -43,7 +43,7 @@ "!git clone https://github.com/litagin02/Style-Bert-VITS2.git\n", "%cd Style-Bert-VITS2/\n", "!pip install -r requirements.txt\n", - "!apt install libcublas11\n", + "# !apt install libcublas11\n", "!python initialize.py --skip_jvnv" ] }, @@ -121,7 +121,7 @@ "initial_prompt = \"こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!\"\n", "\n", "!python slice.py -i {input_dir} --model_name {model_name}\n", - "!python transcribe.py --model_name {model_name} --compute_type float16 --initial_prompt {initial_prompt}" + "!python transcribe.py --model_name {model_name} --initial_prompt {initial_prompt} --use_hf_whisper" ] }, { diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index dee5ee491..3e3ed0ca0 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,5 +1,58 @@ # Changelog +## v2.4.0 (2024-03-15) + +大規模リファクタリング・日本語処理のワーカー化と機能追加等。データセット作り・学習・音声合成・マージ・スタイルWebUIは全て`app.py` (`App.bat`) へ統一されましたのでご注意ください。 + +### アップデート手順 +- 2.3未満(辞書・エディター追加前)からのアップデートの場合は、[Update-to-Dict-Editor.bat](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.4.0/Update-to-Dict-Editor.bat)をダウンロードし、`Style-Bert-VITS2`フォルダがある場所(インストールbatファイルとかがあったところ)においてダブルクリックしてください。 +- それ以外の場合は、単純に今までの`Update-Style-Bert-VITS2.bat`でアップデートできます。 +- ただしアップデートにより多くのファイルが移動したり不要になったりしたので、それらを削除したい場合は[Clean.bat](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.4.0/Clean.bat)を`Update-Style-Bert-VITS2.bat`と同じ場所に保存して実行してください。 + +### 内部改善 + +- [tsukumijimaさんによる大規模リファクタリングのプルリク](https://github.com/litagin02/Style-Bert-VITS2/pull/92) によって、内部コードが非常に整理され可読性が高まりライブラリ化もされた。[tsukumijimaさん](https://github.com/tsukumijima) 大変な作業を本当にありがとうございます! +- ライブラリとして`pip install style-bert-vits2`によりすぐにインストールでき、音声合成部分の機能が使えます(使用例は[/library.ipynb](/library.ipynb)を参照してください) +- その他このプルリクに動機づけられ、多くのコードのリファクタリング・型アノテーションの追加等を行った +- 日本語処理のpyopenjtalkをソケット通信を用いて別プロセス化し、複数同時に学習や音声合成を立ち上げても辞書の競合エラーが起きないように。[kale4eat](https://github.com/kale4eat) さんによる[PR](https://github.com/litagin02/Style-Bert-VITS2/pull/89) で + +### バグ修正 + +- 上記にもある通り、音声合成と学習前処理など、日本語処理を扱うものを2つ以上起動しようとするとエラーが発生する仕様の解決。ユーザー辞書は追加すれば常にどこからでも適応されます。 +- `raw`フォルダの直下でなくサブフォルダ内に音声ファイルがある場合に、`wavs`フォルダでもその構造が保たれてしまい、書き起こしファイルとの整合性が取れなくなる挙動を修正し、常に`wav`フォルダ直下へ`wav`ファイルを保存するように変更 +- スライス時に元ファイル名にピリオド `.` が含まれると、スライス後のファイル名がおかしくなるバグの修正 + +### 機能改善・追加 + +- 各種WebUIを一つ`app.py` `App.bat` に統一 +- その他以下の変更や、軽微なUI・説明文の改善等 + +**データセット作成** + +- スライス処理の高速化(マルチスレッドにした、大量にスライス元ファイルファイルがある場合に高速になります)、またスライス元のファイルを`wav`以外の`mp3`や`ogg`などの形式にも対応 +- スライス処理時に、ファイル名にスライスされた開始終了区間を含めるオプションを追加([aka7774](https://github.com/aka7774) さんによるPRです、ありがとうございます!) +- 書き起こしの高速化、またHugging FaceのWhisperモデルを使うオプションを追加。バッチサイズを上げることでVRAMを食う代わりに速度が大幅に向上します。 + +**学習** + +- 学習元の音声ファイル(`Data/モデル名/raw`にいれるやつ)を、`wav`以外の`mp3`や`ogg`などの形式にも対応(前処理段階で自動的に`wav`ファイルに変換されます)(ただし変わらず1ファイル2-12秒程度の範囲の長さが望ましい) + +**音声合成** + +- 音声合成時に、生成音声の音の高さ(音高)と抑揚の幅を調整できるように(ただし音質が少し劣化する)。`App.bat`や`Editor.bat`のどちらからでも使えます。 +- `Editor.bat`の複数話者モデルでの話者指定を可能に +- `Editor.bat`で、改行を含む文字列をペーストすると自動的に欄が増えるように。また「↑↓」キーで欄を追加・行き来できるように +- `Editor.bat`でモデル一覧のリロードをメニューに追加 + +**API** + +- `server_fastapi.py`の音声合成エンドポイント`/voice`について、GETメソッドに加えてPOSTメソッドを追加。GETメソッドでは多くの制約があるようなのでPOSTを使うことが推奨されます。 + +**CLI** + +- `preprocess_text.py`で、書き起こしファイルでの音声ファイル名を自動的に正しい`Data/モデル名/wavs/`へ書き換える`--correct_path`オプションの追加(WebUIでは今までもこの挙動でした) +- その他上述のデータセット作成の機能追加に伴うCLIのオプションの追加(詳しくは[CLI.md](/docs/CLI.md)を参照) + ## v2.3.1 (2024-02-27) ### バグ修正 diff --git a/docs/CLI.md b/docs/CLI.md index 95ab8b501..537e45063 100644 --- a/docs/CLI.md +++ b/docs/CLI.md @@ -24,9 +24,11 @@ Optional: ## 1. Dataset preparation -### 1.1. Slice wavs +### 1.1. Slice audio files + +The following audio formats are supported: ".wav", ".flac", ".mp3", ".ogg", ".opus". ```bash -python slice.py --model_name [-i ] [-m ] [-M ] +python slice.py --model_name [-i ] [-m ] [-M ] [--time_suffix] ``` Required: @@ -36,8 +38,9 @@ Optional: - `input_dir`: Path to the directory containing the audio files to slice (default: `inputs`) - `min_sec`: Minimum duration of the sliced audio files in seconds (default: 2). - `max_sec`: Maximum duration of the sliced audio files in seconds (default: 12). +- `--time_suffix`: Make the filename end with -start_ms-end_ms when saving wav. -### 1.2. Transcribe wavs +### 1.2. Transcribe audio files ```bash python transcribe.py --model_name @@ -50,7 +53,11 @@ Optional - `--device`: `cuda` or `cpu` (default: `cuda`). - `--language`: `jp`, `en`, or `en` (default: `jp`). - `--model`: Whisper model, default: `large-v3` -- `--compute_type`: default: `bfloat16` +- `--compute_type`: default: `bfloat16`. Only used if not `--use_hf_whisper`. +- `--use_hf_whisper`: Use Hugging Face's whisper model instead of default faster-whisper (HF whisper is faster but requires more VRAM). +- `--batch_size`: Batch size (default: 16). Only used if `--use_hf_whisper`. +- `--num_beams`: Beam size (default: 1). +- `--no_repeat_ngram_size`: N-gram size for no repeat (default: 10). ## 2. Preprocess diff --git a/library.ipynb b/library.ipynb index 81a102894..753059a79 100644 --- a/library.ipynb +++ b/library.ipynb @@ -6,7 +6,19 @@ "source": [ "# Style-Bert-VITS2ライブラリの使用例\n", "\n", - "`pip install style-bert-vits2`を使った、colabで動く使用例です。" + "`pip install style-bert-vits2`を使った、jupyter notebookでの使用例です。Google colab等でも動きます。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# PyTorch環境の構築(ない場合)\n", + "# 参照: https://pytorch.org/get-started/locally/\n", + "\n", + "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118" ] }, { @@ -17,7 +29,9 @@ }, "outputs": [], "source": [ - "!pip install style-bert-vits2==2.4.dev0" + "# style-bert-vits2のインストール\n", + "\n", + "!pip install style-bert-vits2" ] }, { @@ -28,6 +42,8 @@ }, "outputs": [], "source": [ + "# BERTモデルをロード(ローカルに手動でダウンロードする必要はありません)\n", + "\n", "from style_bert_vits2.nlp import bert_models\n", "from style_bert_vits2.constants import Languages\n", "\n", @@ -48,6 +64,9 @@ }, "outputs": [], "source": [ + "# Hugging Faceから試しにデフォルトモデルをダウンロードしてみて、それを音声合成に使ってみる\n", + "# model_assetsディレクトリにダウンロードされます\n", + "\n", "from pathlib import Path\n", "from huggingface_hub import hf_hub_download\n", "\n", @@ -73,14 +92,16 @@ }, "outputs": [], "source": [ + "# 上でダウンロードしたモデルファイルを指定して音声合成のテスト\n", + "\n", "from style_bert_vits2.tts_model import TTSModel\n", "\n", "assets_root = Path(\"model_assets\")\n", "\n", "model = TTSModel(\n", " model_path=assets_root / model_file,\n", - " config_path = assets_root / config_file,\n", - " style_vec_path = assets_root / style_file,\n", + " config_path=assets_root / config_file,\n", + " style_vec_path=assets_root / style_file,\n", " device=\"cpu\"\n", ")" ] diff --git a/preprocess_text.py b/preprocess_text.py index 65fc77b9d..c75e2306c 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -227,7 +227,6 @@ def preprocess( parser.add_argument("--correct_path", action="store_true") args = parser.parse_args() - logger.debug(f"args: {args}") transcription_path = Path(args.transcription_path) cleaned_path = Path(args.cleaned_path) if args.cleaned_path else None diff --git a/scripts/Clean.bat b/scripts/Clean.bat index 0a3af3471..aed7a1df9 100644 --- a/scripts/Clean.bat +++ b/scripts/Clean.bat @@ -28,7 +28,7 @@ echo Style-Bert-VITS2\webui_style_vectors.py echo Style-Bert-VITS2\webui_train.py echo Style-Bert-VITS2\webui.py echo. -set /p delConfirm=削除しますか? (y/n): +set /p delConfirm=以上のフォルダファイルを削除しますか? (y/n): if /I "%delConfirm%"=="Y" goto proceed if /I "%delConfirm%"=="y" goto proceed if "%delConfirm%"=="" goto proceed diff --git a/scripts/Update-Style-Bert-VITS2.bat b/scripts/Update-Style-Bert-VITS2.bat index 04f593720..dc00cea02 100644 --- a/scripts/Update-Style-Bert-VITS2.bat +++ b/scripts/Update-Style-Bert-VITS2.bat @@ -10,12 +10,12 @@ if not exist %CURL_CMD% ( pause & popd & exit /b 1 ) -@REM Style-Bert-VITS2.zip をGitHubのdev-refactorの最新のものをダウンロード +@REM Style-Bert-VITS2.zip をGitHubのdevの最新のものをダウンロード %CURL_CMD% -Lo Style-Bert-VITS2.zip^ - https://github.com/litagin02/Style-Bert-VITS2/archive/refs/heads/dev-refactor.zip + https://github.com/litagin02/Style-Bert-VITS2/archive/refs/heads/dev.zip if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) -@REM Style-Bert-VITS2.zip を解凍(フォルダ名前がBert-VITS2-dev-refactorになる) +@REM Style-Bert-VITS2.zip を解凍(フォルダ名前がBert-VITS2-devになる) %PS_CMD% Expand-Archive -Path Style-Bert-VITS2.zip -DestinationPath . -Force if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) @@ -23,9 +23,9 @@ if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) del Style-Bert-VITS2.zip if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) -@REM Bert-VITS2-dev-refactorの中身をStyle-Bert-VITS2に上書き移動 -xcopy /QSY .\Style-Bert-VITS2-dev-refactor\ .\Style-Bert-VITS2\ -rmdir /s /q Style-Bert-VITS2-dev-refactor +@REM Bert-VITS2-devの中身をStyle-Bert-VITS2に上書き移動 +xcopy /QSY .\Style-Bert-VITS2-dev\ .\Style-Bert-VITS2\ +rmdir /s /q Style-Bert-VITS2-dev if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) @REM 仮想環境のpip requirements.txtを更新 @@ -37,7 +37,9 @@ if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) pip install -U -r Style-Bert-VITS2\requirements.txt if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) +echo ---------------------------------------- echo Update completed. Running Style-Bert-VITS2 Editor... +echo ---------------------------------------- @REM Style-Bert-VITS2フォルダに移動 pushd Style-Bert-VITS2 diff --git a/webui/dataset.py b/webui/dataset.py index fe14030ac..33da3ff98 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -73,8 +73,7 @@ def do_transcribe( cmd.extend(["--batch_size", str(batch_size)]) success, message = run_script_with_log(cmd) if not success: - return f"Error: {message}. しかし何故かエラーが起きても正常に終了している場合がほとんどなので、書き起こし結果を確認して問題なければ学習に使えます。" - return "音声の文字起こしが完了しました。" + return f"Error: {message}. エラーメッセージが空の場合、何も問題がない可能性があるので、書き起こしファイルをチェックして問題なければ無視してください。" how_to_md = """ diff --git a/webui/merge.py b/webui/merge.py index c7f8eeede..661f23369 100644 --- a/webui/merge.py +++ b/webui/merge.py @@ -38,7 +38,7 @@ def merge_style(model_name_a, model_name_b, weight, output_name, style_triple_li sorted_list = sorted(style_triple_list, key=lambda x: x[2] != DEFAULT_STYLE) else: # 存在しない場合、エラーを発生 - raise ValueError("No element with {DEFAULT_STYLE} output style name found.") + raise ValueError(f"No element with {DEFAULT_STYLE} output style name found.") style_vectors_a = np.load( os.path.join(assets_root, model_name_a, "style_vectors.npy") From 5100b4a57e971a4f27fc8ce6fc3cac88ba0ae78e Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 20:30:22 +0900 Subject: [PATCH 142/148] docs --- README.md | 15 ++++++++++----- colab.ipynb | 2 +- docs/CHANGELOG.md | 1 + 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index db2043d95..93e8eec23 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ Bert-VITS2 with more controllable voice styles. https://github.com/litagin02/Style-Bert-VITS2/assets/139731664/e853f9a2-db4a-4202-a1dd-56ded3c562a0 +You can install via `pip install style-bert-vits2` (inference only), see [library.ipynb](/library.ipynb) for example usage. + - **解説チュートリアル動画** [YouTube](https://youtu.be/aTUSzgDl1iY) [ニコニコ動画](https://www.nicovideo.jp/watch/sm43391524) - [English README](docs/README_en.md) - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/litagin02/Style-Bert-VITS2/blob/master/colab.ipynb) @@ -12,6 +14,7 @@ https://github.com/litagin02/Style-Bert-VITS2/assets/139731664/e853f9a2-db4a-420 - [**リリースページ**](https://github.com/litagin02/Style-Bert-VITS2/releases/)、[更新履歴](/docs/CHANGELOG.md) + - 2024-03-15: ver 2.4.0 (大規模リファクタリングや種々の改良、ライブラリ化) - 2024-02-26: ver 2.3 (辞書機能とエディター機能) - 2024-02-09: ver 2.2 - 2024-02-07: ver 2.1 @@ -42,6 +45,8 @@ CLIでの使い方は[こちら](/docs/CLI.md)を参照してください。 ### インストール +Pythonライブラリとしてのpipでのインストールや使用例は[library.ipynb](/library.ipynb)を参照してください。 + #### GitやPythonに馴染みが無い方 Windowsを前提としています。 @@ -107,8 +112,8 @@ model_assets #### データセット作り -- `Dataset.bat`をダブルクリックか`python webui_dataset.py`すると、音声ファイルからデータセットを作るためのWebUIが起動します(音声ファイルを適切な長さにスライスし、その後に文字の書き起こしを自動で行います)。 -- 指示に従った後、閉じて下の「学習WebUI」でそのまま学習を行うことができます。 +- `App.bat`をダブルクリックか`python app.py`したところの「データセット作成」タブから、音声ファイルを適切な長さにスライスし、その後に文字の書き起こしを自動で行えます。 +- 指示に従った後、下の「学習」タブでそのまま学習を行うことができます。 注意: データセットの手動修正やノイズ除去等、細かい修正を行いたい場合は[Aivis](https://github.com/tsukumijima/Aivis)や、そのデータセット部分のWindows対応版 [Aivis Dataset](https://github.com/litagin02/Aivis-Dataset) を使うといいかもしれません。ですがファイル数が多い場合などは、このツールで簡易的に切り出してデータセットを作るだけでも十分という気もしています。 @@ -116,12 +121,12 @@ model_assets #### 学習WebUI -- `Train.bat`をダブルクリックか`python webui_train.py`するとWebUIが起動するので指示に従ってください。 +- `App.bat`をダブルクリックか`python app.py`して開くWebUIの「学習」タブから指示に従ってください。 ### スタイルの生成 - デフォルトスタイル「Neutral」以外のスタイルを使いたい人向けです。 -- `Style.bat`をダブルクリックか`python webui_style_vectors.py`するとWebUIが起動します。 +- `App.bat`をダブルクリックか`python app.py`して開くWebUIの「スタイル作成」タブから、音声ファイルを使ってスタイルを生成できます。 - 学習とは独立しているので、学習中でもできるし、学習が終わっても何度もやりなおせます(前処理は終わらせている必要があります)。 - スタイルについての仕様の詳細は[clustering.ipynb](clustering.ipynb)を参照してください。 @@ -140,7 +145,7 @@ API仕様は起動後に`/docs`にて確認ください。 ### マージ 2つのモデルを、「声質」「声の高さ」「感情表現」「テンポ」の4点で混ぜ合わせて、新しいモデルを作ることが出来ます。 -`Merge.bat`をダブルクリックか`python webui_merge.py`するとWebUIが起動します。 +`App.bat`をダブルクリックか`python app.py`して開くWebUIの「マージ」タブから、2つのモデルを選択してマージすることができます。 ### 自然性評価 diff --git a/colab.ipynb b/colab.ipynb index 27ccbc180..cd21c3c2d 100644 --- a/colab.ipynb +++ b/colab.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Style-Bert-VITS2 (ver 2.4) のGoogle Colabでの学習\n", + "# Style-Bert-VITS2 (ver 2.4.0) のGoogle Colabでの学習\n", "\n", "Google Colab上でStyle-Bert-VITS2の学習を行うことができます。\n", "\n", diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 3e3ed0ca0..95c9b4507 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -46,6 +46,7 @@ **API** +- `server_fastapi.py`の実行時に全てのモデルファイルを読み込もうとする挙動を修正。音声合成がリクエストされて初めてそのモデルを読み込むように変更(APIを使わない音声合成のときと同じ挙動) - `server_fastapi.py`の音声合成エンドポイント`/voice`について、GETメソッドに加えてPOSTメソッドを追加。GETメソッドでは多くの制約があるようなのでPOSTを使うことが推奨されます。 **CLI** From f796af5fbb77d48cb02b1434b642ecd51c871f02 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 21:07:56 +0900 Subject: [PATCH 143/148] Remove loading bert since necesary model is loaded automatically --- bert_gen.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/bert_gen.py b/bert_gen.py index 5619ee624..1dcab9eb2 100644 --- a/bert_gen.py +++ b/bert_gen.py @@ -20,13 +20,6 @@ from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT -bert_models.load_model(Languages.JP) -bert_models.load_tokenizer(Languages.JP) -bert_models.load_model(Languages.EN) -bert_models.load_tokenizer(Languages.EN) -bert_models.load_model(Languages.ZH) -bert_models.load_tokenizer(Languages.ZH) - # このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 pyopenjtalk_worker.initialize_worker() From e42991f7e3881f112c6d7198d655ebfc45398999 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 21:10:24 +0900 Subject: [PATCH 144/148] Remove loading bert since necessary model is loaded automatically --- style_bert_vits2/tts_model.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index 01c6fb420..6d04c5a68 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -380,12 +380,6 @@ def get_model(self, model_name: str, model_path_str: str) -> TTSModel: def get_model_for_gradio( self, model_name: str, model_path_str: str ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: - bert_models.load_model(Languages.JP) - bert_models.load_tokenizer(Languages.JP) - bert_models.load_model(Languages.EN) - bert_models.load_tokenizer(Languages.EN) - bert_models.load_model(Languages.ZH) - bert_models.load_tokenizer(Languages.ZH) model_path = Path(model_path_str) if model_name not in self.model_files_dict: From 1a8a7edb9ccdc97f34846d20e74be9246483c52e Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 21:25:24 +0900 Subject: [PATCH 145/148] Docs and bump default pytorch version to 2.2.1 latest --- README.md | 10 +--------- colab.ipynb | 3 +-- docs/CLI.md | 2 +- docs/paperspace.md | 4 ++-- style_bert_vits2/tts_model.py | 1 - 5 files changed, 5 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 93e8eec23..4a40de25b 100644 --- a/README.md +++ b/README.md @@ -67,8 +67,7 @@ git clone https://github.com/litagin02/Style-Bert-VITS2.git cd Style-Bert-VITS2 python -m venv venv venv\Scripts\activate -# PyTorch 2.2.x系は今のところは学習エラーが出るので前のバージョンを使う -pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install -r requirements.txt python initialize.py # 必要なモデルとデフォルトTTSモデルをダウンロード ``` @@ -170,13 +169,6 @@ python speech_mos.py -m - safetensors形式のサポート、デフォルトでsafetensorsを使用するように - その他軽微なbugfixやリファクタリング -## TODO -- [x] デフォルトのJVNVモデルにJP-Extra版のものを追加 -- [x] LinuxやWSL等、Windowsの通常環境以外でのサポート ← おそらく問題ないとの報告あり -- [x] 複数話者学習での音声合成対応(学習は現在でも可能) -- [x] `server_fastapi.py`の対応、とくにAPIで使えるようになると嬉しい人が増えるのかもしれない -- [x] モデルのマージで声音と感情表現を混ぜる機能の実装 -- [ ] 英語等多言語対応? ## References In addition to the original reference (written below), I used the following repositories: diff --git a/colab.ipynb b/colab.ipynb index cd21c3c2d..3da742767 100644 --- a/colab.ipynb +++ b/colab.ipynb @@ -28,7 +28,7 @@ "\n", "Style-Bert-VITS2の環境をcolab上に構築します。グラボモードが有効になっていることを確認し、以下のセルを順に実行してください。\n", "\n", - "最近のcolabのアップデートにより、エラーダイアログ「WARNING: The following packages were previously imported in this runtime: [pydevd_plugins]」が出るが、「キャンセル」を選択して続行してください。" + "**最近のcolabのアップデートにより、エラーダイアログ「WARNING: The following packages were previously imported in this runtime: [pydevd_plugins]」が出るが、「キャンセル」を選択して続行してください。**" ] }, { @@ -43,7 +43,6 @@ "!git clone https://github.com/litagin02/Style-Bert-VITS2.git\n", "%cd Style-Bert-VITS2/\n", "!pip install -r requirements.txt\n", - "# !apt install libcublas11\n", "!python initialize.py --skip_jvnv" ] }, diff --git a/docs/CLI.md b/docs/CLI.md index 537e45063..97d296aa0 100644 --- a/docs/CLI.md +++ b/docs/CLI.md @@ -7,7 +7,7 @@ git clone https://github.com/litagin02/Style-Bert-VITS2.git cd Style-Bert-VITS2 python -m venv venv venv\Scripts\activate -pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install -r requirements.txt ``` diff --git a/docs/paperspace.md b/docs/paperspace.md index 2ae9895e4..4026031b7 100644 --- a/docs/paperspace.md +++ b/docs/paperspace.md @@ -28,7 +28,7 @@ git clone https://github.com/litagin02/Style-Bert-VITS2.git 環境構築(デフォルトはPyTorch 1.x系、Python 3.9の模様) ```bash cd /storage/sbv2/Style-Bert-VITS2 -pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 && pip install -r requirements.txt +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 && pip install -r requirements.txt ``` 事前学習済みモデル等のダウンロード、またパスを`/notebooks/`以下のものに設定 ```bash @@ -57,7 +57,7 @@ mkdir inputs unzip Foo.zip -d inputs cd /storage/sbv2/Style-Bert-VITS2 python slice.py --model_name Foo -i /notebooks/inputs -python transcribe.py --model_name Foo +python transcribe.py --model_name Foo --use_hf_whisper ``` それが終わったら、以下のコマンドで一括前処理を行う(パラメータは各自お好み、バッチサイズ5か6でVRAM 16GBギリくらい)。 diff --git a/style_bert_vits2/tts_model.py b/style_bert_vits2/tts_model.py index 6d04c5a68..12803086f 100644 --- a/style_bert_vits2/tts_model.py +++ b/style_bert_vits2/tts_model.py @@ -380,7 +380,6 @@ def get_model(self, model_name: str, model_path_str: str) -> TTSModel: def get_model_for_gradio( self, model_name: str, model_path_str: str ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: - model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") From 2ee4f4ac86bf6043a633ca0025be4f302a54ea17 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 21:27:22 +0900 Subject: [PATCH 146/148] Change dev to master in bat --- scripts/Update-Style-Bert-VITS2.bat | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/Update-Style-Bert-VITS2.bat b/scripts/Update-Style-Bert-VITS2.bat index dc00cea02..cf2733f13 100644 --- a/scripts/Update-Style-Bert-VITS2.bat +++ b/scripts/Update-Style-Bert-VITS2.bat @@ -10,12 +10,12 @@ if not exist %CURL_CMD% ( pause & popd & exit /b 1 ) -@REM Style-Bert-VITS2.zip をGitHubのdevの最新のものをダウンロード +@REM Style-Bert-VITS2.zip をGitHubのmasterの最新のものをダウンロード %CURL_CMD% -Lo Style-Bert-VITS2.zip^ - https://github.com/litagin02/Style-Bert-VITS2/archive/refs/heads/dev.zip + https://github.com/litagin02/Style-Bert-VITS2/archive/refs/heads/master.zip if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) -@REM Style-Bert-VITS2.zip を解凍(フォルダ名前がBert-VITS2-devになる) +@REM Style-Bert-VITS2.zip を解凍(フォルダ名前がBert-VITS2-masterになる) %PS_CMD% Expand-Archive -Path Style-Bert-VITS2.zip -DestinationPath . -Force if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) @@ -23,9 +23,9 @@ if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) del Style-Bert-VITS2.zip if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) -@REM Bert-VITS2-devの中身をStyle-Bert-VITS2に上書き移動 -xcopy /QSY .\Style-Bert-VITS2-dev\ .\Style-Bert-VITS2\ -rmdir /s /q Style-Bert-VITS2-dev +@REM Bert-VITS2-masterの中身をStyle-Bert-VITS2に上書き移動 +xcopy /QSY .\Style-Bert-VITS2-master\ .\Style-Bert-VITS2\ +rmdir /s /q Style-Bert-VITS2-master if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) @REM 仮想環境のpip requirements.txtを更新 From 14c68053374e6b968c9c45c0e7b8350effd68a5b Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 21:36:53 +0900 Subject: [PATCH 147/148] Update hatch setting and fmt --- preprocess_text.py | 1 + pyproject.toml | 6 +++--- webui/dataset.py | 6 +++++- webui/style_vectors.py | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/preprocess_text.py b/preprocess_text.py index c75e2306c..a65a1d6be 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -14,6 +14,7 @@ from style_bert_vits2.nlp.japanese.user_dict import update_dict from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT + # このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 pyopenjtalk_worker.initialize_worker() diff --git a/pyproject.toml b/pyproject.toml index 3045f039e..2d26680f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ 'pyworld-prebuilt', 'safetensors', 'scipy', - 'torch>=2.1,<2.2', + 'torch>=2.1', 'transformers', ] @@ -102,11 +102,11 @@ dependencies = [ [tool.hatch.envs.style.scripts] check = [ "black --check --diff .", - "isort --check-only --diff --profile black --gitignore --lai 2 .", + "isort --check-only --diff --profile black --gitignore --lai 2 . --sg \"Data/*\" --sg \"inputs/*\" --sg \"model_assets/*\" --sg \"static/*\"", ] fmt = [ "black .", - "isort --profile black --gitignore --lai 2 .", + "isort --profile black --gitignore --lai 2 . --sg \"Data/*\" --sg \"inputs/*\" --sg \"model_assets/*\" --sg \"static/*\"", "check", ] diff --git a/webui/dataset.py b/webui/dataset.py index 33da3ff98..21b1063b5 100644 --- a/webui/dataset.py +++ b/webui/dataset.py @@ -236,7 +236,11 @@ def create_dataset_app() -> gr.Blocks: outputs=[result2], ) use_hf_whisper.change( - lambda x: (gr.update(visible=x), gr.update(visible=not x), gr.update(visible=not x)), + lambda x: ( + gr.update(visible=x), + gr.update(visible=not x), + gr.update(visible=not x), + ), inputs=[use_hf_whisper], outputs=[batch_size, compute_type, device], ) diff --git a/webui/style_vectors.py b/webui/style_vectors.py index e125c3fe9..05056eb03 100644 --- a/webui/style_vectors.py +++ b/webui/style_vectors.py @@ -6,12 +6,12 @@ import matplotlib.pyplot as plt import numpy as np import yaml -from config import config from scipy.spatial.distance import pdist, squareform from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans from sklearn.manifold import TSNE from umap import UMAP +from config import config from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME from style_bert_vits2.logging import logger From ab4020545ef4a95fa493b4c6f9d01bb93db2fa63 Mon Sep 17 00:00:00 2001 From: litagin02 Date: Fri, 15 Mar 2024 21:43:57 +0900 Subject: [PATCH 148/148] Bump torch ver and docs --- Dockerfile.train | 2 +- docs/CHANGELOG.md | 2 +- scripts/Install-Style-Bert-VITS2.bat | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Dockerfile.train b/Dockerfile.train index 333fd4634..59d565514 100644 --- a/Dockerfile.train +++ b/Dockerfile.train @@ -90,7 +90,7 @@ ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH # Based on https://pytorch.org/get-started/locally/ -RUN $PIP_INSTALL torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 +RUN $PIP_INSTALL torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 RUN $PIP_INSTALL jupyterlab diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 95c9b4507..240c8da4c 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -41,7 +41,7 @@ - 音声合成時に、生成音声の音の高さ(音高)と抑揚の幅を調整できるように(ただし音質が少し劣化する)。`App.bat`や`Editor.bat`のどちらからでも使えます。 - `Editor.bat`の複数話者モデルでの話者指定を可能に -- `Editor.bat`で、改行を含む文字列をペーストすると自動的に欄が増えるように。また「↑↓」キーで欄を追加・行き来できるように +- `Editor.bat`で、改行を含む文字列をペーストすると自動的に欄が増えるように。また「↑↓」キーで欄を追加・行き来できるように(エディター側で以前に既にアプデしていました) - `Editor.bat`でモデル一覧のリロードをメニューに追加 **API** diff --git a/scripts/Install-Style-Bert-VITS2.bat b/scripts/Install-Style-Bert-VITS2.bat index 59fe97ec3..d674a9e5f 100644 --- a/scripts/Install-Style-Bert-VITS2.bat +++ b/scripts/Install-Style-Bert-VITS2.bat @@ -41,7 +41,8 @@ call Style-Bert-VITS2\scripts\Setup-Python.bat ..\..\lib\python ..\venv if %errorlevel% neq 0 ( popd & exit /b %errorlevel% ) @REM 依存関係インストール -pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 +pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 + if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) pip install -r Style-Bert-VITS2\requirements.txt