Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
litagin02 committed Mar 15, 2024
1 parent 867855a commit 3d8b60c
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 198 deletions.
269 changes: 163 additions & 106 deletions preprocess_text.py
Original file line number Diff line number Diff line change
@@ -1,145 +1,161 @@
import argparse
import json
import os
from collections import defaultdict
from pathlib import Path
from random import shuffle
from typing import Optional

import click
from tqdm import tqdm

from config import config
from config import Preprocess_text_config, config
from style_bert_vits2.logging import logger
from style_bert_vits2.nlp import clean_text
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker
from style_bert_vits2.nlp.japanese.user_dict import update_dict
from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT


# このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化
pyopenjtalk_worker.initialize_worker()

# dict_data/ 以下の辞書データを pyopenjtalk に適用
update_dict()

preprocess_text_config = config.preprocess_text_config

preprocess_text_config: Preprocess_text_config = config.preprocess_text_config


# Count lines for tqdm
def count_lines(file_path: str):
with open(file_path, "r", encoding="utf-8") as file:
def count_lines(file_path: Path):
with file_path.open("r", encoding="utf-8") as file:
return sum(1 for _ in file)


@click.command()
@click.option(
"--transcription-path",
default=preprocess_text_config.transcription_path,
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@click.option("--cleaned-path", default=preprocess_text_config.cleaned_path)
@click.option("--train-path", default=preprocess_text_config.train_path)
@click.option("--val-path", default=preprocess_text_config.val_path)
@click.option(
"--config-path",
default=preprocess_text_config.config_path,
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@click.option("--val-per-lang", default=preprocess_text_config.val_per_lang)
@click.option("--max-val-total", default=preprocess_text_config.max_val_total)
@click.option("--clean/--no-clean", default=preprocess_text_config.clean)
@click.option("-y", "--yml_config")
@click.option("--use_jp_extra", is_flag=True)
@click.option("--yomi_error", default="raise")
def write_error_log(error_log_path: Path, line: str, error: Exception):
with error_log_path.open("a", encoding="utf-8") as error_log:
error_log.write(f"{line.strip()}\n{error}\n\n")


def process_line(
line: str,
transcription_path: Path,
correct_path: bool,
use_jp_extra: bool,
yomi_error: str,
):
splitted_line = line.strip().split("|")
if len(splitted_line) != 4:
raise ValueError(f"Invalid line format: {line.strip()}")
utt, spk, language, text = splitted_line
norm_text, phones, tones, word2ph = clean_text(
text=text,
language=language, # type: ignore
use_jp_extra=use_jp_extra,
raise_yomi_error=(yomi_error != "use"),
)
if correct_path:
utt = str(transcription_path.parent / "wavs" / utt)

return "{}|{}|{}|{}|{}|{}|{}\n".format(
utt,
spk,
language,
norm_text,
" ".join(phones),
" ".join([str(i) for i in tones]),
" ".join([str(i) for i in word2ph]),
)


def preprocess(
transcription_path: str,
cleaned_path: Optional[str],
train_path: str,
val_path: str,
config_path: str,
transcription_path: Path,
cleaned_path: Optional[Path],
train_path: Path,
val_path: Path,
config_path: Path,
val_per_lang: int,
max_val_total: int,
clean: bool,
yml_config: str, # 这个不要删
# clean: bool,
use_jp_extra: bool,
yomi_error: str,
correct_path: bool,
):
assert yomi_error in ["raise", "skip", "use"]
if cleaned_path == "" or cleaned_path is None:
cleaned_path = transcription_path + ".cleaned"
cleaned_path = transcription_path.with_name(
transcription_path.name + ".cleaned"
)

error_log_path = os.path.join(os.path.dirname(cleaned_path), "text_error.log")
if os.path.exists(error_log_path):
os.remove(error_log_path)
error_log_path = transcription_path.parent / "text_error.log"
if error_log_path.exists():
error_log_path.unlink()
error_count = 0

if clean:
total_lines = count_lines(transcription_path)
with open(cleaned_path, "w", encoding="utf-8") as out_file:
with open(transcription_path, "r", encoding="utf-8") as trans_file:
for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines):
try:
utt, spk, language, text = line.strip().split("|")
norm_text, phones, tones, word2ph = clean_text(
text=text,
language=language, # type: ignore
use_jp_extra=use_jp_extra,
raise_yomi_error=(yomi_error != "use"),
)

out_file.write(
"{}|{}|{}|{}|{}|{}|{}\n".format(
utt,
spk,
language,
norm_text,
" ".join(phones),
" ".join([str(i) for i in tones]),
" ".join([str(i) for i in word2ph]),
)
)
except Exception as e:
logger.error(
f"An error occurred at line:\n{line.strip()}\n{e}",
encoding="utf-8",
)
with open(error_log_path, "a", encoding="utf-8") as error_log:
error_log.write(f"{line.strip()}\n{e}\n\n")
error_count += 1
total_lines = count_lines(transcription_path)

# transcription_path から 1行ずつ読み込んで文章処理して cleaned_path に書き込む
with (
transcription_path.open("r", encoding="utf-8") as trans_file,
cleaned_path.open("w", encoding="utf-8") as out_file,
):
for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines):
try:
processed_line = process_line(
line,
transcription_path,
correct_path,
use_jp_extra,
yomi_error,
)
out_file.write(processed_line)
except Exception as e:
logger.error(
f"An error occurred at line:\n{line.strip()}\n{e}", encoding="utf-8"
)
write_error_log(error_log_path, line, e)
error_count += 1

transcription_path = cleaned_path
spk_utt_map = defaultdict(list)
spk_id_map = {}
current_sid = 0

with open(transcription_path, "r", encoding="utf-8") as f:
audioPaths = set()
countSame = 0
countNotFound = 0

# 各話者ごとのlineの辞書
spk_utt_map: dict[str, list[str]] = defaultdict(list)

# 話者からIDへの写像
spk_id_map: dict[str, int] = {}

# 話者ID
current_sid: int = 0

# 音源ファイルのチェックや、spk_id_mapの作成
with transcription_path.open("r", encoding="utf-8") as f:
audio_paths: set[str] = set()
count_same = 0
count_not_found = 0
for line in f.readlines():
utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
if utt in audioPaths:
# 过滤数据集错误:相同的音频匹配多个文本,导致后续bert出问题
logger.warning(f"Same audio matches multiple texts: {line}")
countSame += 1
utt, spk = line.strip().split("|")[:2]
if utt in audio_paths:
logger.warning(f"Same audio file appears multiple times: {utt}")
count_same += 1
continue
if not os.path.isfile(utt):
# 过滤数据集错误:不存在对应音频
if not Path(utt).is_file():
logger.warning(f"Audio not found: {utt}")
countNotFound += 1
count_not_found += 1
continue
audioPaths.add(utt)
spk_utt_map[language].append(line)
audio_paths.add(utt)
spk_utt_map[spk].append(line)

# 新しい話者が出てきたら話者IDを割り当て、current_sidを1増やす
if spk not in spk_id_map.keys():
spk_id_map[spk] = current_sid
current_sid += 1
if countSame > 0 or countNotFound > 0:
if count_same > 0 or count_not_found > 0:
logger.warning(
f"Total repeated audios: {countSame}, Total number of audio not found: {countNotFound}"
f"Total repeated audios: {count_same}, Total number of audio not found: {count_not_found}"
)

train_list = []
val_list = []
train_list: list[str] = []
val_list: list[str] = []

# 各話者ごとにシャッフルして、val_per_lang個をval_listに、残りをtrain_listに追加
for spk, utts in spk_utt_map.items():
shuffle(utts)
val_list += utts[:val_per_lang]
Expand All @@ -150,26 +166,21 @@ def preprocess(
train_list += val_list[max_val_total:]
val_list = val_list[:max_val_total]

with open(train_path, "w", encoding="utf-8") as f:
with train_path.open("w", encoding="utf-8") as f:
for line in train_list:
f.write(line)

with open(val_path, "w", encoding="utf-8") as f:
with val_path.open("w", encoding="utf-8") as f:
for line in val_list:
f.write(line)

json_config = json.load(open(config_path, encoding="utf-8"))
with config_path.open("r", encoding="utf-8") as f:
json_config = json.load(f)

json_config["data"]["spk2id"] = spk_id_map
json_config["data"]["n_speakers"] = len(spk_id_map)
# 新增写入:写入训练版本、数据集路径
# json_config["version"] = latest_version
json_config["data"]["training_files"] = os.path.normpath(train_path).replace(
"\\", "/"
)
json_config["data"]["validation_files"] = os.path.normpath(val_path).replace(
"\\", "/"
)
with open(config_path, "w", encoding="utf-8") as f:

with config_path.open("w", encoding="utf-8") as f:
json.dump(json_config, f, indent=2, ensure_ascii=False)
if error_count > 0:
if yomi_error == "skip":
Expand All @@ -194,4 +205,50 @@ def preprocess(


if __name__ == "__main__":
preprocess()
parser = argparse.ArgumentParser()
parser.add_argument(
"--transcription-path", default=preprocess_text_config.transcription_path
)
parser.add_argument("--cleaned-path", default=preprocess_text_config.cleaned_path)
parser.add_argument("--train-path", default=preprocess_text_config.train_path)
parser.add_argument("--val-path", default=preprocess_text_config.val_path)
parser.add_argument("--config-path", default=preprocess_text_config.config_path)

# 「話者ごと」のバリデーションデータ数、言語ごとではない!
# 元のコードや設定ファイルでval_per_langとなっていたので名前をそのままにしている
parser.add_argument(
"--val-per-lang",
default=preprocess_text_config.val_per_lang,
help="Number of validation data per SPEAKER, not per language (due to compatibility with the original code).",
)
parser.add_argument("--max-val-total", default=preprocess_text_config.max_val_total)
parser.add_argument("--use_jp_extra", action="store_true")
parser.add_argument("--yomi_error", default="raise")
parser.add_argument("--correct_path", action="store_true")

args = parser.parse_args()
logger.debug(f"args: {args}")

transcription_path = Path(args.transcription_path)
cleaned_path = Path(args.cleaned_path) if args.cleaned_path else None
train_path = Path(args.train_path)
val_path = Path(args.val_path)
config_path = Path(args.config_path)
val_per_lang = int(args.val_per_lang)
max_val_total = int(args.max_val_total)
use_jp_extra: bool = args.use_jp_extra
yomi_error: str = args.yomi_error
correct_path: bool = args.correct_path

preprocess(
transcription_path=transcription_path,
cleaned_path=cleaned_path,
train_path=train_path,
val_path=val_path,
config_path=config_path,
val_per_lang=val_per_lang,
max_val_total=max_val_total,
use_jp_extra=use_jp_extra,
yomi_error=yomi_error,
correct_path=correct_path,
)
4 changes: 4 additions & 0 deletions resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from style_bert_vits2.logging import logger
from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT


DEFAULT_BLOCK_SIZE: float = 0.400 # seconds


Expand All @@ -33,6 +34,9 @@ def normalize_audio(data: NDArray[Any], sr: int):


def resample(file: Path, output_dir: Path, target_sr: int, normalize: bool, trim: bool):
"""
fileを読み込んで、target_srなwavファイルに変換してoutput_dir直下に保存する
"""
try:
# librosaが読めるファイルかチェック
# wav以外にもmp3やoggやflacなども読める
Expand Down
1 change: 0 additions & 1 deletion style_bert_vits2/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@
format="<g>{time:MM-DD HH:mm:ss}</g> |<lvl>{level:^8}</lvl>| {file}:{line} | {message}",
backtrace=True,
diagnose=True,
level="TRACE",
)
2 changes: 1 addition & 1 deletion style_bert_vits2/nlp/japanese/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def normalize_text(text: str) -> str:

res = unicodedata.normalize("NFKC", text) # ここでアルファベットは半角になる
res = __convert_numbers_to_words(res) # 「100円」→「百円」等
# 「~」と「〜」とと「~」も長音記号として扱う
# 「~」と「〜」「~」も長音記号として扱う
res = res.replace("~", "ー")
res = res.replace("~", "ー")
res = res.replace("〜", "ー")
Expand Down
1 change: 0 additions & 1 deletion style_gen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Any

Expand Down
2 changes: 1 addition & 1 deletion webui/style_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import matplotlib.pyplot as plt
import numpy as np
import yaml
from config import config
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import DBSCAN, AgglomerativeClustering, KMeans
from sklearn.manifold import TSNE
from umap import UMAP

from config import config
from style_bert_vits2.constants import DEFAULT_STYLE, GRADIO_THEME
from style_bert_vits2.logging import logger

Expand Down
Loading

0 comments on commit 3d8b60c

Please sign in to comment.