forked from RVC-Boss/GPT-SoVITS
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
324 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,324 @@ | ||
import argparse | ||
import os | ||
import signal | ||
import sys | ||
from time import time as ttime | ||
import torch | ||
import librosa | ||
import soundfile as sf | ||
from fastapi import FastAPI, Request, HTTPException | ||
from fastapi.responses import StreamingResponse | ||
import uvicorn | ||
from transformers import AutoModelForMaskedLM, AutoTokenizer | ||
import numpy as np | ||
from feature_extractor import cnhubert | ||
from io import BytesIO | ||
from module.models import SynthesizerTrn | ||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule | ||
from text import cleaned_text_to_sequence | ||
from text.cleaner import clean_text | ||
from module.mel_processing import spectrogram_torch | ||
from my_utils import load_audio | ||
|
||
DEFAULT_PORT = 9880 | ||
DEFAULT_CNHUBERT = "GPT_SoVITS/pretrained_models/chinese-hubert-base" | ||
DEFAULT_BERT = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" | ||
DEFAULT_HALF = True | ||
|
||
DEFAULT_GPT = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" | ||
DEFAULT_SOVITS = "GPT_SoVITS/pretrained_models/s2G488k.pth" | ||
|
||
AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
parser = argparse.ArgumentParser(description="GPT-SoVITS api") | ||
|
||
parser.add_argument("-g", "--gpt_path", type=str, default="", help="GPT模型路径") | ||
parser.add_argument("-s", "--sovits_path", type=str, default="", help="SoVITS模型路径") | ||
|
||
parser.add_argument("-dr", "--default_refer_path", type=str, default="", | ||
help="默认参考音频路径, 请求缺少参考音频时调用") | ||
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") | ||
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") | ||
|
||
parser.add_argument("-d", "--device", type=str, default=AVAILABLE_COMPUTE, help="cuda / cpu") | ||
parser.add_argument("-p", "--port", type=int, default=DEFAULT_PORT, help="default: 9880") | ||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1") | ||
parser.add_argument("-hp", "--half_precision", action='store_true', default=False) | ||
|
||
parser.add_argument("-hb", "--hubert_path", type=str, default=DEFAULT_CNHUBERT) | ||
parser.add_argument("-b", "--bert_path", type=str, default=DEFAULT_BERT) | ||
|
||
args = parser.parse_args() | ||
|
||
gpt_path = args.gpt_path | ||
sovits_path = args.sovits_path | ||
|
||
default_refer_path = args.default_refer_path | ||
default_refer_text = args.default_refer_text | ||
default_refer_language = args.default_refer_language | ||
has_preset = False | ||
|
||
device = args.device | ||
port = args.port | ||
host = args.bind_addr | ||
is_half = args.half_precision | ||
|
||
cnhubert_base_path = args.hubert_path | ||
bert_path = args.bert_path | ||
|
||
if gpt_path == "": | ||
gpt_path = DEFAULT_GPT | ||
print("[WARN] 未指定GPT模型路径") | ||
if sovits_path == "": | ||
sovits_path = DEFAULT_SOVITS | ||
print("[WARN] 未指定SoVITS模型路径") | ||
|
||
if default_refer_path == "" or default_refer_text == "" or default_refer_language == "": | ||
default_refer_path, default_refer_text, default_refer_language = "", "", "" | ||
print("[INFO] 未指定默认参考音频") | ||
has_preset = False | ||
else: | ||
print(f"[INFO] 默认参考音频路径: {default_refer_path}") | ||
print(f"[INFO] 默认参考音频文本: {default_refer_text}") | ||
print(f"[INFO] 默认参考音频语种: {default_refer_language}") | ||
has_preset = True | ||
|
||
cnhubert.cnhubert_base_path = cnhubert_base_path | ||
tokenizer = AutoTokenizer.from_pretrained(bert_path) | ||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) | ||
# bert_model = AutoModelForSequenceClassification.from_pretrained(bert_path, config=bert_path+"/config.json") | ||
if (is_half == True): | ||
bert_model = bert_model.half().to(device) | ||
else: | ||
bert_model = bert_model.to(device) | ||
|
||
|
||
# bert_model=bert_model.to(device) | ||
def get_bert_feature(text, word2ph): | ||
with torch.no_grad(): | ||
inputs = tokenizer(text, return_tensors="pt") | ||
for i in inputs: | ||
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model | ||
res = bert_model(**inputs, output_hidden_states=True) | ||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] | ||
assert len(word2ph) == len(text) | ||
phone_level_feature = [] | ||
for i in range(len(word2ph)): | ||
repeat_feature = res[i].repeat(word2ph[i], 1) | ||
phone_level_feature.append(repeat_feature) | ||
phone_level_feature = torch.cat(phone_level_feature, dim=0) | ||
# if(is_half==True):phone_level_feature=phone_level_feature.half() | ||
return phone_level_feature.T | ||
|
||
|
||
n_semantic = 1024 | ||
dict_s2 = torch.load(sovits_path, map_location="cpu") | ||
hps = dict_s2["config"] | ||
|
||
|
||
class DictToAttrRecursive: | ||
def __init__(self, input_dict): | ||
for key, value in input_dict.items(): | ||
if isinstance(value, dict): | ||
# 如果值是字典,递归调用构造函数 | ||
setattr(self, key, DictToAttrRecursive(value)) | ||
else: | ||
setattr(self, key, value) | ||
|
||
|
||
hps = DictToAttrRecursive(hps) | ||
hps.model.semantic_frame_rate = "25hz" | ||
dict_s1 = torch.load(gpt_path, map_location="cpu") | ||
config = dict_s1["config"] | ||
ssl_model = cnhubert.get_model() | ||
if is_half: | ||
ssl_model = ssl_model.half().to(device) | ||
else: | ||
ssl_model = ssl_model.to(device) | ||
|
||
vq_model = SynthesizerTrn( | ||
hps.data.filter_length // 2 + 1, | ||
hps.train.segment_size // hps.data.hop_length, | ||
n_speakers=hps.data.n_speakers, | ||
**hps.model) | ||
if is_half: | ||
vq_model = vq_model.half().to(device) | ||
else: | ||
vq_model = vq_model.to(device) | ||
vq_model.eval() | ||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False)) | ||
hz = 50 | ||
max_sec = config['data']['max_sec'] | ||
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False) | ||
t2s_model.load_state_dict(dict_s1["weight"]) | ||
if is_half: | ||
t2s_model = t2s_model.half() | ||
t2s_model = t2s_model.to(device) | ||
t2s_model.eval() | ||
total = sum([param.nelement() for param in t2s_model.parameters()]) | ||
print("Number of parameter: %.2fM" % (total / 1e6)) | ||
|
||
|
||
def get_spepc(hps, filename): | ||
audio = load_audio(filename, int(hps.data.sampling_rate)) | ||
audio = torch.FloatTensor(audio) | ||
audio_norm = audio | ||
audio_norm = audio_norm.unsqueeze(0) | ||
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length, | ||
hps.data.win_length, center=False) | ||
return spec | ||
|
||
|
||
dict_language = { | ||
"中文": "zh", | ||
"英文": "en", | ||
"日文": "ja", | ||
"ZH": "zh", | ||
"EN": "en", | ||
"JA": "ja", | ||
"zh": "zh", | ||
"en": "en", | ||
"ja": "ja" | ||
} | ||
|
||
|
||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language): | ||
t0 = ttime() | ||
prompt_text = prompt_text.strip("\n") | ||
prompt_language, text = prompt_language, text.strip("\n") | ||
with torch.no_grad(): | ||
wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙 | ||
wav16k = torch.from_numpy(wav16k) | ||
if (is_half == True): | ||
wav16k = wav16k.half().to(device) | ||
else: | ||
wav16k = wav16k.to(device) | ||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() | ||
codes = vq_model.extract_latent(ssl_content) | ||
prompt_semantic = codes[0, 0] | ||
t1 = ttime() | ||
prompt_language = dict_language[prompt_language] | ||
text_language = dict_language[text_language] | ||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language) | ||
phones1 = cleaned_text_to_sequence(phones1) | ||
texts = text.split("\n") | ||
audio_opt = [] | ||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) | ||
for text in texts: | ||
phones2, word2ph2, norm_text2 = clean_text(text, text_language) | ||
phones2 = cleaned_text_to_sequence(phones2) | ||
if (prompt_language == "zh"): | ||
bert1 = get_bert_feature(norm_text1, word2ph1).to(device) | ||
else: | ||
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to( | ||
device) | ||
if (text_language == "zh"): | ||
bert2 = get_bert_feature(norm_text2, word2ph2).to(device) | ||
else: | ||
bert2 = torch.zeros((1024, len(phones2))).to(bert1) | ||
bert = torch.cat([bert1, bert2], 1) | ||
|
||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) | ||
bert = bert.to(device).unsqueeze(0) | ||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) | ||
prompt = prompt_semantic.unsqueeze(0).to(device) | ||
t2 = ttime() | ||
with torch.no_grad(): | ||
# pred_semantic = t2s_model.model.infer( | ||
pred_semantic, idx = t2s_model.model.infer_panel( | ||
all_phoneme_ids, | ||
all_phoneme_len, | ||
prompt, | ||
bert, | ||
# prompt_phone_len=ph_offset, | ||
top_k=config['inference']['top_k'], | ||
early_stop_num=hz * max_sec) | ||
t3 = ttime() | ||
# print(pred_semantic.shape,idx) | ||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次 | ||
refer = get_spepc(hps, ref_wav_path) # .to(device) | ||
if (is_half == True): | ||
refer = refer.half().to(device) | ||
else: | ||
refer = refer.to(device) | ||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0] | ||
audio = \ | ||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), | ||
refer).detach().cpu().numpy()[ | ||
0, 0] ###试试重建不带上prompt部分 | ||
audio_opt.append(audio) | ||
audio_opt.append(zero_wav) | ||
t4 = ttime() | ||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) | ||
yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16) | ||
|
||
|
||
def restart(): | ||
python = sys.executable | ||
os.execl(python, python, *sys.argv) | ||
|
||
|
||
def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language): | ||
if command == "/restart": | ||
restart() | ||
elif command == "/exit": | ||
os.kill(os.getpid(), signal.SIGTERM) | ||
exit(0) | ||
|
||
if ( | ||
refer_wav_path == "" or refer_wav_path is None | ||
or prompt_text == "" or prompt_text is None | ||
or prompt_language == "" or prompt_language is None | ||
): | ||
refer_wav_path, prompt_text, prompt_language = ( | ||
default_refer_path, | ||
default_refer_text, | ||
default_refer_language, | ||
) | ||
if not has_preset: | ||
raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设") | ||
|
||
with torch.no_grad(): | ||
gen = get_tts_wav( | ||
refer_wav_path, prompt_text, prompt_language, text, text_language | ||
) | ||
sampling_rate, audio_data = next(gen) | ||
|
||
wav = BytesIO() | ||
sf.write(wav, audio_data, sampling_rate, format="wav") | ||
wav.seek(0) | ||
|
||
torch.cuda.empty_cache() | ||
return StreamingResponse(wav, media_type="audio/wav") | ||
|
||
|
||
app = FastAPI() | ||
|
||
|
||
@app.post("/") | ||
async def tts_endpoint(request: Request): | ||
json_post_raw = await request.json() | ||
return handle( | ||
json_post_raw.get("command"), | ||
json_post_raw.get("refer_wav_path"), | ||
json_post_raw.get("prompt_text"), | ||
json_post_raw.get("prompt_language"), | ||
json_post_raw.get("text"), | ||
json_post_raw.get("text_language"), | ||
) | ||
|
||
|
||
@app.get("/") | ||
async def tts_endpoint( | ||
command: str = None, | ||
refer_wav_path: str = None, | ||
prompt_text: str = None, | ||
prompt_language: str = None, | ||
text: str = None, | ||
text_language: str = None, | ||
): | ||
return handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language) | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run(app, host=host, port=port, workers=1) |