diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..7ff74dab1 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,43 @@ +# Dockerfile.deploy用 + +*.pyc +*.pyo +*.pyd +__pycache__ +*.pyc + +venv/ +.vscode/ + +.ipynb_checkpoints/ +*.ipynb + +.git/ +.gitignore + +Dockerfile* +.dockerignore +*.md +*.bat +LICENSE + +*.wav +*.zip +*.csv + +# 中国語と英語が必要な場合はコメントアウト +/bert/chinese-roberta-wwm-ext-large/ +/bert/deberta-v3-large/ + +Data/ +dict_data/user_dic.json +dict_data/user_dic.dic +docs/ +inputs/ +mos_results/ +pretrained/ +pretrained_jp_extra/ +scripts/ +slm/ +static/ +tools/ diff --git a/.gitignore b/.gitignore index 31602b04e..b556dfff7 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,7 @@ venv/ safetensors.ipynb *.wav +/static/ + +# pyopenjtalk's dictionary +*.dic diff --git a/Dockerfile.deploy b/Dockerfile.deploy new file mode 100644 index 000000000..dd351d107 --- /dev/null +++ b/Dockerfile.deploy @@ -0,0 +1,23 @@ +# Hugging face spaces (CPU) でエディタ (server_editor.py) のデプロイ用 + +# See https://huggingface.co/docs/hub/spaces-sdks-docker-first-demo + +FROM python:3.10 + +RUN useradd -m -u 1000 user + +USER user + +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +WORKDIR $HOME/app + +RUN pip install --no-cache-dir --upgrade pip + +COPY --chown=user . $HOME/app + +RUN pip install --no-cache-dir -r $HOME/app/requirements.txt + +# 必要に応じて制限を変更してください +CMD ["python", "server_editor.py", "--line_length", "50", "--line_count", "3"] diff --git a/Dockerfile.train b/Dockerfile.train new file mode 100644 index 000000000..333fd4634 --- /dev/null +++ b/Dockerfile.train @@ -0,0 +1,109 @@ +# PaperspaceのGradient環境での学習環境構築用Dockerfileです。 +# 環境のみ構築するため、イメージには学習用のコードは含まれていません。 +# 以下を参照しました。 +# https://github.com/gradient-ai/base-container/tree/main/pt211-tf215-cudatk120-py311 + +# 主なバージョン等 +# Ubuntu 22.04 +# Python 3.10 +# PyTorch 2.1.2 (CUDA 11.8) +# CUDA Toolkit 12.0, CUDNN 8.9.7 + + +# ================================================================== +# Initial setup +# ------------------------------------------------------------------ + +# Ubuntu 22.04 as base image +FROM ubuntu:22.04 +# RUN yes| unminimize + +# Set ENV variables +ENV LANG C.UTF-8 +ENV SHELL=/bin/bash +ENV DEBIAN_FRONTEND=noninteractive + +ENV APT_INSTALL="apt-get install -y --no-install-recommends" +ENV PIP_INSTALL="python3 -m pip --no-cache-dir install --upgrade" +ENV GIT_CLONE="git clone --depth 10" + +# ================================================================== +# Tools +# ------------------------------------------------------------------ + +RUN apt-get update && \ + $APT_INSTALL \ + sudo \ + build-essential \ + ca-certificates \ + wget \ + curl \ + git \ + zip \ + unzip \ + nano \ + ffmpeg \ + software-properties-common \ + gnupg \ + python3 \ + python3-pip \ + python3-dev + +# ================================================================== +# Git-lfs +# ------------------------------------------------------------------ + +RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash && \ + $APT_INSTALL git-lfs + + +# Add symlink so python and python3 commands use same python3.9 executable +RUN ln -s /usr/bin/python3 /usr/local/bin/python + +# ================================================================== +# Installing CUDA packages (CUDA Toolkit 12.0 and CUDNN 8.9.7) +# ------------------------------------------------------------------ +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \ + mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 && \ + wget https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda-repo-ubuntu2204-12-0-local_12.0.0-525.60.13-1_amd64.deb && \ + dpkg -i cuda-repo-ubuntu2204-12-0-local_12.0.0-525.60.13-1_amd64.deb && \ + cp /var/cuda-repo-ubuntu2204-12-0-local/cuda-*-keyring.gpg /usr/share/keyrings/ && \ + apt-get update && \ + $APT_INSTALL cuda && \ + rm cuda-repo-ubuntu2204-12-0-local_12.0.0-525.60.13-1_amd64.deb + +# Installing CUDNN +RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub && \ + add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" && \ + apt-get update && \ + $APT_INSTALL libcudnn8=8.9.7.29-1+cuda12.2 \ + libcudnn8-dev=8.9.7.29-1+cuda12.2 + + +ENV PATH=$PATH:/usr/local/cuda/bin +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH + + +# ================================================================== +# PyTorch +# ------------------------------------------------------------------ + +# Based on https://pytorch.org/get-started/locally/ + +RUN $PIP_INSTALL torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 + + +RUN $PIP_INSTALL jupyterlab + +# Install requirements.txt from the project +COPY requirements.txt /tmp/requirements.txt +RUN $PIP_INSTALL -r /tmp/requirements.txt +RUN rm /tmp/requirements.txt + +# ================================================================== +# Startup +# ------------------------------------------------------------------ + +EXPOSE 8888 6006 + +CMD jupyter lab --allow-root --ip=0.0.0.0 --no-browser --ServerApp.trust_xheaders=True --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True \ No newline at end of file diff --git a/Editor.bat b/Editor.bat new file mode 100644 index 000000000..7b0836c4f --- /dev/null +++ b/Editor.bat @@ -0,0 +1,11 @@ +chcp 65001 > NUL +@echo off + +pushd %~dp0 +echo Running server_editor.py --inbroser +venv\Scripts\python server_editor.py --inbrowser + +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +popd +pause \ No newline at end of file diff --git a/LGPL_LICENSE b/LGPL_LICENSE new file mode 100644 index 000000000..153d416dc --- /dev/null +++ b/LGPL_LICENSE @@ -0,0 +1,165 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + + This version of the GNU Lesser General Public License incorporates +the terms and conditions of version 3 of the GNU General Public +License, supplemented by the additional permissions listed below. + + 0. Additional Definitions. + + As used herein, "this License" refers to version 3 of the GNU Lesser +General Public License, and the "GNU GPL" refers to version 3 of the GNU +General Public License. + + "The Library" refers to a covered work governed by this License, +other than an Application or a Combined Work as defined below. + + An "Application" is any work that makes use of an interface provided +by the Library, but which is not otherwise based on the Library. +Defining a subclass of a class defined by the Library is deemed a mode +of using an interface provided by the Library. + + A "Combined Work" is a work produced by combining or linking an +Application with the Library. The particular version of the Library +with which the Combined Work was made is also called the "Linked +Version". + + The "Minimal Corresponding Source" for a Combined Work means the +Corresponding Source for the Combined Work, excluding any source code +for portions of the Combined Work that, considered in isolation, are +based on the Application, and not on the Linked Version. + + The "Corresponding Application Code" for a Combined Work means the +object code and/or source code for the Application, including any data +and utility programs needed for reproducing the Combined Work from the +Application, but excluding the System Libraries of the Combined Work. + + 1. Exception to Section 3 of the GNU GPL. + + You may convey a covered work under sections 3 and 4 of this License +without being bound by section 3 of the GNU GPL. + + 2. Conveying Modified Versions. + + If you modify a copy of the Library, and, in your modifications, a +facility refers to a function or data to be supplied by an Application +that uses the facility (other than as an argument passed when the +facility is invoked), then you may convey a copy of the modified +version: + + a) under this License, provided that you make a good faith effort to + ensure that, in the event an Application does not supply the + function or data, the facility still operates, and performs + whatever part of its purpose remains meaningful, or + + b) under the GNU GPL, with none of the additional permissions of + this License applicable to that copy. + + 3. Object Code Incorporating Material from Library Header Files. + + The object code form of an Application may incorporate material from +a header file that is part of the Library. You may convey such object +code under terms of your choice, provided that, if the incorporated +material is not limited to numerical parameters, data structure +layouts and accessors, or small macros, inline functions and templates +(ten or fewer lines in length), you do both of the following: + + a) Give prominent notice with each copy of the object code that the + Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the object code with a copy of the GNU GPL and this license + document. + + 4. Combined Works. + + You may convey a Combined Work under terms of your choice that, +taken together, effectively do not restrict modification of the +portions of the Library contained in the Combined Work and reverse +engineering for debugging such modifications, if you also do each of +the following: + + a) Give prominent notice with each copy of the Combined Work that + the Library is used in it and that the Library and its use are + covered by this License. + + b) Accompany the Combined Work with a copy of the GNU GPL and this license + document. + + c) For a Combined Work that displays copyright notices during + execution, include the copyright notice for the Library among + these notices, as well as a reference directing the user to the + copies of the GNU GPL and this license document. + + d) Do one of the following: + + 0) Convey the Minimal Corresponding Source under the terms of this + License, and the Corresponding Application Code in a form + suitable for, and under terms that permit, the user to + recombine or relink the Application with a modified version of + the Linked Version to produce a modified Combined Work, in the + manner specified by section 6 of the GNU GPL for conveying + Corresponding Source. + + 1) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (a) uses at run time + a copy of the Library already present on the user's computer + system, and (b) will operate properly with a modified version + of the Library that is interface-compatible with the Linked + Version. + + e) Provide Installation Information, but only if you would otherwise + be required to provide such information under section 6 of the + GNU GPL, and only to the extent that such information is + necessary to install and execute a modified version of the + Combined Work produced by recombining or relinking the + Application with a modified version of the Linked Version. (If + you use option 4d0, the Installation Information must accompany + the Minimal Corresponding Source and Corresponding Application + Code. If you use option 4d1, you must provide the Installation + Information in the manner specified by section 6 of the GNU GPL + for conveying Corresponding Source.) + + 5. Combined Libraries. + + You may place library facilities that are a work based on the +Library side by side in a single library together with other library +facilities that are not Applications and are not covered by this +License, and convey such a combined library under terms of your +choice, if you do both of the following: + + a) Accompany the combined library with a copy of the same work based + on the Library, uncombined with any other library facilities, + conveyed under the terms of this License. + + b) Give prominent notice with the combined library that part of it + is a work based on the Library, and explaining where to find the + accompanying uncombined form of the same work. + + 6. Revised Versions of the GNU Lesser General Public License. + + The Free Software Foundation may publish revised and/or new versions +of the GNU Lesser General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + + Each version is given a distinguishing version number. If the +Library as you received it specifies that a certain numbered version +of the GNU Lesser General Public License "or any later version" +applies to it, you have the option of following the terms and +conditions either of that published version or of any later version +published by the Free Software Foundation. If the Library as you +received it does not specify a version number of the GNU Lesser +General Public License, you may choose any version of the GNU Lesser +General Public License ever published by the Free Software Foundation. + + If the Library as you received it specifies that a proxy can decide +whether future versions of the GNU Lesser General Public License shall +apply, that proxy's public statement of acceptance of any version is +permanent authorization for you to choose that version for the +Library. \ No newline at end of file diff --git a/README.md b/README.md index bf2a05fe7..db2043d95 100644 --- a/README.md +++ b/README.md @@ -7,14 +7,15 @@ https://github.com/litagin02/Style-Bert-VITS2/assets/139731664/e853f9a2-db4a-420 - **解説チュートリアル動画** [YouTube](https://youtu.be/aTUSzgDl1iY) [ニコニコ動画](https://www.nicovideo.jp/watch/sm43391524) - [English README](docs/README_en.md) - [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/litagin02/Style-Bert-VITS2/blob/master/colab.ipynb) -- [🤗 オンラインデモはこちらから](https://huggingface.co/spaces/litagin/Style-Bert-VITS2-JVNV) +- [🤗 オンラインデモはこちらから](https://huggingface.co/spaces/litagin/Style-Bert-VITS2-Editor-Demo) - [Zennの解説記事](https://zenn.dev/litagin/articles/034819a5256ff4) -- [**リリースページ**](https://github.com/litagin02/Style-Bert-VITS2/releases/)、[更新履歴](docs/CHANGELOG.md) +- [**リリースページ**](https://github.com/litagin02/Style-Bert-VITS2/releases/)、[更新履歴](/docs/CHANGELOG.md) + - 2024-02-26: ver 2.3 (辞書機能とエディター機能) - 2024-02-09: ver 2.2 - 2024-02-07: ver 2.1 - - 2024-02-03: ver 2.0 + - 2024-02-03: ver 2.0 (JP-Extra) - 2024-01-09: ver 1.3 - 2023-12-31: ver 1.2 - 2023-12-29: ver 1.1 @@ -33,7 +34,7 @@ This repository is based on [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2 ## 使い方 - +CLIでの使い方は[こちら](/docs/CLI.md)を参照してください。 ### 動作環境 @@ -45,13 +46,14 @@ This repository is based on [Bert-VITS2](https://github.com/fishaudio/Bert-VITS2 Windowsを前提としています。 -1. [このzipファイル](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.2/Style-Bert-VITS2.zip)を**パスに日本語や空白が含まれない場所に**ダウンロードして展開します。 +1. [このzipファイル](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.3/Style-Bert-VITS2.zip)を**パスに日本語や空白が含まれない場所に**ダウンロードして展開します。 - グラボがある方は、`Install-Style-Bert-VITS2.bat`をダブルクリックします。 - グラボがない方は、`Install-Style-Bert-VITS2-CPU.bat`をダブルクリックします。CPU版では学習はできませんが、音声合成とマージは可能です。 2. 待つと自動で必要な環境がインストールされます。 -3. その後、自動的に音声合成するためのWebUIが起動したらインストール成功です。デフォルトのモデルがダウンロードされるているので、そのまま遊ぶことができます。 +3. その後、自動的に音声合成するためのエディターが起動したらインストール成功です。デフォルトのモデルがダウンロードされるているので、そのまま遊ぶことができます。 -またアップデートをしたい場合は、`Update-Style-Bert-VITS2.bat`をダブルクリックしてください。ただし**1.x**から**2.x**へアップデートする場合は、[このbatファイル](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.2/Update-to-JP-Extra.bat)を`Style-Bert-VITS2`フォルダがあるフォルダ(`Update-Style-Bert-VITS2.bat`等があるフォルダ)へ保存してからダブルクリックしてください。 +またアップデートをしたい場合は、`Update-Style-Bert-VITS2.bat`をダブルクリックしてください。ただし以下の場合は、専用のアップデートbatファイルを`Style-Bert-VITS2`フォルダがあるフォルダ(`Update-Style-Bert-VITS2.bat`等があるフォルダ)へ保存してからダブルクリックしてください。 +- **2.3以前**から**2.3以上**(辞書・エディター付き)へアップデート: [Update-to-Dict-Editor.bat](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.3/Update-to-Dict-Editor.bat) #### GitやPython使える人 @@ -69,7 +71,12 @@ python initialize.py # 必要なモデルとデフォルトTTSモデルをダ ### 音声合成 -`App.bat`をダブルクリックか、`python app.py`するとWebUIが起動します(`python app.py --cpu`でCPUモードで起動、学習中チェックに便利です)。インストール時にデフォルトのモデルがダウンロードされているので、学習していなくてもそれを使うことができます。 +音声合成エディターは`Editor.bat`をダブルクリックか、`python server_editor.py --inbrowser`すると起動します(`--device cpu`でCPUモードで起動)。画面内で各セリフごとに設定を変えて原稿を作ったり、保存や読み込みや辞書の編集等ができます。 +インストール時にデフォルトのモデルがダウンロードされているので、学習していなくてもそれを使うことができます。 + +エディター部分は[別リポジトリ](https://github.com/litagin02/Style-Bert-VITS2-Editor)に分かれています。 + +バージョン2.2以前での音声合成WebUIは、`App.bat`をダブルクリックか、`python app.py`するとWebUIが起動します。 音声合成に必要なモデルファイルたちの構造は以下の通りです(手動で配置する必要はありません)。 ``` @@ -90,6 +97,9 @@ model_assets ### 学習 +- CLIでの学習の詳細は[こちら](docs/CLI.md)を参照してください。 +- paperspace上での学習の詳細は[こちら](docs/paperspace.md)、colabでの学習は[こちら](http://colab.research.google.com/github/litagin02/Style-Bert-VITS2/blob/master/colab.ipynb)を参照してください。 + 学習には2-14秒程度の音声ファイルが複数と、それらの書き起こしデータが必要です。 - 既存コーパスなどですでに分割された音声ファイルと書き起こしデータがある場合はそのまま(必要に応じて書き起こしファイルを修正して)使えます。下の「学習WebUI」を参照してください。 @@ -123,6 +133,10 @@ API仕様は起動後に`/docs`にて確認ください。 - 入力文字数はデフォルトで100文字が上限となっています。これは`config.yml`の`server.limit`で変更できます。 - デフォルトではCORS設定を全てのドメインで許可しています。できる限り、`config.yml`の`server.origins`の値を変更し、信頼できるドメインに制限ください(キーを消せばCORS設定を無効にできます)。 +また音声合成エディターのAPIサーバーは`python server_editor.py`で起動します。があまりまだ整備をしていません。[エディターのリポジトリ](https://github.com/litagin02/Style-Bert-VITS2-Editor)から必要な最低限のAPIしか現在は実装していません。 + +音声合成エディターのウェブデプロイについては[このDockerfile](Dockerfile.deploy)を参考にしてください。 + ### マージ 2つのモデルを、「声質」「声の高さ」「感情表現」「テンポ」の4点で混ぜ合わせて、新しいモデルを作ることが出来ます。 @@ -167,6 +181,18 @@ In addition to the original reference (written below), I used the following repo [The pretrained model](https://huggingface.co/litagin/Style-Bert-VITS2-1.0-base) and [JP-Extra version](https://huggingface.co/litagin/Style-Bert-VITS2-2.0-base-JP-Extra) is essentially taken from [the original base model of Bert-VITS2 v2.1](https://huggingface.co/Garydesu/bert-vits2_base_model-2.1) and [JP-Extra pretrained model of Bert-VITS2](https://huggingface.co/Stardust-minus/Bert-VITS2-Japanese-Extra), so all the credits go to the original author ([Fish Audio](https://github.com/fishaudio)): +In addition, [text/user_dict/](text/user_dict) module is based on the following repositories: +- [voicevox_engine](https://github.com/VOICEVOX/voicevox_engine)] +and the license of this module is LGPL v3. + +## LICENSE + +This repository is licensed under the GNU Affero General Public License v3.0, the same as the original Bert-VITS2 repository. For more details, see [LICENSE](LICENSE). + +In addition, [text/user_dict/](text/user_dict) module is licensed under the GNU Lesser General Public License v3.0, inherited from the original VOICEVOX engine repository. For more details, see [LGPL_LICENSE](LGPL_LICENSE). + + + Below is the original README.md. --- diff --git a/app.py b/app.py index f9ceb5edc..e17a0a56a 100644 --- a/app.py +++ b/app.py @@ -3,6 +3,7 @@ import json import os import sys +from pathlib import Path from typing import Optional import gradio as gr @@ -57,6 +58,8 @@ def tts_fn( kata_tone_json_str, use_tone, speaker, + pitch_scale, + intonation_scale, ): model_holder.load_model_gr(model_name, model_path) @@ -111,6 +114,8 @@ def tts_fn( style_weight=style_weight, given_tone=tone, sid=speaker_id, + pitch_scale=pitch_scale, + intonation_scale=intonation_scale, ) except InvalidToneError as e: logger.error(f"Tone error: {e}") @@ -197,7 +202,9 @@ def tts_fn( initial_md = f""" # Style-Bert-VITS2 ver {LATEST_VERSION} 音声合成 -注意: 初期からある[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)は、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。ライセンスは[CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.ja)です。 +- Ver 2.3で追加されたエディターのほうが実際に読み上げさせるには使いやすいかもしれません。`Editor.bat`か`python server_editor.py`で起動できます。 + +- 初期からある[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)は、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。ライセンスは[CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/deed.ja)です。 """ how_to_md = """ @@ -267,7 +274,7 @@ def gr_util(item): help="Do not launch app automatically", ) args = parser.parse_args() - model_dir = args.dir + model_dir = Path(args.dir) if args.cpu: device = "cpu" @@ -306,6 +313,22 @@ def gr_util(item): refresh_button = gr.Button("更新", scale=1, visible=True) load_button = gr.Button("ロード", scale=1, variant="primary") text_input = gr.TextArea(label="テキスト", value=initial_text) + pitch_scale = gr.Slider( + minimum=0.8, + maximum=1.5, + value=1, + step=0.05, + label="音程(1以外では音質劣化)", + visible=False, # pyworldが必要 + ) + intonation_scale = gr.Slider( + minimum=0, + maximum=2, + value=1, + step=0.1, + label="抑揚(1以外では音質劣化)", + visible=False, # pyworldが必要 + ) line_split = gr.Checkbox( label="改行で分けて生成(分けたほうが感情が乗ります)", @@ -441,6 +464,8 @@ def gr_util(item): tone, use_tone, speaker, + pitch_scale, + intonation_scale, ], outputs=[text_output, audio_output, tone], ) diff --git a/colab.ipynb b/colab.ipynb index f5cab3e79..f4e932b25 100644 --- a/colab.ipynb +++ b/colab.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Style-Bert-VITS2 (ver 2.2) のGoogle Colabでの学習\n", + "# Style-Bert-VITS2 (ver 2.3) のGoogle Colabでの学習\n", "\n", "Google Colab上でStyle-Bert-VITS2の学習を行うことができます。\n", "\n", @@ -115,8 +115,11 @@ "# モデル名(話者名)を入力\n", "model_name = \"your_model_name\"\n", "\n", + "# こういうふうに書き起こして欲しいという例文(句読点の入れ方・笑い方や固有名詞等)\n", + "initial_prompt = \"こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!\"\n", + "\n", "!python slice.py -i {input_dir} -o {dataset_root}/{model_name}/raw\n", - "!python transcribe.py -i {dataset_root}/{model_name}/raw -o {dataset_root}/{model_name}/esd.list --speaker_name {model_name} --compute_type float16" + "!python transcribe.py -i {dataset_root}/{model_name}/raw -o {dataset_root}/{model_name}/esd.list --speaker_name {model_name} --compute_type float16 --initial_prompt {initial_prompt}" ] }, { @@ -262,6 +265,7 @@ " freeze_JP_bert=False,\n", " freeze_ZH_bert=False,\n", " freeze_style=False,\n", + " freeze_decoder=False, # ここをTrueにするともしかしたら違う結果になるかもしれません。\n", " use_jp_extra=use_jp_extra,\n", " val_per_lang=0,\n", " log_interval=200,\n", diff --git a/common/constants.py b/common/constants.py index d2d6352c0..751d5e123 100644 --- a/common/constants.py +++ b/common/constants.py @@ -4,7 +4,10 @@ # See https://huggingface.co/spaces/gradio/theme-gallery for more themes GRADIO_THEME: str = "NoCrypt/miku" -LATEST_VERSION: str = "2.2" +LATEST_VERSION: str = "2.3" + +USER_DICT_DIR = "dict_data" + DEFAULT_STYLE: str = "Neutral" DEFAULT_STYLE_WEIGHT: float = 5.0 diff --git a/common/log.py b/common/log.py index 51dca5f3f..679bb2c77 100644 --- a/common/log.py +++ b/common/log.py @@ -1,6 +1,7 @@ """ logger封装 """ + from loguru import logger from .stdout_wrapper import SAFE_STDOUT diff --git a/common/tts_model.py b/common/tts_model.py index c14859813..de4e830df 100644 --- a/common/tts_model.py +++ b/common/tts_model.py @@ -1,17 +1,19 @@ -import numpy as np -import gradio as gr -import torch import os import warnings +from pathlib import Path +from typing import Optional, Union + +import gradio as gr +import numpy as np + +import torch from gradio.processing_utils import convert_to_16_bit_wav -from typing import Dict, List, Optional, Union import utils from infer import get_net_g, infer from models import SynthesizerTrn from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra -from .log import logger from .constants import ( DEFAULT_ASSIST_TEXT_WEIGHT, DEFAULT_LENGTH, @@ -23,25 +25,60 @@ DEFAULT_STYLE, DEFAULT_STYLE_WEIGHT, ) +from .log import logger + + +def adjust_voice(fs, wave, pitch_scale, intonation_scale): + if pitch_scale == 1.0 and intonation_scale == 1.0: + # 初期値の場合は、音質劣化を避けるためにそのまま返す + return fs, wave + + try: + import pyworld + except ImportError: + raise ImportError( + "pyworld is not installed. Please install it by `pip install pyworld`" + ) + + # pyworldでf0を加工して合成 + # pyworldよりもよいのがあるかもしれないが…… + + wave = wave.astype(np.double) + f0, t = pyworld.harvest(wave, fs) + # 質が高そうだしとりあえずharvestにしておく + + sp = pyworld.cheaptrick(wave, f0, t, fs) + ap = pyworld.d4c(wave, f0, t, fs) + + non_zero_f0 = [f for f in f0 if f != 0] + f0_mean = sum(non_zero_f0) / len(non_zero_f0) + + for i, f in enumerate(f0): + if f == 0: + continue + f0[i] = pitch_scale * f0_mean + intonation_scale * (f - f0_mean) + + wave = pyworld.synthesize(f0, sp, ap, fs) + return fs, wave class Model: def __init__( - self, model_path: str, config_path: str, style_vec_path: str, device: str + self, model_path: Path, config_path: Path, style_vec_path: Path, device: str ): - self.model_path: str = model_path - self.config_path: str = config_path + self.model_path: Path = model_path + self.config_path: Path = config_path + self.style_vec_path: Path = style_vec_path self.device: str = device - self.style_vec_path: str = style_vec_path self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path) - self.spk2id: Dict[str, int] = self.hps.data.spk2id - self.id2spk: Dict[int, str] = {v: k for k, v in self.spk2id.items()} + self.spk2id: dict[str, int] = self.hps.data.spk2id + self.id2spk: dict[int, str] = {v: k for k, v in self.spk2id.items()} self.num_styles: int = self.hps.data.num_styles if hasattr(self.hps.data, "style2id"): - self.style2id: Dict[str, int] = self.hps.data.style2id + self.style2id: dict[str, int] = self.hps.data.style2id else: - self.style2id: Dict[str, int] = {str(i): i for i in range(self.num_styles)} + self.style2id: dict[str, int] = {str(i): i for i in range(self.num_styles)} if len(self.style2id) != self.num_styles: raise ValueError( f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})" @@ -57,7 +94,7 @@ def __init__( def load_net_g(self): self.net_g = get_net_g( - model_path=self.model_path, + model_path=str(self.model_path), version=self.hps.version, device=self.device, hps=self.hps, @@ -97,6 +134,9 @@ def infer( style: str = DEFAULT_STYLE, style_weight: float = DEFAULT_STYLE_WEIGHT, given_tone: Optional[list[int]] = None, + pitch_scale: float = 1.0, + intonation_scale: float = 1.0, + ignore_unknown: bool = False, ) -> tuple[int, np.ndarray]: logger.info(f"Start generating audio data from text:\n{text}") if language != "JP" and self.hps.version.endswith("JP-Extra"): @@ -134,6 +174,7 @@ def infer( assist_text_weight=assist_text_weight, style_vec=style_vector, given_tone=given_tone, + ignore_unknown=ignore_unknown, ) else: texts = text.split("\n") @@ -156,55 +197,99 @@ def infer( assist_text=assist_text, assist_text_weight=assist_text_weight, style_vec=style_vector, + ignore_unknown=ignore_unknown, ) ) if i != len(texts) - 1: audios.append(np.zeros(int(44100 * split_interval))) audio = np.concatenate(audios) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - audio = convert_to_16_bit_wav(audio) logger.info("Audio data generated successfully") + if not (pitch_scale == 1.0 and intonation_scale == 1.0): + _, audio = adjust_voice( + fs=self.hps.data.sampling_rate, + wave=audio, + pitch_scale=pitch_scale, + intonation_scale=intonation_scale, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + audio = convert_to_16_bit_wav(audio) return (self.hps.data.sampling_rate, audio) class ModelHolder: - def __init__(self, root_dir: str, device: str): - self.root_dir: str = root_dir + def __init__(self, root_dir: Path, device: str): + self.root_dir: Path = root_dir self.device: str = device - self.model_files_dict: Dict[str, List[str]] = {} + self.model_files_dict: dict[str, list[Path]] = {} self.current_model: Optional[Model] = None - self.model_names: List[str] = [] - self.models: List[Model] = [] + self.model_names: list[str] = [] + self.models: list[Model] = [] self.refresh() def refresh(self): self.model_files_dict = {} self.model_names = [] self.current_model = None - model_dirs = [ - d - for d in os.listdir(self.root_dir) - if os.path.isdir(os.path.join(self.root_dir, d)) - ] - for model_name in model_dirs: - model_dir = os.path.join(self.root_dir, model_name) + + model_dirs = [d for d in self.root_dir.iterdir() if d.is_dir()] + for model_dir in model_dirs: model_files = [ - os.path.join(model_dir, f) - for f in os.listdir(model_dir) - if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors") + f + for f in model_dir.iterdir() + if f.suffix in [".pth", ".pt", ".safetensors"] ] if len(model_files) == 0: + logger.warning(f"No model files found in {model_dir}, so skip it") + continue + config_path = model_dir / "config.json" + if not config_path.exists(): logger.warning( - f"No model files found in {self.root_dir}/{model_name}, so skip it" + f"Config file {config_path} not found, so skip {model_dir}" ) continue - self.model_files_dict[model_name] = model_files - self.model_names.append(model_name) + self.model_files_dict[model_dir.name] = model_files + self.model_names.append(model_dir.name) + + def models_info(self): + if hasattr(self, "_models_info"): + return self._models_info + result = [] + for name, files in self.model_files_dict.items(): + # Get styles + config_path = self.root_dir / name / "config.json" + hps = utils.get_hparams_from_file(config_path) + style2id: dict[str, int] = hps.data.style2id + styles = list(style2id.keys()) + result.append( + { + "name": name, + "files": [str(f) for f in files], + "styles": styles, + } + ) + self._models_info = result + return result + + def load_model(self, model_name: str, model_path_str: str): + model_path = Path(model_path_str) + if model_name not in self.model_files_dict: + raise ValueError(f"Model `{model_name}` is not found") + if model_path not in self.model_files_dict[model_name]: + raise ValueError(f"Model file `{model_path}` is not found") + if self.current_model is None or self.current_model.model_path != model_path: + self.current_model = Model( + model_path=model_path, + config_path=self.root_dir / model_name / "config.json", + style_vec_path=self.root_dir / model_name / "style_vectors.npy", + device=self.device, + ) + return self.current_model def load_model_gr( - self, model_name: str, model_path: str + self, model_name: str, model_path_str: str ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: + model_path = Path(model_path_str) if model_name not in self.model_files_dict: raise ValueError(f"Model `{model_name}` is not found") if model_path not in self.model_files_dict[model_name]: @@ -223,8 +308,8 @@ def load_model_gr( ) self.current_model = Model( model_path=model_path, - config_path=os.path.join(self.root_dir, model_name, "config.json"), - style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"), + config_path=self.root_dir / model_name / "config.json", + style_vec_path=self.root_dir / model_name / "style_vectors.npy", device=self.device, ) speakers = list(self.current_model.spk2id.keys()) diff --git a/configs/config.json b/configs/config.json index 6aa647ccf..25e86db6b 100644 --- a/configs/config.json +++ b/configs/config.json @@ -20,7 +20,8 @@ "freeze_ZH_bert": false, "freeze_JP_bert": false, "freeze_EN_bert": false, - "freeze_style": false + "freeze_style": false, + "freeze_encoder": false }, "data": { "training_files": "Data/your_model_name/filelists/train.list", @@ -67,5 +68,5 @@ "use_spectral_norm": false, "gin_channels": 256 }, - "version": "2.2" + "version": "2.3" } diff --git a/configs/configs_jp_extra.json b/configs/configs_jp_extra.json index 7e2698ea5..616d1d31d 100644 --- a/configs/configs_jp_extra.json +++ b/configs/configs_jp_extra.json @@ -22,7 +22,8 @@ "freeze_JP_bert": false, "freeze_EN_bert": false, "freeze_emo": false, - "freeze_style": false + "freeze_style": false, + "freeze_decoder": false }, "data": { "use_jp_extra": true, @@ -74,5 +75,5 @@ "initial_channel": 64 } }, - "version": "2.2-JP-Extra" + "version": "2.3-JP-Extra" } diff --git a/dict_data/.gitignore b/dict_data/.gitignore new file mode 100644 index 000000000..0b60cb635 --- /dev/null +++ b/dict_data/.gitignore @@ -0,0 +1,3 @@ +* +!.gitignore +!default.csv diff --git a/dict_data/default.csv b/dict_data/default.csv new file mode 100644 index 000000000..84085cde2 --- /dev/null +++ b/dict_data/default.csv @@ -0,0 +1,5 @@ +Bert,,,8609,名詞,固有名詞,一般,*,*,*,Bert,バアト,バアト,0/3,* +VITS,,,8609,名詞,固有名詞,一般,*,*,*,VITS,ビッツ,ビッツ,0/3,* +VITS二,,,8609,名詞,固有名詞,一般,*,*,*,VITS二,ビッツツー,ビッツツー,4/5,* +BertVITS,,,8609,名詞,固有名詞,一般,*,*,*,BertVITS,バアトビッツ,バアトビッツ,4/6,* +担々麺,,,8609,名詞,固有名詞,一般,*,*,*,担々麺,タンタンメン,タンタンメン,3/6,* diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index d9a301545..15e0ef931 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,5 +1,71 @@ # Changelog +## v2.3 (2024-02-25) + +### 大きな変更 + +大きい変更をいくつかしたため、**アップデートはまた専用の手順**が必要です。下記の指示にしたがってください。 + +#### ユーザー辞書機能 +あらかじめ辞書に固有名詞を追加することができ、それが**学習時**・**音声合成時**の読み取得部分に適応されます。辞書の追加・編集は次のエディタ経由で行ってください。または、手持ちのOpenJTalkのcsv形式の辞書がある場合は、`dict_data/default.csv`ファイルを直接上書きや追加しても可能です。 + +使えそうな辞書(ライセンス等は各自ご確認ください)(他に良いのがあったら教えて下さい): + +- [WariHima/Kanayomi-dict](https://github.com/WariHima/KanaYomi-dict) +- [takana-v/tsumu_dic](https://github.com/takana-v/tsumu_dic) + + +辞書機能部分の[実装](/text/user_dict/) は、中のREADMEにある通り、[VOICEVOX Editor](https://github.com/VOICEVOX/voicevox) のものを使っており、この部分のコードライセンスはLGPL-3.0です。 + +#### 音声合成専用エディタ + +音声合成専用エディタを追加。今までのWebUIでできた機能のほか、次のような機能が使えます(つまり既存の日本語音声合成ソフトウェアのエディタを真似ました): +- セリフ単位でキャラや設定を変更しながら原稿を作り、それを一括で生成したり、原稿を保存等したり読み込んだり +- GUIよる分かりやすいアクセント調整 +- ユーザー辞書への単語追加や編集 + +`Editor.bat`をダブルクリックか`python server_editor.py --inbrowser`で起動します。エディター部分は[こちらの別リポジトリ](https://github.com/litagin02/Style-Bert-VITS2-Editor)になります。フロントエンド初心者なのでプルリクや改善案等をお待ちしています。 + +### バグ修正 + +- 特定の状況で読みが正しく取得できず `list index out of range` となるバグの修正 +- 前処理時に、書き起こしファイルのある行の形式が不正だと、書き起こしファイルのそれ以降の内容が消えてしまうバグの修正 +- faster-whisperが1.0.0にメジャーバージョンアップされ(今のところ)大幅に劣化したので、バージョンを0.10.1へ固定 + +### 改善 + +- テキスト前処理時に、読みの取得の失敗等があった場合に、処理を中断せず、エラーがおきた箇所を`text_error.log`ファイルへ保存するように変更。 +- 音声合成時に、読めない文字があったときはエラーを起こさず、その部分を無視して読み上げるように変更(学習段階ではエラーを出します) +- コマンドラインで前処理や学習が簡単にできるよう、前処理を行う`preprocess_all.py`を追加(詳しくは[CLI.md](/docs/CLI.md)を参照) +- 学習の際に、自動的に自分のhugging faceリポジトリへ結果をアップロードするオプションを追加。コマンドライン引数で`--repo_id username/my_model`のように指定してください(詳しくは[CLI.md](/docs/CLI.md)を参照)。🤗の無制限ストレージが使えるのでクラウドでの学習に便利です。 +- 学習時にデコーダー部分を凍結するオプションの追加。品質がもしかしたら上がるかもしれません。 +- `initialize.py`に引数`--dataset_root`と`--assets_root`を追加し、`configs/paths.yml`をその時点で変更できるようにした + +### その他 + +- [paperspaceでの学習の手引きを追加](/docs/paperspace.md)、paperspaceでのimageに使える[Dockerfile](/Dockerfile.train)を追加 +- [CLIでの各種処理の実行の仕方を追加](/docs/CLI.md) +- [Hugging Face spacesで遊べる音声合成エディタ](https://huggingface.co/spaces/litagin/Style-Bert-VITS2-Editor-Demo)をデプロイするための[Dockerfile](Dockerfile.deploy)を追加 + +### アップデート手順 + +- [Update-to-Dict-Editor.bat](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.3/Update-to-Dict-Editor.bat)をダウンロードし、`Style-Bert-VITS2`フォルダがある場所(インストールbatファイルとかがあったところ)においてダブルクリックしてください。 + +- 手動での場合は、以下の手順で実行してください: +```bash +git pull +venv\Scripts\activate +pip uninstall pyopenjtalk-prebuilt +pip install -r requirements.txt +# python initialize.py # これを1.x系からのアップデートの場合は実行してください +python server_editor.py --inbrowser +``` + +### 新規インストール手順 +[このzip](https://github.com/litagin02/Style-Bert-VITS2/releases/download/2.3/Style-Bert-VITS2.zip)をダウンロードし、解凍してください。 +を展開し、`Install-Style-Bert-VITS2.bat`をダブルクリックしてください。 + + ## v2.2 (2024-02-09) ### 変更・機能追加 diff --git a/docs/CLI.md b/docs/CLI.md index ca955e169..08e2fd03e 100644 --- a/docs/CLI.md +++ b/docs/CLI.md @@ -1,35 +1,49 @@ # CLI -**WIP** +## 0. Install and global paths settings -## Dataset +```bash +git clone https://github.com/litagin02/Style-Bert-VITS2.git +cd Style-Bert-VITS2 +python -m venv venv +venv\Scripts\activate +pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 +pip install -r requirements.txt +``` + +Then download the necessary models and the default TTS model, and set the global paths. +```bash +python initialize.py [--skip_jvnv] [--dataset_root ] [--assets_root ] +``` + +Optional: +- `--skip_jvnv`: Skip downloading the default JVNV voice models (use this if you only have to train your own models). +- `--dataset_root`: Default: `Data`. Root directory of the training dataset. The training dataset of `{model_name}` should be placed in `{dataset_root}/{model_name}`. +- `--assets_root`: Default: `model_assets`. Root directory of the model assets (for inference). In training, the model assets will be saved to `{assets_root}/{model_name}`, and in inference, we load all the models from `{assets_root}`. -`Dataset.bat` webui (`python webui_dataset.py`) consists of **slice audio** and **transcribe wavs**. -### Slice audio +## 1. Dataset preparation +### 1.1. Slice wavs ```bash -python slice.py -i -o -m -M +python slice.py --model_name [-i ] [-m ] [-M ] ``` Required: -- `input_dir`: Path to the directory containing the audio files to slice. -- `output_dir`: Path to the directory where the sliced audio files will be saved. +- `model_name`: Name of the speaker (to be used as the name of the trained model). Optional: -- `min_sec`: Minimum duration of the sliced audio files in seconds (default 2). -- `max_sec`: Maximum duration of the sliced audio files in seconds (default 12). +- `input_dir`: Path to the directory containing the audio files to slice (default: `inputs`) +- `min_sec`: Minimum duration of the sliced audio files in seconds (default: 2). +- `max_sec`: Maximum duration of the sliced audio files in seconds (default: 12). -### Transcribe wavs +### 1.2. Transcribe wavs ```bash -python transcribe.py -i -o --speaker_name +python transcribe.py --model_name ``` - Required: -- `input_dir`: Path to the directory containing the audio files to transcribe. -- `output_file`: Path to the file where the transcriptions will be saved. -- `speaker_name`: Name of the speaker. +- `model_name`: Name of the speaker (to be used as the name of the trained model). Optional - `--initial_prompt`: Initial prompt to use for the transcription (default value is specific to Japanese). @@ -38,19 +52,45 @@ Optional - `--model`: Whisper model, default: `large-v3` - `--compute_type`: default: `bfloat16` -## Train +## 2. Preprocess -`Train.bat` webui (`python webui_train.py`) consists of the following. - -### Preprocess audio ```bash -python resample.py -i -o [--normalize] [--trim] +python preprocess_all.py -m [--use_jp_extra] [-b ] [-e ] [-s ] [--num_processes ] [--normalize] [--trim] [--val_per_lang ] [--log_interval ] [--freeze_EN_bert] [--freeze_JP_bert] [--freeze_ZH_bert] [--freeze_style] [--freeze_decoder] ``` Required: -- `input_dir`: Path to the directory containing the audio files to preprocess. -- `output_dir`: Path to the directory where the preprocessed audio files will be saved. +- `model_name`: Name of the speaker (to be used as the name of the trained model). + +Optional: +- `--batch_size`, `-b`: Batch size (default: 2). +- `--epochs`, `-e`: Number of epochs (default: 100). +- `--save_every_steps`, `-s`: Save every steps (default: 1000). +- `--num_processes`: Number of processes (default: half of the number of CPU cores). +- `--normalize`: Loudness normalize audio. +- `--trim`: Trim silence. +- `--freeze_EN_bert`: Freeze English BERT. +- `--freeze_JP_bert`: Freeze Japanese BERT. +- `--freeze_ZH_bert`: Freeze Chinese BERT. +- `--freeze_style`: Freeze style vector. +- `--freeze_decoder`: Freeze decoder. +- `--use_jp_extra`: Use JP-Extra model. +- `--val_per_lang`: Validation data per language (default: 0). +- `--log_interval`: Log interval (default: 200). + +## 3. Train + +Training settings are automatically loaded from the above process. + +If NOT using JP-Extra model: +```bash +python train_ms.py [--repo_id /] +``` -TO BE WRITTEN (WIP) +If using JP-Extra model: +```bash +python train_ms_jp_extra.py [--repo_id /] [--skip_default_style] +``` -これいる? +Optional: +- `--repo_id`: Hugging Face repository ID to upload the trained model to. You should have logged in using `huggingface-cli login` before running this command. +- `--skip_default_style`: Skip making the default style vector. Use this if you want to resume training (since the default style vector is already made). diff --git a/docs/paperspace.md b/docs/paperspace.md new file mode 100644 index 000000000..2ae9895e4 --- /dev/null +++ b/docs/paperspace.md @@ -0,0 +1,86 @@ +# Paperspace gradient で学習する + +詳しいコマンドの叩き方は[こちら](CLI.md)を参照してください。 + +## 事前準備 +- Paperspace のアカウントを作成し必要なら課金する +- Projectを作る +- NotebookはStart from Scratchを選択して空いてるGPUマシンを選ぶ + +## 使い方 + +以下では次のような方針でやっています。 + +- `/storage/`は永続ストレージなので、事前学習モデルとかを含めてリポジトリをクローンするとよい。 +- `/notebooks/`はノートブックごとに変わるストレージなので(同一ノートブック違うランタイムだと共有されるらしい)、データセットやその結果を保存する。ただ容量が多い場合はあふれる可能性があるので`/tmp/`に保存するとよいかもしれない。 +- hugging faceアカウントを作り、(プライベートな)リポジトリを作って、学習元データを置いたり、学習結果を随時アップロードする。 + +### 1. 環境を作る + +以下はデフォルトの`Start from Scratch`で作成した環境の場合。[Dockerfile.train](../Dockerfile.train)を使ったカスタムイメージをするとPythonの環境構築の手間がちょっと省けるので、それを使いたい人は`Advanced Options / Container / Name`に[`litagin/mygradient:latest`](https://hub.docker.com/r/litagin/mygradient/tags)を指定すると使えます(pipの箇所が不要になる等)。 + +まずは永続ストレージにgit clone +```bash +mkdir -p /storage/sbv2 +cd /storage/sbv2 +git clone https://github.com/litagin02/Style-Bert-VITS2.git +``` +環境構築(デフォルトはPyTorch 1.x系、Python 3.9の模様) +```bash +cd /storage/sbv2/Style-Bert-VITS2 +pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 && pip install -r requirements.txt +``` +事前学習済みモデル等のダウンロード、またパスを`/notebooks/`以下のものに設定 +```bash +python initialize.py --skip_jvnv --dataset_root /notebooks/Data --assets_root /notebooks/model_assets +``` + +### 2. データセットの準備 +以下では`username/voices`というデータセットリポジトリにある`Foo.zip`というデータセットを使うことを想定しています。 +```bash +cd /notebooks +huggingface-cli login # 事前にトークンが必要 +huggingface-cli download username/voices Foo.zip --repo-type dataset --local-dir . +``` + +- zipファイル中身が既に`raw`と`esd.list`があるデータ(スライス・書き起こし済み)の場合 +```bash +mkdir -p Data/Foo +unzip Foo.zip -d Data/Foo +rm Foo.zip +cd /storage/sbv2/Style-Bert-VITS2 +``` + +- zipファイルが音声ファイルのみの場合 +```bash +mkdir inputs +unzip Foo.zip -d inputs +cd /storage/sbv2/Style-Bert-VITS2 +python slice.py --model_name Foo -i /notebooks/inputs +python transcribe.py --model_name Foo +``` + +それが終わったら、以下のコマンドで一括前処理を行う(パラメータは各自お好み、バッチサイズ5か6でVRAM 16GBギリくらい)。 +```bash +python preprocess_all.py --model_name Foo -b 5 -e 300 --use_jp_extra +``` + +### 3. 学習 + +Hugging faceの`username/sbv2-private`というモデルリポジトリに学習済みモデルをアップロードすることを想定しています。事前に`huggingface-cli login`でログインしておくこと。 +```bash +python train_ms_jp_extra.py --repo_id username/sbv2-private +``` +(JP-Extraでない場合は`train_ms.py`を使う) + +### 4. 学習再開 + +Notebooksの時間制限が切れてから別Notebooksで同じモデルを学習を再開する場合(環境構築は必要)。 +```bash +huggingface-cli login +cd /notebooks +huggingface-cli download username/sbv2-private --include "Data/Foo/*" --local-dir . +cd /storage/sbv2/Style-Bert-VITS2 +python train_ms_jp_extra.py --repo_id username/sbv2-private --skip_default_style +``` +前回の設定が残っているので特に前処理等は不要。 \ No newline at end of file diff --git a/gen_yaml.py b/gen_yaml.py new file mode 100644 index 000000000..91301accb --- /dev/null +++ b/gen_yaml.py @@ -0,0 +1,32 @@ +import os +import shutil +import yaml +import argparse + +parser = argparse.ArgumentParser( + description="config.ymlの生成。あらかじめ前準備をしたデータをバッチファイルなどで連続で学習する時にtrain_ms.pyより前に使用する。" +) +# そうしないと最後の前準備したデータで学習してしまう +parser.add_argument("--model_name", type=str, help="Model name", required=True) +parser.add_argument( + "--dataset_path", + type=str, + help="Dataset path(example: Data\\your_model_name)", + required=True, +) +args = parser.parse_args() + + +def gen_yaml(model_name, dataset_path): + if not os.path.exists("config.yml"): + shutil.copy(src="default_config.yml", dst="config.yml") + with open("config.yml", "r", encoding="utf-8") as f: + yml_data = yaml.safe_load(f) + yml_data["model_name"] = model_name + yml_data["dataset_path"] = dataset_path + with open("config.yml", "w", encoding="utf-8") as f: + yaml.dump(yml_data, f, allow_unicode=True) + + +if __name__ == "__main__": + gen_yaml(args.model_name, args.dataset_path) diff --git a/infer.py b/infer.py index 6febc0851..914a3554e 100644 --- a/infer.py +++ b/infer.py @@ -52,9 +52,12 @@ def get_text( assist_text=None, assist_text_weight=0.7, given_tone=None, + ignore_unknown=False, ): use_jp_extra = hps.version.endswith("JP-Extra") - norm_text, phone, tone, word2ph = clean_text(text, language_str, use_jp_extra) + norm_text, phone, tone, word2ph = clean_text( + text, language_str, use_jp_extra, ignore_unknown=ignore_unknown + ) if given_tone is not None: if len(given_tone) != len(phone): raise InvalidToneError( @@ -71,7 +74,13 @@ def get_text( word2ph[i] = word2ph[i] * 2 word2ph[0] += 1 bert_ori = get_bert( - norm_text, word2ph, language_str, device, assist_text, assist_text_weight + norm_text, + word2ph, + language_str, + device, + assist_text, + assist_text_weight, + ignore_unknown, ) del word2ph assert bert_ori.shape[-1] == len(phone), phone @@ -118,6 +127,7 @@ def infer( assist_text=None, assist_text_weight=0.7, given_tone=None, + ignore_unknown=False, ): is_jp_extra = hps.version.endswith("JP-Extra") bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( @@ -128,6 +138,7 @@ def infer( assist_text=assist_text, assist_text_weight=assist_text_weight, given_tone=given_tone, + ignore_unknown=ignore_unknown, ) if skip_start: phones = phones[3:] diff --git a/initialize.py b/initialize.py index c163ef9f6..5e35061f2 100644 --- a/initialize.py +++ b/initialize.py @@ -2,6 +2,7 @@ import json from pathlib import Path +import yaml from huggingface_hub import hf_hub_download from common.log import logger @@ -90,9 +91,21 @@ def download_jvnv_models(): ) -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser() parser.add_argument("--skip_jvnv", action="store_true") + parser.add_argument( + "--dataset_root", + type=str, + help="Dataset root path (default: Data)", + default=None, + ) + parser.add_argument( + "--assets_root", + type=str, + help="Assets root path (default: model_assets)", + default=None, + ) args = parser.parse_args() download_bert_models() @@ -105,3 +118,21 @@ def download_jvnv_models(): if not args.skip_jvnv: download_jvnv_models() + + if args.dataset_root is None and args.assets_root is None: + return + + # Change default paths if necessary + paths_yml = Path("configs/paths.yml") + with open(paths_yml, "r", encoding="utf-8") as f: + yml_data = yaml.safe_load(f) + if args.assets_root is not None: + yml_data["assets_root"] = args.assets_root + if args.dataset_root is not None: + yml_data["dataset_root"] = args.dataset_root + with open(paths_yml, "w", encoding="utf-8") as f: + yaml.dump(yml_data, f, allow_unicode=True) + + +if __name__ == "__main__": + main() diff --git a/preprocess_all.py b/preprocess_all.py new file mode 100644 index 000000000..82a3c202c --- /dev/null +++ b/preprocess_all.py @@ -0,0 +1,96 @@ +import argparse +from webui_train import preprocess_all +from multiprocessing import cpu_count + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", "-m", type=str, help="Model name", required=True + ) + parser.add_argument("--batch_size", "-b", type=int, help="Batch size", default=2) + parser.add_argument("--epochs", "-e", type=int, help="Epochs", default=100) + parser.add_argument( + "--save_every_steps", + "-s", + type=int, + help="Save every steps", + default=1000, + ) + parser.add_argument( + "--num_processes", + type=int, + help="Number of processes", + default=cpu_count() // 2, + ) + parser.add_argument( + "--normalize", + action="store_true", + help="Loudness normalize audio", + ) + parser.add_argument( + "--trim", + action="store_true", + help="Trim silence", + ) + parser.add_argument( + "--freeze_EN_bert", + action="store_true", + help="Freeze English BERT", + ) + parser.add_argument( + "--freeze_JP_bert", + action="store_true", + help="Freeze Japanese BERT", + ) + parser.add_argument( + "--freeze_ZH_bert", + action="store_true", + help="Freeze Chinese BERT", + ) + parser.add_argument( + "--freeze_style", + action="store_true", + help="Freeze style vector", + ) + parser.add_argument( + "--freeze_decoder", + action="store_true", + help="Freeze decoder", + ) + parser.add_argument( + "--use_jp_extra", + action="store_true", + help="Use JP-Extra model", + ) + parser.add_argument( + "--val_per_lang", + type=int, + help="Validation per language", + default=0, + ) + parser.add_argument( + "--log_interval", + type=int, + help="Log interval", + default=200, + ) + + args = parser.parse_args() + + preprocess_all( + model_name=args.model_name, + batch_size=args.batch_size, + epochs=args.epochs, + save_every_steps=args.save_every_steps, + num_processes=args.num_processes, + normalize=args.normalize, + trim=args.trim, + freeze_EN_bert=args.freeze_EN_bert, + freeze_JP_bert=args.freeze_JP_bert, + freeze_ZH_bert=args.freeze_ZH_bert, + freeze_style=args.freeze_style, + freeze_decoder=args.freeze_decoder, + use_jp_extra=args.use_jp_extra, + val_per_lang=args.val_per_lang, + log_interval=args.log_interval, + ) diff --git a/preprocess_text.py b/preprocess_text.py index 9fba5ebd0..2c26941e7 100644 --- a/preprocess_text.py +++ b/preprocess_text.py @@ -15,6 +15,12 @@ preprocess_text_config = config.preprocess_text_config +# Count lines for tqdm +def count_lines(file_path: str): + with open(file_path, "r", encoding="utf-8") as file: + return sum(1 for _ in file) + + @click.command() @click.option( "--transcription-path", @@ -49,10 +55,14 @@ def preprocess( if cleaned_path == "" or cleaned_path is None: cleaned_path = transcription_path + ".cleaned" + error_log_path = os.path.join(os.path.dirname(cleaned_path), "text_error.log") + error_count = 0 + if clean: + total_lines = count_lines(transcription_path) with open(cleaned_path, "w", encoding="utf-8") as out_file: with open(transcription_path, "r", encoding="utf-8") as trans_file: - for line in tqdm(trans_file, file=SAFE_STDOUT): + for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines): try: utt, spk, language, text = line.strip().split("|") norm_text, phones, tones, word2ph = clean_text( @@ -70,10 +80,10 @@ def preprocess( ) ) except Exception as e: - logger.error( - f"An error occurred while generating the training set and validation set, at line:\n{line}\nDetails:\n{e}" - ) - raise + logger.error(f"An error occurred at line:\n{line.strip()}\n{e}") + with open(error_log_path, "a", encoding="utf-8") as error_log: + error_log.write(f"{line.strip()}\n{e}\n\n") + error_count += 1 transcription_path = cleaned_path spk_utt_map = defaultdict(list) @@ -101,9 +111,10 @@ def preprocess( if spk not in spk_id_map.keys(): spk_id_map[spk] = current_sid current_sid += 1 - logger.info( - f"Total repeated audios: {countSame}, Total number of audio not found: {countNotFound}" - ) + if countSame > 0 or countNotFound > 0: + logger.warning( + f"Total repeated audios: {countSame}, Total number of audio not found: {countNotFound}" + ) train_list = [] val_list = [] @@ -139,7 +150,17 @@ def preprocess( ) with open(config_path, "w", encoding="utf-8") as f: json.dump(json_config, f, indent=2, ensure_ascii=False) - logger.info("Training set and validation set generation from texts is complete!") + if error_count > 0: + logger.error( + f"An error occurred in {error_count} lines. Please check {error_log_path} for details. You can proceed with lines that do not have errors." + ) + raise Exception( + f"An error occurred in {error_count} lines. Please check {error_log_path} for details. You can proceed with lines that do not have errors." + ) + else: + logger.info( + "Training set and validation set generation from texts is complete!" + ) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index f58619a5d..6bf329ffb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,29 +1,29 @@ cmudict cn2an -faster-whisper>=0.10.0,<1.0.0 +faster-whisper==0.10.1 g2p_en GPUtil gradio -jaconv jieba langid librosa==0.9.2 loguru matplotlib -mecab-python3 num2words numba numpy psutil pyannote.audio>=3.1.0 pyloudnorm -pyopenjtalk-prebuilt +# pyopenjtalk-prebuilt # Should be manually uninstalled +pyopenjtalk-dict pypinyin +# pyworld # Not supported on Windows without Cython... PyYAML requests safetensors scipy tensorboard -torch>=2.1,<2.2 # For users without GPU or colab +torch>=2.1,<2.2 transformers umap-learn diff --git a/scripts/Install-Style-Bert-VITS2-CPU.bat b/scripts/Install-Style-Bert-VITS2-CPU.bat index e5d106fa8..aca6f50cd 100644 --- a/scripts/Install-Style-Bert-VITS2-CPU.bat +++ b/scripts/Install-Style-Bert-VITS2-CPU.bat @@ -33,7 +33,7 @@ xcopy /QSY .\Style-Bert-VITS2-master\ .\Style-Bert-VITS2\ rmdir /s /q Style-Bert-VITS2-master echo ---------------------------------------- -echo Python環境の構築を開始します。 +echo Setup Python and Virtual Environment echo ---------------------------------------- @REM Pythonと仮想環境のセットアップを呼び出す(仮想環境が有効化されて戻ってくる) @@ -45,7 +45,7 @@ pip install -r Style-Bert-VITS2\requirements.txt if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) echo ---------------------------------------- -echo 環境構築が完了しました。モデルのダウンロードを開始します。 +echo Environment setup is complete. Start downloading the model. echo ---------------------------------------- @REM Style-Bert-VITS2フォルダに移動 @@ -55,12 +55,11 @@ pushd Style-Bert-VITS2 python initialize.py echo ---------------------------------------- -echo モデルのダウンロードが完了し、インストールが完了しました! -echo 音声合成のWebUIを起動します。 +echo Model download is complete. Start Style-Bert-VITS2 Editor. echo ---------------------------------------- -@REM 音声合成WebUIの起動 -python app.py +@REM エディターの起動 +python server_editor.py --inbrowser pause diff --git a/scripts/Install-Style-Bert-VITS2.bat b/scripts/Install-Style-Bert-VITS2.bat index 6e99230c3..59fe97ec3 100644 --- a/scripts/Install-Style-Bert-VITS2.bat +++ b/scripts/Install-Style-Bert-VITS2.bat @@ -58,11 +58,11 @@ pushd Style-Bert-VITS2 python initialize.py echo ---------------------------------------- -echo Model download is complete. Start the WebUI of the voice synthesis. +echo Model download is complete. Start Style-Bert-VITS2 Editor. echo ---------------------------------------- -@REM 音声合成WebUIの起動 -python app.py +@REM エディターの起動 +python server_editor.py --inbrowser pause diff --git a/scripts/Update-Style-Bert-VITS2.bat b/scripts/Update-Style-Bert-VITS2.bat index 782f2efe9..b4f8b8d6c 100644 --- a/scripts/Update-Style-Bert-VITS2.bat +++ b/scripts/Update-Style-Bert-VITS2.bat @@ -37,7 +37,10 @@ if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) pip install -U -r Style-Bert-VITS2\requirements.txt if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) -echo Style-Bert-VITS2のアップデートが完了しました。 +echo Update completed. Running Style-Bert-VITS2 Editor... + +@REM Style-Bert-VITS2 Editorを起動 +python server_editor.py --inbrowser pause diff --git a/scripts/Update-to-JP-Extra.bat b/scripts/Update-to-Dict-Editor.bat similarity index 59% rename from scripts/Update-to-JP-Extra.bat rename to scripts/Update-to-Dict-Editor.bat index 95738c698..1eff1bd60 100644 --- a/scripts/Update-to-JP-Extra.bat +++ b/scripts/Update-to-Dict-Editor.bat @@ -28,21 +28,39 @@ xcopy /QSY .\Style-Bert-VITS2-master\ .\Style-Bert-VITS2\ rmdir /s /q Style-Bert-VITS2-master if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) -@REM 仮想環境のpip requirements.txtを更新 +@REM 仮想環境を有効化 echo call .\Style-Bert-VITS2\scripts\activate.bat call .\Style-Bert-VITS2\venv\Scripts\activate.bat if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) -pip install -U -r Style-Bert-VITS2\requirements.txt +@REM pyopenjtalk-prebuiltやpyopenjtalkが入っていたら削除 +echo python -m pip uninstall -y pyopenjtalk-prebuilt pyopenjtalk +python -m pip uninstall -y pyopenjtalk-prebuilt pyopenjtalk +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +@REM pyopenjtalk-dictをインストール +echo python -m pip install -U pyopenjtalk-dict +python -m pip install -U pyopenjtalk-dict + +@REM その他のrequirements.txtも一応更新 +python -m pip install -U -r Style-Bert-VITS2\requirements.txt if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) pushd Style-Bert-VITS2 -@REM 初期化(必要なモデルのダウンロード) +@REM JP-Extra版以前からの場合のために一応initialize.pyを実行 + +echo python initialize.py python initialize.py +if %errorlevel% neq 0 ( pause & popd & exit /b %errorlevel% ) + +echo ---------------------------------------- +echo Update completed. Running Style-Bert-VITS2 Editor... +echo ---------------------------------------- -echo Style-Bert-VITS2の2.xへのアップデートが完了しました。 +@REM Style-Bert-VITS2 Editorを起動 +python server_editor.py --inbrowser pause diff --git a/server_editor.py b/server_editor.py new file mode 100644 index 000000000..93c7c58c6 --- /dev/null +++ b/server_editor.py @@ -0,0 +1,421 @@ +""" +Style-Bert-VITS2-Editor用のサーバー。 +次のリポジトリ +https://github.com/litagin02/Style-Bert-VITS2-Editor +をビルドしてできあがったファイルをWebフォルダに入れて実行する。 + +TODO: リファクタリングやドキュメンテーションやAPI整理、辞書周りの改善などが必要。 +""" + +import argparse +import io +import shutil +import sys +import webbrowser +import zipfile +from datetime import datetime +from io import BytesIO +from pathlib import Path +import yaml + +import numpy as np +import pyopenjtalk +import requests +import torch +import uvicorn +from fastapi import APIRouter, FastAPI, HTTPException, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel +from scipy.io import wavfile + +from common.constants import ( + DEFAULT_ASSIST_TEXT_WEIGHT, + DEFAULT_NOISE, + DEFAULT_NOISEW, + DEFAULT_SDP_RATIO, + DEFAULT_STYLE, + DEFAULT_STYLE_WEIGHT, + LATEST_VERSION, + Languages, +) +from common.log import logger +from common.tts_model import ModelHolder +from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize +from text.user_dict import apply_word, update_dict, read_dict, rewrite_word, delete_word + + +# ---フロントエンド部分に関する処理--- + +# エディターのビルドファイルを配置するディレクトリ +STATIC_DIR = Path("static") +# エディターの最新のビルドファイルのダウンロード日時を記録するファイル +LAST_DOWNLOAD_FILE = STATIC_DIR / "last_download.txt" + + +def download_static_files(user, repo, asset_name): + """Style-Bert-VITS2エディターの最新のビルドzipをダウンロードして展開する。""" + + logger.info("Checking for new release...") + latest_release = get_latest_release(user, repo) + if latest_release is None: + logger.warning( + "Failed to fetch the latest release. Proceeding without static files." + ) + return + + if not new_release_available(latest_release): + logger.info("No new release available. Proceeding with existing static files.") + return + + logger.info("New release available. Downloading static files...") + asset_url = get_asset_url(latest_release, asset_name) + if asset_url: + if STATIC_DIR.exists(): + shutil.rmtree(STATIC_DIR) + STATIC_DIR.mkdir(parents=True, exist_ok=True) + download_and_extract(asset_url, STATIC_DIR) + save_last_download(latest_release) + else: + logger.warning("Asset not found. Proceeding without static files.") + + +def get_latest_release(user, repo): + url = f"https://api.github.com/repos/{user}/{repo}/releases/latest" + try: + response = requests.get(url) + response.raise_for_status() + return response.json() + except requests.RequestException: + return None + + +def get_asset_url(release, asset_name): + for asset in release["assets"]: + if asset["name"] == asset_name: + return asset["browser_download_url"] + return None + + +def download_and_extract(url, extract_to: Path): + response = requests.get(url) + response.raise_for_status() + with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref: + zip_ref.extractall(extract_to) + + # 展開先が1つのディレクトリだけの場合、その中身を直下に移動する + extracted_dirs = list(extract_to.iterdir()) + if len(extracted_dirs) == 1 and extracted_dirs[0].is_dir(): + for file in extracted_dirs[0].iterdir(): + file.rename(extract_to / file.name) + extracted_dirs[0].rmdir() + + # index.htmlが存在するかチェック + if not (extract_to / "index.html").exists(): + logger.warning("index.html not found in the extracted files.") + + +def new_release_available(latest_release): + if LAST_DOWNLOAD_FILE.exists(): + with open(LAST_DOWNLOAD_FILE, "r") as file: + last_download_str = file.read().strip() + # 'Z'を除去して日時オブジェクトに変換 + last_download_str = last_download_str.replace("Z", "+00:00") + last_download = datetime.fromisoformat(last_download_str) + return ( + datetime.fromisoformat( + latest_release["published_at"].replace("Z", "+00:00") + ) + > last_download + ) + return True + + +def save_last_download(latest_release): + with open(LAST_DOWNLOAD_FILE, "w") as file: + file.write(latest_release["published_at"]) + + +# ---フロントエンド部分に関する処理ここまで--- +# 以降はAPIの設定 + + +class AudioResponse(Response): + media_type = "audio/wav" + + +origins = [ + "http://localhost:3000", + "http://localhost:8000", + "http://127.0.0.1:3000", + "http://127.0.0.1:8000", +] + +# Get path settings +with open(Path("configs/paths.yml"), "r", encoding="utf-8") as f: + path_config: dict[str, str] = yaml.safe_load(f.read()) + # dataset_root = path_config["dataset_root"] + assets_root = path_config["assets_root"] + +parser = argparse.ArgumentParser() +parser.add_argument("--model_dir", type=str, default="model_assets/") +parser.add_argument("--device", type=str, default="cuda") +parser.add_argument("--port", type=int, default=8000) +parser.add_argument("--inbrowser", action="store_true") +parser.add_argument("--line_length", type=int, default=None) +parser.add_argument("--line_count", type=int, default=None) +parser.add_argument( + "--dir", "-d", type=str, help="Model directory", default=assets_root +) + +args = parser.parse_args() +device = args.device +if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" +model_dir = Path(args.model_dir) +port = int(args.port) + +model_holder = ModelHolder(model_dir, device) +if len(model_holder.model_names) == 0: + logger.error(f"Models not found in {model_dir}.") + sys.exit(1) + +app = FastAPI() + + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +router = APIRouter() + + +@router.get("/version") +def version() -> str: + return LATEST_VERSION + + +class MoraTone(BaseModel): + mora: str + tone: int + + +class TextRequest(BaseModel): + text: str + + +@router.post("/g2p") +async def read_item(item: TextRequest): + try: + # 最初に正規化しないと整合性がとれない + text = text_normalize(item.text) + kata_tone_list = g2kata_tone(text, ignore_unknown=True) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to convert {item.text} to katakana and tone, {e}", + ) + return [MoraTone(mora=kata, tone=tone) for kata, tone in kata_tone_list] + + +@router.post("/normalize") +async def normalize_text(item: TextRequest): + return text_normalize(item.text) + + +@router.get("/models_info") +def models_info(): + return model_holder.models_info() + + +class SynthesisRequest(BaseModel): + model: str + modelFile: str + text: str + moraToneList: list[MoraTone] + style: str = DEFAULT_STYLE + styleWeight: float = DEFAULT_STYLE_WEIGHT + assistText: str = "" + assistTextWeight: float = DEFAULT_ASSIST_TEXT_WEIGHT + speed: float = 1.0 + noise: float = DEFAULT_NOISE + noisew: float = DEFAULT_NOISEW + sdpRatio: float = DEFAULT_SDP_RATIO + language: Languages = Languages.JP + silenceAfter: float = 0.5 + pitchScale: float = 1.0 + intonationScale: float = 1.0 + + +@router.post("/synthesis", response_class=AudioResponse) +def synthesis(request: SynthesisRequest): + if args.line_length is not None and len(request.text) > args.line_length: + raise HTTPException( + status_code=400, + detail=f"1行の文字数は{args.line_length}文字以下にしてください。", + ) + try: + model = model_holder.load_model( + model_name=request.model, model_path_str=request.modelFile + ) + except Exception as e: + logger.error(e) + raise HTTPException( + status_code=500, + detail=f"Failed to load model {request.model} from {request.modelFile}, {e}", + ) + text = request.text + kata_tone_list = [ + (mora_tone.mora, mora_tone.tone) for mora_tone in request.moraToneList + ] + phone_tone = kata_tone2phone_tone(kata_tone_list) + tone = [t for _, t in phone_tone] + sr, audio = model.infer( + text=text, + language=request.language.value, + sdp_ratio=request.sdpRatio, + noise=request.noise, + noisew=request.noisew, + length=1 / request.speed, + given_tone=tone, + style=request.style, + style_weight=request.styleWeight, + assist_text=request.assistText, + assist_text_weight=request.assistTextWeight, + use_assist_text=bool(request.assistText), + line_split=False, + ignore_unknown=True, + pitch_scale=request.pitchScale, + intonation_scale=request.intonationScale, + ) + + with BytesIO() as wavContent: + wavfile.write(wavContent, sr, audio) + return Response(content=wavContent.getvalue(), media_type="audio/wav") + + +class MultiSynthesisRequest(BaseModel): + lines: list[SynthesisRequest] + + +@router.post("/multi_synthesis", response_class=AudioResponse) +def multi_synthesis(request: MultiSynthesisRequest): + lines = request.lines + if args.line_count is not None and len(lines) > args.line_count: + raise HTTPException( + status_code=400, + detail=f"行数は{args.line_count}行以下にしてください。", + ) + audios = [] + for i, req in enumerate(lines): + if args.line_length is not None and len(req.text) > args.line_length: + raise HTTPException( + status_code=400, + detail=f"1行の文字数は{args.line_length}文字以下にしてください。", + ) + try: + model = model_holder.load_model( + model_name=req.model, model_path_str=req.modelFile + ) + except Exception as e: + logger.error(e) + raise HTTPException( + status_code=500, + detail=f"Failed to load model {req.model} from {req.modelFile}, {e}", + ) + text = req.text + kata_tone_list = [ + (mora_tone.mora, mora_tone.tone) for mora_tone in req.moraToneList + ] + phone_tone = kata_tone2phone_tone(kata_tone_list) + tone = [t for _, t in phone_tone] + sr, audio = model.infer( + text=text, + language=req.language.value, + sdp_ratio=req.sdpRatio, + noise=req.noise, + noisew=req.noisew, + length=1 / req.speed, + given_tone=tone, + style=req.style, + style_weight=req.styleWeight, + assist_text=req.assistText, + assist_text_weight=req.assistTextWeight, + use_assist_text=bool(req.assistText), + line_split=False, + ignore_unknown=True, + pitch_scale=req.pitchScale, + intonation_scale=req.intonationScale, + ) + audios.append(audio) + if i < len(lines) - 1: + silence = int(sr * req.silenceAfter) + audios.append(np.zeros(silence, dtype=np.int16)) + audio = np.concatenate(audios) + + with BytesIO() as wavContent: + wavfile.write(wavContent, sr, audio) + return Response(content=wavContent.getvalue(), media_type="audio/wav") + + +class UserDictWordRequest(BaseModel): + surface: str + pronunciation: str + accent_type: int # アクセント核位置(存在しない場合は0、1文字目は1) + priority: int = 5 + + +@router.get("/user_dict") +def get_user_dict(): + return read_dict() + + +@router.post("/user_dict_word") +def add_user_dict_word(request: UserDictWordRequest): + uuid = apply_word( + surface=request.surface, + pronunciation=request.pronunciation, + accent_type=request.accent_type, + priority=request.priority, + ) + update_dict() + + return JSONResponse( + status_code=status.HTTP_201_CREATED, + content={"uuid": uuid}, + ) + + +@router.put("/user_dict_word/{uuid}") +def update_user_dict_word(uuid: str, request: UserDictWordRequest): + rewrite_word( + word_uuid=uuid, + surface=request.surface, + pronunciation=request.pronunciation, + accent_type=request.accent_type, + priority=request.priority, + ) + update_dict() + return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) + + +@router.delete("/user_dict_word/{uuid}") +def delete_user_dict_word(uuid: str): + delete_word(uuid) + update_dict() + return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) + + +app.include_router(router, prefix="/api") + +if __name__ == "__main__": + download_static_files("litagin02", "Style-Bert-VITS2-Editor", "out.zip") + app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") + if args.inbrowser: + webbrowser.open(f"http://localhost:{port}") + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/server_fastapi.py b/server_fastapi.py index 40877fd98..ce6ed0432 100644 --- a/server_fastapi.py +++ b/server_fastapi.py @@ -1,10 +1,13 @@ """ API server for TTS +TODO: server_editor.pyと統合する? """ + import argparse import os import sys from io import BytesIO +from pathlib import Path from typing import Dict, Optional, Union from urllib.parse import unquote @@ -53,10 +56,8 @@ def load_models(model_holder: ModelHolder): for model_name, model_paths in model_holder.model_files_dict.items(): model = Model( model_path=model_paths[0], - config_path=os.path.join(model_holder.root_dir, model_name, "config.json"), - style_vec_path=os.path.join( - model_holder.root_dir, model_name, "style_vectors.npy" - ), + config_path=model_holder.root_dir / model_name / "config.json", + style_vec_path=model_holder.root_dir / model_name / "style_vectors.npy", device=model_holder.device, ) model.load_net_g() @@ -76,7 +77,7 @@ def load_models(model_holder: ModelHolder): else: device = "cuda" if torch.cuda.is_available() else "cpu" - model_dir = args.dir + model_dir = Path(args.dir) model_holder = ModelHolder(model_dir, device) if len(model_holder.model_names) == 0: logger.error(f"Models not found in {model_dir}.") @@ -105,9 +106,12 @@ async def voice( request: Request, text: str = Query(..., min_length=1, max_length=limit, description=f"セリフ"), encoding: str = Query(None, description="textをURLデコードする(ex, `utf-8`)"), - model_id: int = Query(0, description="モデルID。`GET /models/info`のkeyの値を指定ください"), + model_id: int = Query( + 0, description="モデルID。`GET /models/info`のkeyの値を指定ください" + ), speaker_name: str = Query( - None, description="話者名(speaker_idより優先)。esd.listの2列目の文字列を指定" + None, + description="話者名(speaker_idより優先)。esd.listの2列目の文字列を指定", ), speaker_id: int = Query( 0, description="話者ID。model_assets>[model]>config.json内のspk2idを確認" @@ -116,12 +120,17 @@ async def voice( DEFAULT_SDP_RATIO, description="SDP(Stochastic Duration Predictor)/DP混合比。比率が高くなるほどトーンのばらつきが大きくなる", ), - noise: float = Query(DEFAULT_NOISE, description="サンプルノイズの割合。大きくするほどランダム性が高まる"), + noise: float = Query( + DEFAULT_NOISE, + description="サンプルノイズの割合。大きくするほどランダム性が高まる", + ), noisew: float = Query( - DEFAULT_NOISEW, description="SDPノイズ。大きくするほど発音の間隔にばらつきが出やすくなる" + DEFAULT_NOISEW, + description="SDPノイズ。大きくするほど発音の間隔にばらつきが出やすくなる", ), length: float = Query( - DEFAULT_LENGTH, description="話速。基準は1で大きくするほど音声は長くなり読み上げが遅まる" + DEFAULT_LENGTH, + description="話速。基準は1で大きくするほど音声は長くなり読み上げが遅まる", ), language: Languages = Query(ln, description=f"textの言語"), auto_split: bool = Query(DEFAULT_LINE_SPLIT, description="改行で分けて生成"), @@ -129,20 +138,25 @@ async def voice( DEFAULT_SPLIT_INTERVAL, description="分けた場合に挟む無音の長さ(秒)" ), assist_text: Optional[str] = Query( - None, description="このテキストの読み上げと似た声音・感情になりやすくなる。ただし抑揚やテンポ等が犠牲になる傾向がある" + None, + description="このテキストの読み上げと似た声音・感情になりやすくなる。ただし抑揚やテンポ等が犠牲になる傾向がある", ), assist_text_weight: float = Query( DEFAULT_ASSIST_TEXT_WEIGHT, description="assist_textの強さ" ), style: Optional[Union[int, str]] = Query(DEFAULT_STYLE, description="スタイル"), style_weight: float = Query(DEFAULT_STYLE_WEIGHT, description="スタイルの強さ"), - reference_audio_path: Optional[str] = Query(None, description="スタイルを音声ファイルで行う"), + reference_audio_path: Optional[str] = Query( + None, description="スタイルを音声ファイルで行う" + ), ): """Infer text to speech(テキストから感情付き音声を生成する)""" logger.info( f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}" ) - if model_id >= len(model_holder.models): # /models/refresh があるためQuery(le)で表現不可 + if model_id >= len( + model_holder.models + ): # /models/refresh があるためQuery(le)で表現不可 raise_validation_error(f"model_id={model_id} not found", "model_id") model = model_holder.models[model_id] diff --git a/slice.py b/slice.py index 3b6e2d789..2d56427d1 100644 --- a/slice.py +++ b/slice.py @@ -5,6 +5,7 @@ import soundfile as sf import torch +import yaml from tqdm import tqdm from common.log import logger @@ -76,6 +77,7 @@ def split_wav( os.makedirs(target_dir, exist_ok=True) total_time_ms = 0 + count = 0 # タイムスタンプに従って分割し、ファイルに保存 for i, ts in enumerate(speech_timestamps): @@ -88,8 +90,9 @@ def split_wav( sf.write(os.path.join(target_dir, f"{file_name}-{i}.wav"), segment, sr) total_time_ms += end_ms - start_ms + count += 1 - return total_time_ms / 1000 + return total_time_ms / 1000, count if __name__ == "__main__": @@ -108,11 +111,10 @@ def split_wav( help="Directory of input wav files", ) parser.add_argument( - "--output_dir", - "-o", + "--model_name", type=str, - default="raw", - help="Directory of output wav files", + required=True, + help="The result will be in Data/{model_name}/raw/ (if Data is dataset_root in configs/paths.yml)", ) parser.add_argument( "--min_silence_dur_ms", @@ -123,8 +125,12 @@ def split_wav( ) args = parser.parse_args() + with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: + path_config: dict[str, str] = yaml.safe_load(f.read()) + dataset_root = path_config["dataset_root"] + input_dir = args.input_dir - output_dir = args.output_dir + output_dir = os.path.join(dataset_root, args.model_name, "raw") min_sec = args.min_sec max_sec = args.max_sec min_silence_dur_ms = args.min_silence_dur_ms @@ -137,8 +143,9 @@ def split_wav( shutil.rmtree(output_dir) total_sec = 0 + total_count = 0 for wav_file in tqdm(wav_files, file=SAFE_STDOUT): - time_sec = split_wav( + time_sec, count = split_wav( audio_file=str(wav_file), target_dir=output_dir, min_sec=min_sec, @@ -146,5 +153,8 @@ def split_wav( min_silence_dur_ms=min_silence_dur_ms, ) total_sec += time_sec + total_count += count - logger.info(f"Slice done! Total time: {total_sec / 60:.2f} min.") + logger.info( + f"Slice done! Total time: {total_sec / 60:.2f} min, {total_count} files." + ) diff --git a/speech_mos.py b/speech_mos.py index 8464223a5..d69a23ab5 100644 --- a/speech_mos.py +++ b/speech_mos.py @@ -99,7 +99,7 @@ def get_model(model_file: Path): logger.info(f"{model_file}: {scores[-1]}") with open( - mos_result_dir / f"mos_{model_name}.csv", "w", encoding="utf-8", newline="" + mos_result_dir / f"mos_{model_name}.csv", "w", encoding="utf_8_sig", newline="" ) as f: writer = csv.writer(f) writer.writerow(["model_path"] + ["step"] + test_texts + ["mean"]) diff --git a/text/__init__.py b/text/__init__.py index 495e57b50..2151cafc5 100644 --- a/text/__init__.py +++ b/text/__init__.py @@ -19,14 +19,27 @@ def cleaned_text_to_sequence(cleaned_text, tones, language): def get_bert( - norm_text, word2ph, language, device, assist_text=None, assist_text_weight=0.7 + text, + word2ph, + language, + device, + assist_text=None, + assist_text_weight=0.7, + ignore_unknown=False, ): - from .chinese_bert import get_bert_feature as zh_bert - from .english_bert_mock import get_bert_feature as en_bert - from .japanese_bert import get_bert_feature as jp_bert - - lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert} - bert = lang_bert_func_map[language]( - norm_text, word2ph, device, assist_text, assist_text_weight - ) - return bert + if language == "ZH": + from .chinese_bert import get_bert_feature as zh_bert + + return zh_bert(text, word2ph, device, assist_text, assist_text_weight) + elif language == "EN": + from .english_bert_mock import get_bert_feature as en_bert + + return en_bert(text, word2ph, device, assist_text, assist_text_weight) + elif language == "JP": + from .japanese_bert import get_bert_feature as jp_bert + + return jp_bert( + text, word2ph, device, assist_text, assist_text_weight, ignore_unknown + ) + else: + raise ValueError(f"Language {language} not supported") diff --git a/text/cleaner.py b/text/cleaner.py index ab59b4e1a..8da4c7b6e 100644 --- a/text/cleaner.py +++ b/text/cleaner.py @@ -1,31 +1,26 @@ -from text import chinese, japanese, english, cleaned_text_to_sequence +def clean_text(text, language, use_jp_extra=True, ignore_unknown=False): + # Changed to import inside if condition to avoid unnecessary import + if language == "ZH": + from . import chinese as language_module + norm_text = language_module.text_normalize(text) + phones, tones, word2ph = language_module.g2p(norm_text) + elif language == "EN": + from . import english as language_module -language_module_map = {"ZH": chinese, "JP": japanese, "EN": english} - + norm_text = language_module.text_normalize(text) + phones, tones, word2ph = language_module.g2p(norm_text) + elif language == "JP": + from . import japanese as language_module -def clean_text(text, language, use_jp_extra=True): - language_module = language_module_map[language] - norm_text = language_module.text_normalize(text) - if language == "JP": - phones, tones, word2ph = language_module.g2p(norm_text, use_jp_extra) + norm_text = language_module.text_normalize(text) + phones, tones, word2ph = language_module.g2p( + norm_text, use_jp_extra, ignore_unknown=ignore_unknown + ) else: - phones, tones, word2ph = language_module.g2p(norm_text) + raise ValueError(f"Language {language} not supported") return norm_text, phones, tones, word2ph -def clean_text_bert(text, language): - language_module = language_module_map[language] - norm_text = language_module.text_normalize(text) - phones, tones, word2ph = language_module.g2p(norm_text) - bert = language_module.get_bert_feature(norm_text, word2ph) - return phones, tones, bert - - -def text_to_sequence(text, language): - norm_text, phones, tones, word2ph = clean_text(text, language) - return cleaned_text_to_sequence(phones, tones, language) - - if __name__ == "__main__": pass diff --git a/text/japanese.py b/text/japanese.py index 2b0f24ba0..5a36eb69e 100644 --- a/text/japanese.py +++ b/text/japanese.py @@ -2,6 +2,7 @@ # compatible with Julius https://github.com/julius-speech/segmentation-kit import re import unicodedata +from pathlib import Path import pyopenjtalk from num2words import num2words @@ -14,6 +15,11 @@ mora_phonemes_to_mora_kata, ) +from text.user_dict import update_dict + +# 最初にpyopenjtalkの辞書を更新 +update_dict() + # 子音の集合 COSONANTS = set( [ @@ -160,7 +166,7 @@ def japanese_convert_numbers_to_words(text: str) -> str: def g2p( - norm_text: str, use_jp_extra: bool = True + norm_text: str, use_jp_extra: bool = True, ignore_unknown: bool = False ) -> tuple[list[str], list[int], list[int]]: """ 他で使われるメインの関数。`text_normalize()`で正規化された`norm_text`を受け取り、 @@ -182,7 +188,7 @@ def g2p( # sep_text: 単語単位の単語のリスト # sep_kata: 単語単位の単語のカタカナ読みのリスト - sep_text, sep_kata = text2sep_kata(norm_text) + sep_text, sep_kata = text2sep_kata(norm_text, ignore_unknown=ignore_unknown) # sep_phonemes: 各単語ごとの音素のリストのリスト sep_phonemes = handle_long([kata2phoneme_list(i) for i in sep_kata]) @@ -231,8 +237,8 @@ def g2p( return phones, tones, word2ph -def g2kata_tone(norm_text: str) -> list[tuple[str, int]]: - phones, tones, _ = g2p(norm_text, use_jp_extra=True) +def g2kata_tone(norm_text: str, ignore_unknown: bool = False) -> list[tuple[str, int]]: + phones, tones, _ = g2p(norm_text, use_jp_extra=True, ignore_unknown=ignore_unknown) return phone_tone2kata_tone(list(zip(phones, tones))) @@ -325,7 +331,9 @@ def g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]: return result -def text2sep_kata(norm_text: str) -> tuple[list[str], list[str]]: +def text2sep_kata( + norm_text: str, ignore_unknown: bool = False +) -> tuple[list[str], list[str]]: """ `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、 分割された単語リストとその読み(カタカナor記号1文字)のリストのタプルを返す。 @@ -361,6 +369,9 @@ def text2sep_kata(norm_text: str) -> tuple[list[str], list[str]]: # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`, `--` のいずれか if not set(word).issubset(set(punctuation)): # 記号繰り返しか判定 # ここはpyopenjtalkが読めない文字等のときに起こる + if ignore_unknown: + logger.error(f"Ignoring unknown: {word} in:\n{norm_text}") + continue raise ValueError(f"Cannot read: {word} in:\n{norm_text}") # yomiは元の記号のままに変更 yomi = word @@ -500,6 +511,9 @@ def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]: おそらく長音記号とダッシュを勘違いしていると思われるので、ダッシュに対応する音素`-`に変換する。 """ for i in range(len(sep_phonemes)): + if len(sep_phonemes[i]) == 0: + # 空白文字等でリストが空の場合 + continue if sep_phonemes[i][0] == "ー": if i != 0: prev_phoneme = sep_phonemes[i - 1][-1] diff --git a/text/japanese_bert.py b/text/japanese_bert.py index 2974c76e6..9efe7a46a 100644 --- a/text/japanese_bert.py +++ b/text/japanese_bert.py @@ -4,7 +4,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer from config import config -from text.japanese import text2sep_kata +from text.japanese import text2sep_kata, text_normalize LOCAL_PATH = "./bert/deberta-v2-large-japanese-char-wwm" @@ -19,8 +19,10 @@ def get_bert_feature( device=config.bert_gen_config.device, assist_text=None, assist_text_weight=0.7, + ignore_unknown=False, ): - text = "".join(text2sep_kata(text)[0]) + text = "".join(text2sep_kata(text, ignore_unknown=ignore_unknown)[0]) + # text = text_normalize(text) if assist_text: assist_text = "".join(text2sep_kata(assist_text)[0]) if ( diff --git a/text/tone_sandhi.py b/text/tone_sandhi.py index 372308604..38f313785 100644 --- a/text/tone_sandhi.py +++ b/text/tone_sandhi.py @@ -497,7 +497,10 @@ def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]: # 个做量词 elif ( ge_idx >= 1 - and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是") + and ( + word[ge_idx - 1].isnumeric() + or word[ge_idx - 1] in "几有两半多各整每做是" + ) ) or word == "个": finals[ge_idx] = finals[ge_idx][:-1] + "5" else: diff --git a/text/user_dict/README.md b/text/user_dict/README.md new file mode 100644 index 000000000..6f5618eda --- /dev/null +++ b/text/user_dict/README.md @@ -0,0 +1,19 @@ +このフォルダに含まれるユーザー辞書関連のコードは、[VOICEVOX engine](https://github.com/VOICEVOX/voicevox_engine)プロジェクトのコードを改変したものを使用しています。VOICEVOXプロジェクトのチームに深く感謝し、その貢献を尊重します。 + +**元のコード**: + +- [voicevox_engine/user_dict/](https://github.com/VOICEVOX/voicevox_engine/tree/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/user_dict) +- [voicevox_engine/model.py](https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/model.py#L207) + +**改変の詳細**: + +- ファイル名の書き換えおよびそれに伴うimport文の書き換え。 +- VOICEVOX固有の部分をコメントアウト。 +- mutexを使用している部分をコメントアウト。 +- 参照しているpyopenjtalkの違いによるメソッド名の書き換え。 +- UserDictWordのmora_countのデフォルト値をNoneに指定。 +- Pydanticのモデルで必要な箇所のみを抽出。 + +**ライセンス**: + +元のVOICEVOX engineのリポジトリのコードは、LGPL v3 と、ソースコードの公開が不要な別ライセンスのデュアルライセンスの下で使用されています。当プロジェクトにおけるこのモジュールもLGPLライセンスの下にあります。詳細については、プロジェクトのルートディレクトリにある[LGPL_LICENSE](/LGPL_LICENSE)ファイルをご参照ください。また、元のVOICEVOX engineプロジェクトのライセンスについては、[こちら](https://github.com/VOICEVOX/voicevox_engine/blob/master/LICENSE)をご覧ください。 diff --git a/text/user_dict/__init__.py b/text/user_dict/__init__.py new file mode 100644 index 000000000..c12b3d1aa --- /dev/null +++ b/text/user_dict/__init__.py @@ -0,0 +1,467 @@ +# このファイルは、VOICEVOXプロジェクトのVOICEVOX engineからお借りしています。 +# 引用元: +# https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/user_dict/user_dict.py +# ライセンス: LGPL-3.0 +# 詳しくは、このファイルと同じフォルダにあるREADME.mdを参照してください。 +import json +import sys +import threading +import traceback +from pathlib import Path +from typing import Dict, List, Optional +from uuid import UUID, uuid4 + +import numpy as np +import pyopenjtalk +from fastapi import HTTPException + +from .word_model import UserDictWord, WordTypes + +# from ..utility.mutex_utility import mutex_wrapper +# from ..utility.path_utility import engine_root, get_save_dir +from .part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data +from common.constants import USER_DICT_DIR + +# root_dir = engine_root() +# save_dir = get_save_dir() +root_dir = Path(USER_DICT_DIR) +save_dir = Path(USER_DICT_DIR) + + +if not save_dir.is_dir(): + save_dir.mkdir(parents=True) + +default_dict_path = root_dir / "default.csv" # VOICEVOXデフォルト辞書ファイルのパス +user_dict_path = save_dir / "user_dict.json" # ユーザー辞書ファイルのパス +compiled_dict_path = save_dir / "user.dic" # コンパイル済み辞書ファイルのパス + + +# # 同時書き込みの制御 +# mutex_user_dict = threading.Lock() +# mutex_openjtalk_dict = threading.Lock() + + +# @mutex_wrapper(mutex_user_dict) +def _write_to_json(user_dict: Dict[str, UserDictWord], user_dict_path: Path) -> None: + """ + ユーザー辞書ファイルへのユーザー辞書データ書き込み + Parameters + ---------- + user_dict : Dict[str, UserDictWord] + ユーザー辞書データ + user_dict_path : Path + ユーザー辞書ファイルのパス + """ + converted_user_dict = {} + for word_uuid, word in user_dict.items(): + word_dict = word.dict() + word_dict["cost"] = _priority2cost( + word_dict["context_id"], word_dict["priority"] + ) + del word_dict["priority"] + converted_user_dict[word_uuid] = word_dict + # 予めjsonに変換できることを確かめる + user_dict_json = json.dumps(converted_user_dict, ensure_ascii=False) + + # ユーザー辞書ファイルへの書き込み + user_dict_path.write_text(user_dict_json, encoding="utf-8") + + +# @mutex_wrapper(mutex_openjtalk_dict) +def update_dict( + default_dict_path: Path = default_dict_path, + user_dict_path: Path = user_dict_path, + compiled_dict_path: Path = compiled_dict_path, +) -> None: + """ + 辞書の更新 + Parameters + ---------- + default_dict_path : Path + デフォルト辞書ファイルのパス + user_dict_path : Path + ユーザー辞書ファイルのパス + compiled_dict_path : Path + コンパイル済み辞書ファイルのパス + """ + random_string = uuid4() + tmp_csv_path = compiled_dict_path.with_suffix( + f".dict_csv-{random_string}.tmp" + ) # csv形式辞書データの一時保存ファイル + tmp_compiled_path = compiled_dict_path.with_suffix( + f".dict_compiled-{random_string}.tmp" + ) # コンパイル済み辞書データの一時保存ファイル + + try: + # 辞書.csvを作成 + csv_text = "" + + # デフォルト辞書データの追加 + if not default_dict_path.is_file(): + print("Warning: Cannot find default dictionary.", file=sys.stderr) + return + default_dict = default_dict_path.read_text(encoding="utf-8") + if default_dict == default_dict.rstrip(): + default_dict += "\n" + csv_text += default_dict + + # ユーザー辞書データの追加 + user_dict = read_dict(user_dict_path=user_dict_path) + for word_uuid in user_dict: + word = user_dict[word_uuid] + csv_text += ( + "{surface},{context_id},{context_id},{cost},{part_of_speech}," + + "{part_of_speech_detail_1},{part_of_speech_detail_2}," + + "{part_of_speech_detail_3},{inflectional_type}," + + "{inflectional_form},{stem},{yomi},{pronunciation}," + + "{accent_type}/{mora_count},{accent_associative_rule}\n" + ).format( + surface=word.surface, + context_id=word.context_id, + cost=_priority2cost(word.context_id, word.priority), + part_of_speech=word.part_of_speech, + part_of_speech_detail_1=word.part_of_speech_detail_1, + part_of_speech_detail_2=word.part_of_speech_detail_2, + part_of_speech_detail_3=word.part_of_speech_detail_3, + inflectional_type=word.inflectional_type, + inflectional_form=word.inflectional_form, + stem=word.stem, + yomi=word.yomi, + pronunciation=word.pronunciation, + accent_type=word.accent_type, + mora_count=word.mora_count, + accent_associative_rule=word.accent_associative_rule, + ) + # 辞書データを辞書.csv へ一時保存 + tmp_csv_path.write_text(csv_text, encoding="utf-8") + + # 辞書.csvをOpenJTalk用にコンパイル + # pyopenjtalk.create_user_dict(str(tmp_csv_path), str(tmp_compiled_path)) + pyopenjtalk.mecab_dict_index(str(tmp_csv_path), str(tmp_compiled_path)) + if not tmp_compiled_path.is_file(): + raise RuntimeError("辞書のコンパイル時にエラーが発生しました。") + + # コンパイル済み辞書の置き換え・読み込み + pyopenjtalk.unset_user_dict() + tmp_compiled_path.replace(compiled_dict_path) + if compiled_dict_path.is_file(): + # pyopenjtalk.set_user_dict(str(compiled_dict_path.resolve(strict=True))) + pyopenjtalk.update_global_jtalk_with_user_dict(str(compiled_dict_path)) + + except Exception as e: + print("Error: Failed to update dictionary.", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + raise e + + finally: + # 後処理 + if tmp_csv_path.exists(): + tmp_csv_path.unlink() + if tmp_compiled_path.exists(): + tmp_compiled_path.unlink() + + +# @mutex_wrapper(mutex_user_dict) +def read_dict(user_dict_path: Path = user_dict_path) -> Dict[str, UserDictWord]: + """ + ユーザー辞書の読み出し + Parameters + ---------- + user_dict_path : Path + ユーザー辞書ファイルのパス + Returns + ------- + result : Dict[str, UserDictWord] + ユーザー辞書 + """ + # 指定ユーザー辞書が存在しない場合、空辞書を返す + if not user_dict_path.is_file(): + return {} + + with user_dict_path.open(encoding="utf-8") as f: + result: Dict[str, UserDictWord] = {} + for word_uuid, word in json.load(f).items(): + # cost2priorityで変換を行う際にcontext_idが必要となるが、 + # 0.12以前の辞書は、context_idがハードコーディングされていたためにユーザー辞書内に保管されていない + # ハードコーディングされていたcontext_idは固有名詞を意味するものなので、固有名詞のcontext_idを補完する + if word.get("context_id") is None: + word["context_id"] = part_of_speech_data[ + WordTypes.PROPER_NOUN + ].context_id + word["priority"] = _cost2priority(word["context_id"], word["cost"]) + del word["cost"] + result[str(UUID(word_uuid))] = UserDictWord(**word) + + return result + + +def _create_word( + surface: str, + pronunciation: str, + accent_type: int, + word_type: Optional[WordTypes] = None, + priority: Optional[int] = None, +) -> UserDictWord: + """ + 単語オブジェクトの生成 + Parameters + ---------- + surface : str + 単語情報 + pronunciation : str + 単語情報 + accent_type : int + 単語情報 + word_type : Optional[WordTypes] + 品詞 + priority : Optional[int] + 優先度 + Returns + ------- + : UserDictWord + 単語オブジェクト + """ + if word_type is None: + word_type = WordTypes.PROPER_NOUN + if word_type not in part_of_speech_data.keys(): + raise HTTPException(status_code=422, detail="不明な品詞です") + if priority is None: + priority = 5 + if not MIN_PRIORITY <= priority <= MAX_PRIORITY: + raise HTTPException(status_code=422, detail="優先度の値が無効です") + pos_detail = part_of_speech_data[word_type] + return UserDictWord( + surface=surface, + context_id=pos_detail.context_id, + priority=priority, + part_of_speech=pos_detail.part_of_speech, + part_of_speech_detail_1=pos_detail.part_of_speech_detail_1, + part_of_speech_detail_2=pos_detail.part_of_speech_detail_2, + part_of_speech_detail_3=pos_detail.part_of_speech_detail_3, + inflectional_type="*", + inflectional_form="*", + stem="*", + yomi=pronunciation, + pronunciation=pronunciation, + accent_type=accent_type, + accent_associative_rule="*", + ) + + +def apply_word( + surface: str, + pronunciation: str, + accent_type: int, + word_type: Optional[WordTypes] = None, + priority: Optional[int] = None, + user_dict_path: Path = user_dict_path, + compiled_dict_path: Path = compiled_dict_path, +) -> str: + """ + 新規単語の追加 + Parameters + ---------- + surface : str + 単語情報 + pronunciation : str + 単語情報 + accent_type : int + 単語情報 + word_type : Optional[WordTypes] + 品詞 + priority : Optional[int] + 優先度 + user_dict_path : Path + ユーザー辞書ファイルのパス + compiled_dict_path : Path + コンパイル済み辞書ファイルのパス + Returns + ------- + word_uuid : UserDictWord + 追加された単語に発行されたUUID + """ + # 新規単語の追加による辞書データの更新 + word = _create_word( + surface=surface, + pronunciation=pronunciation, + accent_type=accent_type, + word_type=word_type, + priority=priority, + ) + user_dict = read_dict(user_dict_path=user_dict_path) + word_uuid = str(uuid4()) + user_dict[word_uuid] = word + + # 更新された辞書データの保存と適用 + _write_to_json(user_dict, user_dict_path) + update_dict(user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path) + + return word_uuid + + +def rewrite_word( + word_uuid: str, + surface: str, + pronunciation: str, + accent_type: int, + word_type: Optional[WordTypes] = None, + priority: Optional[int] = None, + user_dict_path: Path = user_dict_path, + compiled_dict_path: Path = compiled_dict_path, +) -> None: + """ + 既存単語の上書き更新 + Parameters + ---------- + word_uuid : str + 単語UUID + surface : str + 単語情報 + pronunciation : str + 単語情報 + accent_type : int + 単語情報 + word_type : Optional[WordTypes] + 品詞 + priority : Optional[int] + 優先度 + user_dict_path : Path + ユーザー辞書ファイルのパス + compiled_dict_path : Path + コンパイル済み辞書ファイルのパス + """ + word = _create_word( + surface=surface, + pronunciation=pronunciation, + accent_type=accent_type, + word_type=word_type, + priority=priority, + ) + + # 既存単語の上書きによる辞書データの更新 + user_dict = read_dict(user_dict_path=user_dict_path) + if word_uuid not in user_dict: + raise HTTPException( + status_code=422, detail="UUIDに該当するワードが見つかりませんでした" + ) + user_dict[word_uuid] = word + + # 更新された辞書データの保存と適用 + _write_to_json(user_dict, user_dict_path) + update_dict(user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path) + + +def delete_word( + word_uuid: str, + user_dict_path: Path = user_dict_path, + compiled_dict_path: Path = compiled_dict_path, +) -> None: + """ + 単語の削除 + Parameters + ---------- + word_uuid : str + 単語UUID + user_dict_path : Path + ユーザー辞書ファイルのパス + compiled_dict_path : Path + コンパイル済み辞書ファイルのパス + """ + # 既存単語の削除による辞書データの更新 + user_dict = read_dict(user_dict_path=user_dict_path) + if word_uuid not in user_dict: + raise HTTPException( + status_code=422, detail="IDに該当するワードが見つかりませんでした" + ) + del user_dict[word_uuid] + + # 更新された辞書データの保存と適用 + _write_to_json(user_dict, user_dict_path) + update_dict(user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path) + + +def import_user_dict( + dict_data: Dict[str, UserDictWord], + override: bool = False, + user_dict_path: Path = user_dict_path, + default_dict_path: Path = default_dict_path, + compiled_dict_path: Path = compiled_dict_path, +) -> None: + """ + ユーザー辞書のインポート + Parameters + ---------- + dict_data : Dict[str, UserDictWord] + インポートするユーザー辞書のデータ + override : bool + 重複したエントリがあった場合、上書きするかどうか + user_dict_path : Path + ユーザー辞書ファイルのパス + default_dict_path : Path + デフォルト辞書ファイルのパス + compiled_dict_path : Path + コンパイル済み辞書ファイルのパス + """ + # インポートする辞書データのバリデーション + for word_uuid, word in dict_data.items(): + UUID(word_uuid) + assert isinstance(word, UserDictWord) + for pos_detail in part_of_speech_data.values(): + if word.context_id == pos_detail.context_id: + assert word.part_of_speech == pos_detail.part_of_speech + assert ( + word.part_of_speech_detail_1 == pos_detail.part_of_speech_detail_1 + ) + assert ( + word.part_of_speech_detail_2 == pos_detail.part_of_speech_detail_2 + ) + assert ( + word.part_of_speech_detail_3 == pos_detail.part_of_speech_detail_3 + ) + assert ( + word.accent_associative_rule in pos_detail.accent_associative_rules + ) + break + else: + raise ValueError("対応していない品詞です") + + # 既存辞書の読み出し + old_dict = read_dict(user_dict_path=user_dict_path) + + # 辞書データの更新 + # 重複エントリの上書き + if override: + new_dict = {**old_dict, **dict_data} + # 重複エントリの保持 + else: + new_dict = {**dict_data, **old_dict} + + # 更新された辞書データの保存と適用 + _write_to_json(user_dict=new_dict, user_dict_path=user_dict_path) + update_dict( + default_dict_path=default_dict_path, + user_dict_path=user_dict_path, + compiled_dict_path=compiled_dict_path, + ) + + +def _search_cost_candidates(context_id: int) -> List[int]: + for value in part_of_speech_data.values(): + if value.context_id == context_id: + return value.cost_candidates + raise HTTPException(status_code=422, detail="品詞IDが不正です") + + +def _cost2priority(context_id: int, cost: int) -> int: + assert -32768 <= cost <= 32767 + cost_candidates = _search_cost_candidates(context_id) + # cost_candidatesの中にある値で最も近い値を元にpriorityを返す + # 参考: https://qiita.com/Krypf/items/2eada91c37161d17621d + # この関数とpriority2cost関数によって、辞書ファイルのcostを操作しても最も近いpriorityのcostに上書きされる + return MAX_PRIORITY - np.argmin(np.abs(np.array(cost_candidates) - cost)).item() + + +def _priority2cost(context_id: int, priority: int) -> int: + assert MIN_PRIORITY <= priority <= MAX_PRIORITY + cost_candidates = _search_cost_candidates(context_id) + return cost_candidates[MAX_PRIORITY - priority] diff --git a/text/user_dict/part_of_speech_data.py b/text/user_dict/part_of_speech_data.py new file mode 100644 index 000000000..7e22699b7 --- /dev/null +++ b/text/user_dict/part_of_speech_data.py @@ -0,0 +1,150 @@ +# このファイルは、VOICEVOXプロジェクトのVOICEVOX engineからお借りしています。 +# 引用元: +# https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/user_dict/part_of_speech_data.py +# ライセンス: LGPL-3.0 +# 詳しくは、このファイルと同じフォルダにあるREADME.mdを参照してください。 + +from typing import Dict + +from .word_model import ( + USER_DICT_MAX_PRIORITY, + USER_DICT_MIN_PRIORITY, + PartOfSpeechDetail, + WordTypes, +) + +MIN_PRIORITY = USER_DICT_MIN_PRIORITY +MAX_PRIORITY = USER_DICT_MAX_PRIORITY + +part_of_speech_data: Dict[WordTypes, PartOfSpeechDetail] = { + WordTypes.PROPER_NOUN: PartOfSpeechDetail( + part_of_speech="名詞", + part_of_speech_detail_1="固有名詞", + part_of_speech_detail_2="一般", + part_of_speech_detail_3="*", + context_id=1348, + cost_candidates=[ + -988, + 3488, + 4768, + 6048, + 7328, + 8609, + 8734, + 8859, + 8984, + 9110, + 14176, + ], + accent_associative_rules=[ + "*", + "C1", + "C2", + "C3", + "C4", + "C5", + ], + ), + WordTypes.COMMON_NOUN: PartOfSpeechDetail( + part_of_speech="名詞", + part_of_speech_detail_1="一般", + part_of_speech_detail_2="*", + part_of_speech_detail_3="*", + context_id=1345, + cost_candidates=[ + -4445, + 49, + 1473, + 2897, + 4321, + 5746, + 6554, + 7362, + 8170, + 8979, + 15001, + ], + accent_associative_rules=[ + "*", + "C1", + "C2", + "C3", + "C4", + "C5", + ], + ), + WordTypes.VERB: PartOfSpeechDetail( + part_of_speech="動詞", + part_of_speech_detail_1="自立", + part_of_speech_detail_2="*", + part_of_speech_detail_3="*", + context_id=642, + cost_candidates=[ + 3100, + 6160, + 6360, + 6561, + 6761, + 6962, + 7414, + 7866, + 8318, + 8771, + 13433, + ], + accent_associative_rules=[ + "*", + ], + ), + WordTypes.ADJECTIVE: PartOfSpeechDetail( + part_of_speech="形容詞", + part_of_speech_detail_1="自立", + part_of_speech_detail_2="*", + part_of_speech_detail_3="*", + context_id=20, + cost_candidates=[ + 1527, + 3266, + 3561, + 3857, + 4153, + 4449, + 5149, + 5849, + 6549, + 7250, + 10001, + ], + accent_associative_rules=[ + "*", + ], + ), + WordTypes.SUFFIX: PartOfSpeechDetail( + part_of_speech="名詞", + part_of_speech_detail_1="接尾", + part_of_speech_detail_2="一般", + part_of_speech_detail_3="*", + context_id=1358, + cost_candidates=[ + 4399, + 5373, + 6041, + 6710, + 7378, + 8047, + 9440, + 10834, + 12228, + 13622, + 15847, + ], + accent_associative_rules=[ + "*", + "C1", + "C2", + "C3", + "C4", + "C5", + ], + ), +} diff --git a/text/user_dict/word_model.py b/text/user_dict/word_model.py new file mode 100644 index 000000000..f05d8dc47 --- /dev/null +++ b/text/user_dict/word_model.py @@ -0,0 +1,129 @@ +# このファイルは、VOICEVOXプロジェクトのVOICEVOX engineからお借りしています。 +# 引用元: +# https://github.com/VOICEVOX/voicevox_engine/blob/f181411ec69812296989d9cc583826c22eec87ae/voicevox_engine/model.py#L207 +# ライセンス: LGPL-3.0 +# 詳しくは、このファイルと同じフォルダにあるREADME.mdを参照してください。 +from enum import Enum +from re import findall, fullmatch +from typing import List, Optional + +from pydantic import BaseModel, Field, validator + +USER_DICT_MIN_PRIORITY = 0 +USER_DICT_MAX_PRIORITY = 10 + + +class UserDictWord(BaseModel): + """ + 辞書のコンパイルに使われる情報 + """ + + surface: str = Field(title="表層形") + priority: int = Field( + title="優先度", ge=USER_DICT_MIN_PRIORITY, le=USER_DICT_MAX_PRIORITY + ) + context_id: int = Field(title="文脈ID", default=1348) + part_of_speech: str = Field(title="品詞") + part_of_speech_detail_1: str = Field(title="品詞細分類1") + part_of_speech_detail_2: str = Field(title="品詞細分類2") + part_of_speech_detail_3: str = Field(title="品詞細分類3") + inflectional_type: str = Field(title="活用型") + inflectional_form: str = Field(title="活用形") + stem: str = Field(title="原形") + yomi: str = Field(title="読み") + pronunciation: str = Field(title="発音") + accent_type: int = Field(title="アクセント型") + mora_count: Optional[int] = Field(title="モーラ数", default=None) + accent_associative_rule: str = Field(title="アクセント結合規則") + + class Config: + validate_assignment = True + + @validator("surface") + def convert_to_zenkaku(cls, surface): + return surface.translate( + str.maketrans( + "".join(chr(0x21 + i) for i in range(94)), + "".join(chr(0xFF01 + i) for i in range(94)), + ) + ) + + @validator("pronunciation", pre=True) + def check_is_katakana(cls, pronunciation): + if not fullmatch(r"[ァ-ヴー]+", pronunciation): + raise ValueError("発音は有効なカタカナでなくてはいけません。") + sutegana = ["ァ", "ィ", "ゥ", "ェ", "ォ", "ャ", "ュ", "ョ", "ヮ", "ッ"] + for i in range(len(pronunciation)): + if pronunciation[i] in sutegana: + # 「キャット」のように、捨て仮名が連続する可能性が考えられるので、 + # 「ッ」に関しては「ッ」そのものが連続している場合と、「ッ」の後にほかの捨て仮名が連続する場合のみ無効とする + if i < len(pronunciation) - 1 and ( + pronunciation[i + 1] in sutegana[:-1] + or ( + pronunciation[i] == sutegana[-1] + and pronunciation[i + 1] == sutegana[-1] + ) + ): + raise ValueError("無効な発音です。(捨て仮名の連続)") + if pronunciation[i] == "ヮ": + if i != 0 and pronunciation[i - 1] not in ["ク", "グ"]: + raise ValueError( + "無効な発音です。(「くゎ」「ぐゎ」以外の「ゎ」の使用)" + ) + return pronunciation + + @validator("mora_count", pre=True, always=True) + def check_mora_count_and_accent_type(cls, mora_count, values): + if "pronunciation" not in values or "accent_type" not in values: + # 適切な場所でエラーを出すようにする + return mora_count + + if mora_count is None: + rule_others = ( + "[イ][ェ]|[ヴ][ャュョ]|[トド][ゥ]|[テデ][ィャュョ]|[デ][ェ]|[クグ][ヮ]" + ) + rule_line_i = "[キシチニヒミリギジビピ][ェャュョ]" + rule_line_u = "[ツフヴ][ァ]|[ウスツフヴズ][ィ]|[ウツフヴ][ェォ]" + rule_one_mora = "[ァ-ヴー]" + mora_count = len( + findall( + f"(?:{rule_others}|{rule_line_i}|{rule_line_u}|{rule_one_mora})", + values["pronunciation"], + ) + ) + + if not 0 <= values["accent_type"] <= mora_count: + raise ValueError( + "誤ったアクセント型です({})。 expect: 0 <= accent_type <= {}".format( + values["accent_type"], mora_count + ) + ) + return mora_count + + +class PartOfSpeechDetail(BaseModel): + """ + 品詞ごとの情報 + """ + + part_of_speech: str = Field(title="品詞") + part_of_speech_detail_1: str = Field(title="品詞細分類1") + part_of_speech_detail_2: str = Field(title="品詞細分類2") + part_of_speech_detail_3: str = Field(title="品詞細分類3") + # context_idは辞書の左・右文脈IDのこと + # https://github.com/VOICEVOX/open_jtalk/blob/427cfd761b78efb6094bea3c5bb8c968f0d711ab/src/mecab-naist-jdic/_left-id.def # noqa + context_id: int = Field(title="文脈ID") + cost_candidates: List[int] = Field(title="コストのパーセンタイル") + accent_associative_rules: List[str] = Field(title="アクセント結合規則の一覧") + + +class WordTypes(str, Enum): + """ + fastapiでword_type引数を検証する時に使用するクラス + """ + + PROPER_NOUN = "PROPER_NOUN" + COMMON_NOUN = "COMMON_NOUN" + VERB = "VERB" + ADJECTIVE = "ADJECTIVE" + SUFFIX = "SUFFIX" diff --git a/tools/translate.py b/tools/translate.py index 9368b5f8e..be0f7ea45 100644 --- a/tools/translate.py +++ b/tools/translate.py @@ -1,6 +1,7 @@ """ 翻译api """ + from config import config import random diff --git a/train_ms.py b/train_ms.py index 57eb5f6a1..ea4b65fbb 100644 --- a/train_ms.py +++ b/train_ms.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist +from huggingface_hub import HfApi from torch.cuda.amp import GradScaler, autocast from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel as DDP @@ -45,6 +46,8 @@ global_step = 0 +api = HfApi() + def run(): # Command line configuration is not recommended unless necessary, use config.yml @@ -84,6 +87,11 @@ def run(): action="store_true", help="Speed up training by disabling logging and evaluation.", ) + parser.add_argument( + "--repo_id", + help="Huggingface model repo id to backup the model.", + default=None, + ) args = parser.parse_args() # Set log file @@ -151,6 +159,30 @@ def run(): config.out_dir: The directory for model assets of this model (for inference). default: `model_assets/{model_name}`. """ + + if args.repo_id is not None: + # First try to upload config.json to check if the repo exists + try: + api.upload_file( + path_or_fileobj=args.config, + path_in_repo=f"Data/{config.model_name}/config.json", + repo_id=hps.repo_id, + ) + except Exception as e: + logger.error(e) + logger.error( + f"Failed to upload files to the repo {hps.repo_id}. Please check if the repo exists and you have logged in using `huggingface-cli login`." + ) + raise e + # Upload Data dir for resuming training + api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.dataset_path, + path_in_repo=f"Data/{config.model_name}", + delete_patterns="*.pth", # Only keep the latest checkpoint + run_as_future=True, + ) + os.makedirs(config.out_dir, exist_ok=True) if not args.skip_default_style: @@ -274,6 +306,11 @@ def run(): for param in net_g.enc_p.style_proj.parameters(): param.requires_grad = False + if getattr(hps.train, "freeze_decoder", False): + logger.info("Freezing decoder !!!") + for param in net_g.dec.parameters(): + param.requires_grad = False + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank) optim_g = torch.optim.AdamW( filter(lambda p: p.requires_grad, net_g.parameters()), @@ -478,6 +515,25 @@ def lr_lambda(epoch): ), for_infer=True, ) + if hps.repo_id is not None: + future1 = api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.dataset_path, + path_in_repo=f"Data/{config.model_name}", + delete_patterns="*.pth", # Only keep the latest checkpoint + run_as_future=True, + ) + future2 = api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.out_dir, + path_in_repo=f"model_assets/{config.model_name}", + run_as_future=True, + ) + try: + future1.result() + future2.result() + except Exception as e: + logger.error(e) if pbar is not None: pbar.close() @@ -760,6 +816,20 @@ def train_and_evaluate( ), for_infer=True, ) + if hps.repo_id is not None: + api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.dataset_path, + path_in_repo=f"Data/{config.model_name}", + delete_patterns="*.pth", # Only keep the latest checkpoint + run_as_future=True, + ) + api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.out_dir, + path_in_repo=f"model_assets/{config.model_name}", + run_as_future=True, + ) global_step += 1 if pbar is not None: diff --git a/train_ms_jp_extra.py b/train_ms_jp_extra.py index 2c46eb25f..bae922d52 100644 --- a/train_ms_jp_extra.py +++ b/train_ms_jp_extra.py @@ -12,6 +12,7 @@ from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm +from huggingface_hub import HfApi # logging.getLogger("numba").setLevel(logging.WARNING) import commons @@ -48,6 +49,8 @@ ) # Not available if torch version is lower than 2.0 global_step = 0 +api = HfApi() + def run(): # Command line configuration is not recommended unless necessary, use config.yml @@ -87,6 +90,11 @@ def run(): action="store_true", help="Speed up training by disabling logging and evaluation.", ) + parser.add_argument( + "--repo_id", + help="Huggingface model repo id to backup the model.", + default=None, + ) args = parser.parse_args() # Set log file @@ -126,6 +134,7 @@ def run(): # This is needed because we have to pass values to `train_and_evaluate() hps.model_dir = model_dir hps.speedup = args.speedup + hps.repo_id = args.repo_id # 比较路径是否相同 if os.path.realpath(args.config) != os.path.realpath( @@ -154,6 +163,30 @@ def run(): config.out_dir: The directory for model assets of this model (for inference). default: `model_assets/{model_name}`. """ + + if args.repo_id is not None: + # First try to upload config.json to check if the repo exists + try: + api.upload_file( + path_or_fileobj=args.config, + path_in_repo=f"Data/{config.model_name}/config.json", + repo_id=hps.repo_id, + ) + except Exception as e: + logger.error(e) + logger.error( + f"Failed to upload files to the repo {hps.repo_id}. Please check if the repo exists and you have logged in using `huggingface-cli login`." + ) + raise e + # Upload Data dir for resuming training + api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.dataset_path, + path_in_repo=f"Data/{config.model_name}", + delete_patterns="*.pth", # Only keep the latest checkpoint + ignore_patterns=f"{config.dataset_path}/raw", # Ignore raw data + run_as_future=True, + ) os.makedirs(config.out_dir, exist_ok=True) if not args.skip_default_style: @@ -277,6 +310,11 @@ def run(): for param in net_g.enc_p.style_proj.parameters(): param.requires_grad = False + if getattr(hps.train, "freeze_decoder", False): + logger.info("Freezing decoder !!!") + for param in net_g.dec.parameters(): + param.requires_grad = False + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank) optim_g = torch.optim.AdamW( filter(lambda p: p.requires_grad, net_g.parameters()), @@ -565,6 +603,26 @@ def lr_lambda(epoch): ), for_infer=True, ) + if hps.repo_id is not None: + future1 = api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.dataset_path, + path_in_repo=f"Data/{config.model_name}", + delete_patterns="*.pth", # Only keep the latest checkpoint + ignore_patterns=f"{config.dataset_path}/raw", # Ignore raw data + run_as_future=True, + ) + future2 = api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.out_dir, + path_in_repo=f"model_assets/{config.model_name}", + run_as_future=True, + ) + try: + future1.result() + future2.result() + except Exception as e: + logger.error(e) if pbar is not None: pbar.close() @@ -916,6 +974,21 @@ def train_and_evaluate( ), for_infer=True, ) + if hps.repo_id is not None: + api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.dataset_path, + path_in_repo=f"Data/{config.model_name}", + delete_patterns="*.pth", # Only keep the latest checkpoint + ignore_patterns=f"{config.dataset_path}/raw", # Ignore raw data + run_as_future=True, + ) + api.upload_folder( + repo_id=hps.repo_id, + folder_path=config.out_dir, + path_in_repo=f"model_assets/{config.model_name}", + run_as_future=True, + ) global_step += 1 if pbar is not None: @@ -926,9 +999,8 @@ def train_and_evaluate( ) pbar.update() # 本家ではこれをスピードアップのために消すと書かれていたので、一応消してみる - # と思ったけどメモリ使用量が減るかもしれないのでつけてみる - gc.collect() - torch.cuda.empty_cache() + # gc.collect() + # torch.cuda.empty_cache() if pbar is None and rank == 0: logger.info(f"====> Epoch: {epoch}, step: {global_step}") diff --git a/transcribe.py b/transcribe.py index d9fb98be8..4a28fd193 100644 --- a/transcribe.py +++ b/transcribe.py @@ -2,6 +2,7 @@ import os import sys +import yaml from faster_whisper import WhisperModel from tqdm import tqdm @@ -20,25 +21,29 @@ def transcribe(wav_path, initial_prompt=None, language="ja"): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", "-i", type=str, default="raw") - parser.add_argument("--output_file", "-o", type=str, default="esd.list") + parser.add_argument("--model_name", type=str, required=True) parser.add_argument( - "--initial_prompt", type=str, default="こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!" + "--initial_prompt", + type=str, + default="こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!", ) parser.add_argument( "--language", type=str, default="ja", choices=["ja", "en", "zh"] ) - parser.add_argument("--speaker_name", type=str, required=True) parser.add_argument("--model", type=str, default="large-v3") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--compute_type", type=str, default="bfloat16") args = parser.parse_args() - speaker_name = args.speaker_name + with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f: + path_config: dict[str, str] = yaml.safe_load(f.read()) + dataset_root = path_config["dataset_root"] - input_dir = args.input_dir - output_file = args.output_file + model_name = args.model_name + + input_dir = os.path.join(dataset_root, model_name, "raw") + output_file = os.path.join(dataset_root, model_name, "esd.list") initial_prompt = args.initial_prompt language = args.language device = args.device @@ -73,12 +78,9 @@ def transcribe(wav_path, initial_prompt=None, language="ja"): language_id = Languages.ZH.value else: raise ValueError(f"{language} is not supported.") - with open(output_file, "w", encoding="utf-8") as f: - for wav_file in tqdm(wav_files, file=SAFE_STDOUT): - file_name = os.path.basename(wav_file) - text = transcribe( - wav_file, initial_prompt=initial_prompt, language=language - ) - f.write(f"{file_name}|{speaker_name}|{language_id}|{text}\n") - f.flush() + for wav_file in tqdm(wav_files, file=SAFE_STDOUT): + file_name = os.path.basename(wav_file) + text = transcribe(wav_file, initial_prompt=initial_prompt, language=language) + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"{file_name}|{model_name}|{language_id}|{text}\n") sys.exit(0) diff --git a/update_status.py b/update_status.py index 65b4f93f9..7d768c663 100644 --- a/update_status.py +++ b/update_status.py @@ -38,7 +38,9 @@ def update_c_files(): c_files.append(os.path.join(root, file)) cnt += 1 print(c_files) - return f"更新模型列表完成, 共找到{cnt}个配置文件", gr.Dropdown.update(choices=c_files) + return f"更新模型列表完成, 共找到{cnt}个配置文件", gr.Dropdown.update( + choices=c_files + ) def update_model_folders(): @@ -50,7 +52,9 @@ def update_model_folders(): subdirs.append(os.path.join(root, dir_name)) cnt += 1 print(subdirs) - return f"更新模型文件夹列表完成, 共找到{cnt}个文件夹", gr.Dropdown.update(choices=subdirs) + return f"更新模型文件夹列表完成, 共找到{cnt}个文件夹", gr.Dropdown.update( + choices=subdirs + ) def update_wav_lab_pairs(): diff --git a/webui.py b/webui.py index e7c1abd80..90318a1e7 100644 --- a/webui.py +++ b/webui.py @@ -1,6 +1,7 @@ """ Original `webui.py` for Bert-VITS2, not working with Style-Bert-VITS2 yet. """ + # flake8: noqa: E402 import os import logging diff --git a/webui_dataset.py b/webui_dataset.py index 2885e2703..fec7a9ac9 100644 --- a/webui_dataset.py +++ b/webui_dataset.py @@ -25,11 +25,10 @@ def do_slice( if model_name == "": return "Error: モデル名を入力してください。" logger.info("Start slicing...") - output_dir = os.path.join(dataset_root, model_name, "raw") cmd = [ "slice.py", - "--output_dir", - output_dir, + "--model_name", + model_name, "--min_sec", str(min_sec), "--max_sec", @@ -47,24 +46,15 @@ def do_slice( def do_transcribe( - model_name, whisper_model, compute_type, language, initial_prompt, input_dir, device + model_name, whisper_model, compute_type, language, initial_prompt, device ): if model_name == "": return "Error: モデル名を入力してください。" - if initial_prompt == "": - initial_prompt = "こんにちは。元気、ですかー?私は……ふふっ、ちゃんと元気だよ!" - # logger.debug(f"initial_prompt: {initial_prompt}") - if input_dir == "": - input_dir = os.path.join(dataset_root, model_name, "raw") - output_file = os.path.join(dataset_root, model_name, "esd.list") + success, message = run_script_with_log( [ "transcribe.py", - "--input_dir", - input_dir, - "--output_file", - output_file, - "--speaker_name", + "--model_name", model_name, "--model", whisper_model, @@ -154,9 +144,6 @@ def do_transcribe( result1 = gr.Textbox(label="結果") with gr.Row(): with gr.Column(): - raw_dir = gr.Textbox( - label="書き起こしたい音声ファイルが入っているフォルダ(スライスした場合など、`Data/{モデル名}/raw`の場合は省略可", - ) whisper_model = gr.Dropdown( ["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisperモデル", @@ -180,8 +167,8 @@ def do_transcribe( language = gr.Dropdown(["ja", "en", "zh"], value="ja", label="言語") initial_prompt = gr.Textbox( label="初期プロンプト", - placeholder="こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!", - info="このように書き起こしてほしいという例文、日本語なら省略可、英語等なら書いてください", + value="こんにちは。元気、ですかー?ふふっ、私は……ちゃんと元気だよ!", + info="このように書き起こしてほしいという例文(句読点の入れ方・笑い方・固有名詞等)", ) transcribe_button = gr.Button("音声の文字起こし") result2 = gr.Textbox(label="結果") @@ -198,7 +185,6 @@ def do_transcribe( compute_type, language, initial_prompt, - raw_dir, device, ], outputs=[result2], diff --git a/webui_train.py b/webui_train.py index ebd292e7d..cd7af20e9 100644 --- a/webui_train.py +++ b/webui_train.py @@ -9,6 +9,7 @@ import webbrowser from datetime import datetime from multiprocessing import cpu_count +from pathlib import Path import gradio as gr import yaml @@ -47,6 +48,7 @@ def initialize( freeze_JP_bert, freeze_ZH_bert, freeze_style, + freeze_decoder, use_jp_extra, log_interval, ): @@ -61,7 +63,7 @@ def initialize( logger_handler = logger.add(os.path.join(dataset_path, file_name)) logger.info( - f"Step 1: start initialization...\nmodel_name: {model_name}, batch_size: {batch_size}, epochs: {epochs}, save_every_steps: {save_every_steps}, freeze_ZH_bert: {freeze_ZH_bert}, freeze_JP_bert: {freeze_JP_bert}, freeze_EN_bert: {freeze_EN_bert}, freeze_style: {freeze_style}, use_jp_extra: {use_jp_extra}" + f"Step 1: start initialization...\nmodel_name: {model_name}, batch_size: {batch_size}, epochs: {epochs}, save_every_steps: {save_every_steps}, freeze_ZH_bert: {freeze_ZH_bert}, freeze_JP_bert: {freeze_JP_bert}, freeze_EN_bert: {freeze_EN_bert}, freeze_style: {freeze_style}, freeze_decoder: {freeze_decoder}, use_jp_extra: {use_jp_extra}" ) default_config_path = ( @@ -82,12 +84,15 @@ def initialize( config["train"]["freeze_JP_bert"] = freeze_JP_bert config["train"]["freeze_ZH_bert"] = freeze_ZH_bert config["train"]["freeze_style"] = freeze_style + config["train"]["freeze_decoder"] = freeze_decoder config["train"]["bf16_run"] = False # デフォルトでFalseのはずだが念のため model_path = os.path.join(dataset_path, "models") if os.path.exists(model_path): - logger.warning(f"Step 1: {model_path} already exists, so copy it to backup.") + logger.warning( + f"Step 1: {model_path} already exists, so copy it to backup to {model_path}_backup" + ) shutil.copytree( src=model_path, dst=os.path.join(dataset_path, "models_backup"), @@ -158,13 +163,20 @@ def preprocess_text(model_name, use_jp_extra, val_per_lang): except FileNotFoundError: logger.error(f"Step 3: {lbl_path} not found.") return False, f"Step 3, Error: 書き起こしファイル {lbl_path} が見つかりません。" - with open(lbl_path, "w", encoding="utf-8") as f: - for line in lines: - path, spk, language, text = line.strip().split("|") - path = os.path.join(dataset_path, "wavs", os.path.basename(path)).replace( - "\\", "/" + new_lines = [] + for line in lines: + if len(line.strip().split("|")) != 4: + logger.error(f"Step 3: {lbl_path} has invalid format at line:\n{line}") + return ( + False, + f"Step 3, Error: 書き起こしファイル次の行の形式が不正です:\n{line}", ) - f.writelines(f"{path}|{spk}|{language}|{text}\n") + path, spk, language, text = line.strip().split("|") + # pathをファイル名だけ取り出して正しいパスに変更 + path = Path(dataset_path) / "wavs" / Path(path).name + new_lines.append(f"{path}|{spk}|{language}|{text}\n") + with open(lbl_path, "w", encoding="utf-8") as f: + f.writelines(new_lines) cmd = [ "preprocess_text.py", "--config-path", @@ -262,6 +274,7 @@ def preprocess_all( freeze_JP_bert, freeze_ZH_bert, freeze_style, + freeze_decoder, use_jp_extra, val_per_lang, log_interval, @@ -269,29 +282,39 @@ def preprocess_all( if model_name == "": return False, "Error: モデル名を入力してください" success, message = initialize( - model_name, - batch_size, - epochs, - save_every_steps, - freeze_EN_bert, - freeze_JP_bert, - freeze_ZH_bert, - freeze_style, - use_jp_extra, - log_interval, + model_name=model_name, + batch_size=batch_size, + epochs=epochs, + save_every_steps=save_every_steps, + freeze_EN_bert=freeze_EN_bert, + freeze_JP_bert=freeze_JP_bert, + freeze_ZH_bert=freeze_ZH_bert, + freeze_style=freeze_style, + freeze_decoder=freeze_decoder, + use_jp_extra=use_jp_extra, + log_interval=log_interval, ) if not success: return False, message - success, message = resample(model_name, normalize, trim, num_processes) + success, message = resample( + model_name=model_name, + normalize=normalize, + trim=trim, + num_processes=num_processes, + ) if not success: return False, message - success, message = preprocess_text(model_name, use_jp_extra, val_per_lang) + success, message = preprocess_text( + model_name=model_name, use_jp_extra=use_jp_extra, val_per_lang=val_per_lang + ) if not success: return False, message - success, message = bert_gen(model_name) # bert_genは重いのでプロセス数いじらない + success, message = bert_gen( + model_name=model_name + ) # bert_genは重いのでプロセス数いじらない if not success: return False, message - success, message = style_gen(model_name, num_processes) + success, message = style_gen(model_name=model_name, num_processes=num_processes) if not success: return False, message logger.success("Success: All preprocess finished!") @@ -507,6 +530,10 @@ def run_tensorboard(model_name): label="スタイル部分を凍結", value=False, ) + freeze_decoder = gr.Checkbox( + label="デコーダ部分を凍結", + value=False, + ) with gr.Column(): preprocess_button = gr.Button( @@ -565,6 +592,10 @@ def run_tensorboard(model_name): label="スタイル部分を凍結", value=False, ) + freeze_decoder_manual = gr.Checkbox( + label="デコーダ部分を凍結", + value=False, + ) with gr.Column(): generate_config_btn = gr.Button(value="実行", variant="primary") info_init = gr.Textbox(label="状況") @@ -658,6 +689,7 @@ def run_tensorboard(model_name): freeze_JP_bert, freeze_ZH_bert, freeze_style, + freeze_decoder, use_jp_extra, val_per_lang, log_interval, @@ -677,6 +709,7 @@ def run_tensorboard(model_name): freeze_JP_bert_manual, freeze_ZH_bert_manual, freeze_style_manual, + freeze_decoder_manual, use_jp_extra_manual, log_interval_manual, ], @@ -694,7 +727,11 @@ def run_tensorboard(model_name): ) preprocess_text_btn.click( second_elem_of(preprocess_text), - inputs=[model_name, use_jp_extra_manual, val_per_lang_manual], + inputs=[ + model_name, + use_jp_extra_manual, + val_per_lang_manual, + ], outputs=[info_preprocess_text], ) bert_gen_btn.click(