diff --git a/.dockerignore b/.dockerignore index 7ff74dab1..a90abd2db 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,43 +1,25 @@ -# Dockerfile.deploy用 - -*.pyc -*.pyo -*.pyd -__pycache__ -*.pyc - -venv/ -.vscode/ - -.ipynb_checkpoints/ -*.ipynb - -.git/ -.gitignore - -Dockerfile* -.dockerignore -*.md -*.bat -LICENSE - -*.wav -*.zip -*.csv - -# 中国語と英語が必要な場合はコメントアウト -/bert/chinese-roberta-wwm-ext-large/ -/bert/deberta-v3-large/ - -Data/ -dict_data/user_dic.json -dict_data/user_dic.dic -docs/ -inputs/ -mos_results/ -pretrained/ -pretrained_jp_extra/ -scripts/ -slm/ -static/ -tools/ +# Dockerfile.deploy用の.dockerignore +# 日本語のJP-Extraのエディター稼働のみに必要なファイルを指定する + +* + +!/bert/deberta-v2-large-japanese-char-wwm/ +!/common/ +!/configs/ +!/dict_data/default.csv +!/model_assets/ +!/monotonic_align/ +!/text/ + +!/attentions.py +!/commons.py +!/config.py +!/default_config.yml +!/infer.py +!/models.py +!/models_jp_extra.py +!/modules.py +!/requirements.txt +!/server_editor.py +!/transforms.py +!/utils.py diff --git a/app.py b/app.py index e17a0a56a..1514444f7 100644 --- a/app.py +++ b/app.py @@ -96,6 +96,8 @@ def tts_fn( start_time = datetime.datetime.now() + assert model_holder.current_model is not None + try: sr, audio = model_holder.current_model.infer( text=text, diff --git a/colab.ipynb b/colab.ipynb index f4e932b25..a48affce7 100644 --- a/colab.ipynb +++ b/colab.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Style-Bert-VITS2 (ver 2.3) のGoogle Colabでの学習\n", + "# Style-Bert-VITS2 (ver 2.3.1) のGoogle Colabでの学習\n", "\n", "Google Colab上でStyle-Bert-VITS2の学習を行うことができます。\n", "\n", @@ -118,8 +118,8 @@ "# こういうふうに書き起こして欲しいという例文(句読点の入れ方・笑い方や固有名詞等)\n", "initial_prompt = \"こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!\"\n", "\n", - "!python slice.py -i {input_dir} -o {dataset_root}/{model_name}/raw\n", - "!python transcribe.py -i {dataset_root}/{model_name}/raw -o {dataset_root}/{model_name}/esd.list --speaker_name {model_name} --compute_type float16 --initial_prompt {initial_prompt}" + "!python slice.py -i {input_dir} --model_name {model_name}\n", + "!python transcribe.py --model_name {model_name} --compute_type float16 --initial_prompt {initial_prompt}" ] }, { @@ -229,7 +229,11 @@ "normalize = False\n", "\n", "# 音声ファイルの開始・終了にある無音区間を削除するかどうか\n", - "trim = False" + "trim = False\n", + "\n", + "# 読みのエラーが出た場合にどうするか。\n", + "# \"raise\"ならテキスト前処理が終わったら中断、\"skip\"なら読めない行は学習に使わない、\"use\"なら無理やり使う\n", + "yomi_error = \"skip\"" ] }, { @@ -269,6 +273,7 @@ " use_jp_extra=use_jp_extra,\n", " val_per_lang=0,\n", " log_interval=200,\n", + " yomi_error=yomi_error\n", ")" ] }, diff --git a/common/constants.py b/common/constants.py index 751d5e123..fe620195d 100644 --- a/common/constants.py +++ b/common/constants.py @@ -4,7 +4,7 @@ # See https://huggingface.co/spaces/gradio/theme-gallery for more themes GRADIO_THEME: str = "NoCrypt/miku" -LATEST_VERSION: str = "2.3" +LATEST_VERSION: str = "2.3.1" USER_DICT_DIR = "dict_data" diff --git a/common/tts_model.py b/common/tts_model.py index de4e830df..e09787e7e 100644 --- a/common/tts_model.py +++ b/common/tts_model.py @@ -136,7 +136,6 @@ def infer( given_tone: Optional[list[int]] = None, pitch_scale: float = 1.0, intonation_scale: float = 1.0, - ignore_unknown: bool = False, ) -> tuple[int, np.ndarray]: logger.info(f"Start generating audio data from text:\n{text}") if language != "JP" and self.hps.version.endswith("JP-Extra"): @@ -174,7 +173,6 @@ def infer( assist_text_weight=assist_text_weight, style_vec=style_vector, given_tone=given_tone, - ignore_unknown=ignore_unknown, ) else: texts = text.split("\n") @@ -197,7 +195,6 @@ def infer( assist_text=assist_text, assist_text_weight=assist_text_weight, style_vec=style_vector, - ignore_unknown=ignore_unknown, ) ) if i != len(texts) - 1: diff --git a/configs/config.json b/configs/config.json index 25e86db6b..2f2ce7f0f 100644 --- a/configs/config.json +++ b/configs/config.json @@ -68,5 +68,5 @@ "use_spectral_norm": false, "gin_channels": 256 }, - "version": "2.3" + "version": "2.3.1" } diff --git a/configs/configs_jp_extra.json b/configs/configs_jp_extra.json index 616d1d31d..b566be165 100644 --- a/configs/configs_jp_extra.json +++ b/configs/configs_jp_extra.json @@ -75,5 +75,5 @@ "initial_channel": 64 } }, - "version": "2.3-JP-Extra" + "version": "2.3.1-JP-Extra" } diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index bb35e4941..dee5ee491 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## v2.3.1 (2024-02-27) + +### バグ修正 +- colabの学習用ノートブックが動かなかったのを修正 +- `App.bat`や`server_fastapi.py`では読めない文字でまだエラーが発生するようになっていたので、推論時は必ず読めない文字を無視して強引に読むように挙動を変更 + +### 改善 +- 読みが取得できない場合に、テキスト前処理完了時にエラーで中断する今までの挙動に加えて、「読み取得失敗ファイルを学習に使わずに進める」もしくは「読めない文字を無視して読んでファイルを学習に使い進める」というオプションを追加。 +- マージ方法に線形補間の他に球面線形補完を追加 ([@frodo821](https://github.com/frodo821) さんによるPRです、ありがとうございます!) +- デプロイ用`.dockerignore`を更新 + +### アップデート手順 +- 2.3未満からのアップデートの場合は、[Update-to-Dict-Editor.bat](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.3/Update-to-Dict-Editor.bat)をダウンロードし、`Style-Bert-VITS2`フォルダがある場所(インストールbatファイルとかがあったところ)においてダブルクリックしてください。 +- 2.3からのアップデートの場合は、単純に今までの`Update-Style-Bert-VITS2.bat`でアップデートできます。 + ## v2.3 (2024-02-26) ### 大きな変更 diff --git a/docs/CLI.md b/docs/CLI.md index 08e2fd03e..95ab8b501 100644 --- a/docs/CLI.md +++ b/docs/CLI.md @@ -55,7 +55,7 @@ Optional ## 2. Preprocess ```bash -python preprocess_all.py -m [--use_jp_extra] [-b ] [-e ] [-s ] [--num_processes ] [--normalize] [--trim] [--val_per_lang ] [--log_interval ] [--freeze_EN_bert] [--freeze_JP_bert] [--freeze_ZH_bert] [--freeze_style] [--freeze_decoder] +python preprocess_all.py -m [--use_jp_extra] [-b ] [-e ] [-s ] [--num_processes ] [--normalize] [--trim] [--val_per_lang ] [--log_interval ] [--freeze_EN_bert] [--freeze_JP_bert] [--freeze_ZH_bert] [--freeze_style] [--freeze_decoder] [--yomi_error ] ``` Required: @@ -76,6 +76,7 @@ Optional: - `--use_jp_extra`: Use JP-Extra model. - `--val_per_lang`: Validation data per language (default: 0). - `--log_interval`: Log interval (default: 200). +- `--yomi_error`: How to handle yomi errors (default: `raise`: raise an error after preprocessing all texts, `skip`: skip the texts with errors, `use`: use the texts with errors by ignoring unknown characters). ## 3. Train diff --git a/infer.py b/infer.py index 914a3554e..3707df1f9 100644 --- a/infer.py +++ b/infer.py @@ -52,11 +52,11 @@ def get_text( assist_text=None, assist_text_weight=0.7, given_tone=None, - ignore_unknown=False, ): use_jp_extra = hps.version.endswith("JP-Extra") + # 推論のときにのみ呼び出されるので、raise_yomi_errorはFalseに設定 norm_text, phone, tone, word2ph = clean_text( - text, language_str, use_jp_extra, ignore_unknown=ignore_unknown + text, language_str, use_jp_extra, raise_yomi_error=False ) if given_tone is not None: if len(given_tone) != len(phone): @@ -80,7 +80,6 @@ def get_text( device, assist_text, assist_text_weight, - ignore_unknown, ) del word2ph assert bert_ori.shape[-1] == len(phone), phone @@ -127,7 +126,6 @@ def infer( assist_text=None, assist_text_weight=0.7, given_tone=None, - ignore_unknown=False, ): is_jp_extra = hps.version.endswith("JP-Extra") bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( @@ -138,7 +136,6 @@ def infer( assist_text=assist_text, assist_text_weight=assist_text_weight, given_tone=given_tone, - ignore_unknown=ignore_unknown, ) if skip_start: phones = phones[3:] diff --git a/preprocess_all.py b/preprocess_all.py index 82a3c202c..fd3dcd5b7 100644 --- a/preprocess_all.py +++ b/preprocess_all.py @@ -74,6 +74,9 @@ help="Log interval", default=200, ) + parser.add_argument( + "--yomi_error", type=str, help="Yomi error. raise, skip, use", default="raise" + ) args = parser.parse_args() @@ -93,4 +96,5 @@ use_jp_extra=args.use_jp_extra, val_per_lang=args.val_per_lang, log_interval=args.log_interval, + yomi_error=args.yomi_error, ) diff --git a/preprocess_text.py b/preprocess_text.py index 2c26941e7..e0c3399f5 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -40,6 +40,7 @@ def count_lines(file_path: str): @click.option("--clean/--no-clean", default=preprocess_text_config.clean) @click.option("-y", "--yml_config") @click.option("--use_jp_extra", is_flag=True) +@click.option("--yomi_error", default="raise") def preprocess( transcription_path: str, cleaned_path: Optional[str], @@ -51,11 +52,15 @@ def preprocess( clean: bool, yml_config: str, # 这个不要删 use_jp_extra: bool, + yomi_error: str, ): + assert yomi_error in ["raise", "skip", "use"] if cleaned_path == "" or cleaned_path is None: cleaned_path = transcription_path + ".cleaned" error_log_path = os.path.join(os.path.dirname(cleaned_path), "text_error.log") + if os.path.exists(error_log_path): + os.remove(error_log_path) error_count = 0 if clean: @@ -66,8 +71,12 @@ def preprocess( try: utt, spk, language, text = line.strip().split("|") norm_text, phones, tones, word2ph = clean_text( - text, language, use_jp_extra + text=text, + language=language, + use_jp_extra=use_jp_extra, + raise_yomi_error=(yomi_error != "use"), ) + out_file.write( "{}|{}|{}|{}|{}|{}|{}\n".format( utt, @@ -151,12 +160,20 @@ def preprocess( with open(config_path, "w", encoding="utf-8") as f: json.dump(json_config, f, indent=2, ensure_ascii=False) if error_count > 0: - logger.error( - f"An error occurred in {error_count} lines. Please check {error_log_path} for details. You can proceed with lines that do not have errors." - ) - raise Exception( - f"An error occurred in {error_count} lines. Please check {error_log_path} for details. You can proceed with lines that do not have errors." - ) + if yomi_error == "skip": + logger.warning( + f"An error occurred in {error_count} lines. Proceed with lines without errors. Please check {error_log_path} for details." + ) + else: + # yom_error == "raise"と"use"の場合。 + # "use"の場合は、そもそもyomi_error = Falseで処理しているので、 + # ここが実行されるのは他の例外のときなので、エラーをraiseする。 + logger.error( + f"An error occurred in {error_count} lines. Please check {error_log_path} for details." + ) + raise Exception( + f"An error occurred in {error_count} lines. Please check {error_log_path} for details." + ) else: logger.info( "Training set and validation set generation from texts is complete!" diff --git a/server_editor.py b/server_editor.py index 93c7c58c6..afb9321df 100644 --- a/server_editor.py +++ b/server_editor.py @@ -214,7 +214,7 @@ async def read_item(item: TextRequest): try: # 最初に正規化しないと整合性がとれない text = text_normalize(item.text) - kata_tone_list = g2kata_tone(text, ignore_unknown=True) + kata_tone_list = g2kata_tone(text) except Exception as e: raise HTTPException( status_code=400, @@ -289,7 +289,6 @@ def synthesis(request: SynthesisRequest): assist_text_weight=request.assistTextWeight, use_assist_text=bool(request.assistText), line_split=False, - ignore_unknown=True, pitch_scale=request.pitchScale, intonation_scale=request.intonationScale, ) @@ -348,7 +347,6 @@ def multi_synthesis(request: MultiSynthesisRequest): assist_text_weight=req.assistTextWeight, use_assist_text=bool(req.assistText), line_split=False, - ignore_unknown=True, pitch_scale=req.pitchScale, intonation_scale=req.intonationScale, ) diff --git a/text/__init__.py b/text/__init__.py index 2151cafc5..d8ae88dea 100644 --- a/text/__init__.py +++ b/text/__init__.py @@ -18,28 +18,14 @@ def cleaned_text_to_sequence(cleaned_text, tones, language): return phones, tones, lang_ids -def get_bert( - text, - word2ph, - language, - device, - assist_text=None, - assist_text_weight=0.7, - ignore_unknown=False, -): +def get_bert(text, word2ph, language, device, assist_text=None, assist_text_weight=0.7): if language == "ZH": - from .chinese_bert import get_bert_feature as zh_bert - - return zh_bert(text, word2ph, device, assist_text, assist_text_weight) + from .chinese_bert import get_bert_feature elif language == "EN": - from .english_bert_mock import get_bert_feature as en_bert - - return en_bert(text, word2ph, device, assist_text, assist_text_weight) + from .english_bert_mock import get_bert_feature elif language == "JP": - from .japanese_bert import get_bert_feature as jp_bert - - return jp_bert( - text, word2ph, device, assist_text, assist_text_weight, ignore_unknown - ) + from .japanese_bert import get_bert_feature else: raise ValueError(f"Language {language} not supported") + + return get_bert_feature(text, word2ph, device, assist_text, assist_text_weight) diff --git a/text/cleaner.py b/text/cleaner.py index 8da4c7b6e..d805b5145 100644 --- a/text/cleaner.py +++ b/text/cleaner.py @@ -1,4 +1,4 @@ -def clean_text(text, language, use_jp_extra=True, ignore_unknown=False): +def clean_text(text, language, use_jp_extra=True, raise_yomi_error=False): # Changed to import inside if condition to avoid unnecessary import if language == "ZH": from . import chinese as language_module @@ -15,7 +15,7 @@ def clean_text(text, language, use_jp_extra=True, ignore_unknown=False): norm_text = language_module.text_normalize(text) phones, tones, word2ph = language_module.g2p( - norm_text, use_jp_extra, ignore_unknown=ignore_unknown + norm_text, use_jp_extra, raise_yomi_error=raise_yomi_error ) else: raise ValueError(f"Language {language} not supported") diff --git a/text/japanese.py b/text/japanese.py index 5a36eb69e..b18bc682e 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -33,6 +33,16 @@ VOWELS = {"a", "i", "u", "e", "o", "N"} +class YomiError(Exception): + """ + OpenJTalkで、読みが正しく取得できない箇所があるときに発生する例外。 + 基本的に「学習の前処理のテキスト処理時」には発生させ、そうでない場合は、 + ignore_yomi_error=Trueにしておいて、この例外を発生させないようにする。 + """ + + pass + + # 正規化で記号を変換するための辞書 rep_map = { ":": ",", @@ -166,7 +176,7 @@ def japanese_convert_numbers_to_words(text: str) -> str: def g2p( - norm_text: str, use_jp_extra: bool = True, ignore_unknown: bool = False + norm_text: str, use_jp_extra: bool = True, raise_yomi_error: bool = False ) -> tuple[list[str], list[int], list[int]]: """ 他で使われるメインの関数。`text_normalize()`で正規化された`norm_text`を受け取り、 @@ -175,7 +185,10 @@ def g2p( - word2ph: 元のテキストの各文字に音素が何個割り当てられるかを表すリスト のタプルを返す。 ただし`phones`と`tones`の最初と終わりに`_`が入り、応じて`word2ph`の最初と最後に1が追加される。 + use_jp_extra: Falseの場合、「ん」の音素を「N」ではなく「n」とする。 + raise_yomi_error: Trueの場合、読めない文字があるときに例外を発生させる。 + Falseの場合は読めない文字が消えたような扱いとして処理される。 """ # pyopenjtalkのフルコンテキストラベルを使ってアクセントを取り出すと、punctuationの位置が消えてしまい情報が失われてしまう: # 「こんにちは、世界。」と「こんにちは!世界。」と「こんにちは!!!???世界……。」は全て同じになる。 @@ -186,9 +199,9 @@ def g2p( # punctuationがすべて消えた、音素とアクセントのタプルのリスト(「ん」は「N」) phone_tone_list_wo_punct = g2phone_tone_wo_punct(norm_text) - # sep_text: 単語単位の単語のリスト + # sep_text: 単語単位の単語のリスト、読めない文字があったらraise_yomi_errorなら例外、そうでないなら読めない文字が消えて返ってくる # sep_kata: 単語単位の単語のカタカナ読みのリスト - sep_text, sep_kata = text2sep_kata(norm_text, ignore_unknown=ignore_unknown) + sep_text, sep_kata = text2sep_kata(norm_text, raise_yomi_error=raise_yomi_error) # sep_phonemes: 各単語ごとの音素のリストのリスト sep_phonemes = handle_long([kata2phoneme_list(i) for i in sep_kata]) @@ -237,8 +250,12 @@ def g2p( return phones, tones, word2ph -def g2kata_tone(norm_text: str, ignore_unknown: bool = False) -> list[tuple[str, int]]: - phones, tones, _ = g2p(norm_text, use_jp_extra=True, ignore_unknown=ignore_unknown) +def g2kata_tone(norm_text: str) -> list[tuple[str, int]]: + """ + テキストからカタカナとアクセントのペアのリストを返す。 + 推論時のみに使われるので、常に`raise_yomi_error=False`でg2pを呼ぶ。 + """ + phones, tones, _ = g2p(norm_text, use_jp_extra=True, raise_yomi_error=False) return phone_tone2kata_tone(list(zip(phones, tones))) @@ -332,7 +349,7 @@ def g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: def text2sep_kata( - norm_text: str, ignore_unknown: bool = False + norm_text: str, raise_yomi_error: bool = False ) -> tuple[list[str], list[str]]: """ `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、 @@ -341,6 +358,9 @@ def text2sep_kata( 例: `私はそう思う!って感じ?` → ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"] + + raise_yomi_error: Trueの場合、読めない文字があるときに例外を発生させる。 + Falseの場合は読めない文字が消えたような扱いとして処理される。 """ # parsed: OpenJTalkの解析結果 parsed = pyopenjtalk.run_frontend(norm_text) @@ -369,10 +389,10 @@ def text2sep_kata( # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`, `--` のいずれか if not set(word).issubset(set(punctuation)): # 記号繰り返しか判定 # ここはpyopenjtalkが読めない文字等のときに起こる - if ignore_unknown: - logger.error(f"Ignoring unknown: {word} in:\n{norm_text}") - continue - raise ValueError(f"Cannot read: {word} in:\n{norm_text}") + if raise_yomi_error: + raise YomiError(f"Cannot read: {word} in:\n{norm_text}") + logger.warning(f"Ignoring unknown: {word} in:\n{norm_text}") + continue # yomiは元の記号のままに変更 yomi = word elif yomi == "?": diff --git a/text/japanese_bert.py b/text/japanese_bert.py index 9efe7a46a..dcee0f3d2 100644 --- a/text/japanese_bert.py +++ b/text/japanese_bert.py @@ -19,12 +19,13 @@ def get_bert_feature( device=config.bert_gen_config.device, assist_text=None, assist_text_weight=0.7, - ignore_unknown=False, ): - text = "".join(text2sep_kata(text, ignore_unknown=ignore_unknown)[0]) - # text = text_normalize(text) + # 各単語が何文字かを作る`word2ph`を使う必要があるので、読めない文字は必ず無視する + # でないと`word2ph`の結果とテキストの文字数結果が整合性が取れない + text = "".join(text2sep_kata(text, raise_yomi_error=False)[0]) + if assist_text: - assist_text = "".join(text2sep_kata(assist_text)[0]) + assist_text = "".join(text2sep_kata(assist_text, raise_yomi_error=False)[0]) if ( sys.platform == "darwin" and torch.backends.mps.is_available() diff --git a/webui_merge.py b/webui_merge.py index 69f7e8dc0..a58471a80 100644 --- a/webui_merge.py +++ b/webui_merge.py @@ -102,6 +102,29 @@ def merge_style(model_name_a, model_name_b, weight, output_name, style_triple_li return output_style_path, list(new_style2id.keys()) +def lerp_tensors(t, v0, v1): + return v0 * (1 - t) + v1 * t + + +def slerp_tensors(t, v0, v1, dot_thres=0.998): + device = v0.device + v0c = v0.cpu().numpy() + v1c = v1.cpu().numpy() + + dot = np.sum(v0c * v1c / (np.linalg.norm(v0c) * np.linalg.norm(v1c))) + + if abs(dot) > dot_thres: + return lerp_tensors(t, v0, v1) + + th0 = np.arccos(dot) + sin_th0 = np.sin(th0) + th_t = th0 * t + + return torch.from_numpy( + v0c * np.sin(th0 - th_t) / sin_th0 + v1c * np.sin(th_t) / sin_th0 + ).to(device) + + def merge_models( model_path_a, model_path_b, @@ -110,6 +133,7 @@ def merge_models( speech_style_weight, tempo_weight, output_name, + use_slerp_instead_of_lerp, ): """model Aを起点に、model Bの各要素を重み付けしてマージする。 safetensors形式を前提とする。""" @@ -137,8 +161,8 @@ def merge_models( else: continue merged_model_weight[key] = ( - model_a_weight[key] * (1 - weight) + model_b_weight[key] * weight - ) + slerp_tensors if use_slerp_instead_of_lerp else lerp_tensors + )(weight, model_a_weight[key], model_b_weight[key]) merged_model_path = os.path.join( assets_root, output_name, f"{output_name}.safetensors" @@ -171,6 +195,7 @@ def merge_models_gr( voice_pitch_weight, speech_style_weight, tempo_weight, + use_slerp_instead_of_lerp, ): if output_name == "": return "Error: 新しいモデル名を入力してください。" @@ -182,6 +207,7 @@ def merge_models_gr( speech_style_weight, tempo_weight, output_name, + use_slerp_instead_of_lerp, ) return f"Success: モデルを{merged_model_path}に保存しました。" @@ -246,7 +272,20 @@ def load_styles_gr(model_name_a, model_name_b): with open(config_path_b, encoding="utf-8") as f: config_b = json.load(f) styles_b = list(config_b["data"]["style2id"].keys()) - return gr.Textbox(value=", ".join(styles_a)), gr.Textbox(value=", ".join(styles_b)) + + return ( + gr.Textbox(value=", ".join(styles_a)), + gr.Textbox(value=", ".join(styles_b)), + gr.TextArea( + label="スタイルのマージリスト", + placeholder=f"{DEFAULT_STYLE}, {DEFAULT_STYLE},{DEFAULT_STYLE}\nAngry, Angry, Angry", + value="\n".join( + f"{sty_a}, {sty_b}, {sty_a if sty_a != sty_b else ''}{sty_b}" + for sty_a in styles_a + for sty_b in styles_b + ), + ), + ) initial_md = """ @@ -359,6 +398,10 @@ def load_styles_gr(model_name_a, model_name_b): maximum=1, step=0.1, ) + use_slerp_instead_of_lerp = gr.Checkbox( + label="線形補完のかわりに球面線形補完を使う", + value=False, + ) with gr.Column(variant="panel"): gr.Markdown("## モデルファイル(safetensors)のマージ") model_merge_button = gr.Button("モデルファイルのマージ", variant="primary") @@ -414,7 +457,7 @@ def load_styles_gr(model_name_a, model_name_b): load_style_button.click( load_styles_gr, inputs=[model_name_a, model_name_b], - outputs=[styles_a, styles_b], + outputs=[styles_a, styles_b, style_triple_list], ) model_merge_button.click( @@ -429,6 +472,7 @@ def load_styles_gr(model_name_a, model_name_b): voice_pitch_slider, speech_style_slider, tempo_slider, + use_slerp_instead_of_lerp, ], outputs=[info_model_merge], ) diff --git a/webui_train.py b/webui_train.py index cd7af20e9..59cc9f594 100644 --- a/webui_train.py +++ b/webui_train.py @@ -155,7 +155,7 @@ def resample(model_name, normalize, trim, num_processes): return True, "Step 2, Success: 音声ファイルの前処理が完了しました" -def preprocess_text(model_name, use_jp_extra, val_per_lang): +def preprocess_text(model_name, use_jp_extra, val_per_lang, yomi_error): logger.info("Step 3: start preprocessing text...") dataset_path, lbl_path, train_path, val_path, config_path = get_path(model_name) try: @@ -189,6 +189,8 @@ def preprocess_text(model_name, use_jp_extra, val_per_lang): val_path, "--val-per-lang", str(val_per_lang), + "--yomi_error", + yomi_error, ] if use_jp_extra: cmd.append("--use_jp_extra") @@ -278,6 +280,7 @@ def preprocess_all( use_jp_extra, val_per_lang, log_interval, + yomi_error, ): if model_name == "": return False, "Error: モデル名を入力してください" @@ -304,8 +307,12 @@ def preprocess_all( ) if not success: return False, message + success, message = preprocess_text( - model_name=model_name, use_jp_extra=use_jp_extra, val_per_lang=val_per_lang + model_name=model_name, + use_jp_extra=use_jp_extra, + val_per_lang=val_per_lang, + yomi_error=yomi_error, ) if not success: return False, message @@ -488,6 +495,15 @@ def run_tensorboard(model_name): label="音声の最初と最後の無音を取り除く", value=False, ) + yomi_error = gr.Radio( + label="書き起こしが読めないファイルの扱い", + choices=[ + ("エラー出たらテキスト前処理が終わった時点で中断", "raise"), + ("読めないファイルは使わず続行", "skip"), + ("読めないファイルも無理やり読んで学習に使う", "use"), + ], + value="raise", + ) with gr.Accordion("詳細設定", open=False): num_processes = gr.Slider( label="プロセス数", @@ -630,6 +646,15 @@ def run_tensorboard(model_name): maximum=100, step=1, ) + yomi_error_manual = gr.Radio( + label="書き起こしが読めないファイルの扱い", + choices=[ + ("エラー出たらテキスト前処理が終わった時点で中断", "raise"), + ("読めないファイルは使わず続行", "skip"), + ("読めないファイルも無理やり読んで学習に使う", "use"), + ], + value="raise", + ) with gr.Column(): preprocess_text_btn = gr.Button(value="実行", variant="primary") info_preprocess_text = gr.Textbox(label="状況") @@ -693,6 +718,7 @@ def run_tensorboard(model_name): use_jp_extra, val_per_lang, log_interval, + yomi_error, ], outputs=[info_all], ) @@ -731,6 +757,7 @@ def run_tensorboard(model_name): model_name, use_jp_extra_manual, val_per_lang_manual, + yomi_error_manual, ], outputs=[info_preprocess_text], )