Skip to content

Commit

Permalink
Export MeloTTS to ONNX (#1129)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jul 15, 2024
1 parent de04b3b commit 04c2319
Show file tree
Hide file tree
Showing 4 changed files with 573 additions and 0 deletions.
101 changes: 101 additions & 0 deletions .github/workflows/export-melo-tts-to-onnx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
name: export-melo-tts-to-onnx

on:
push:
branches:
- export-melo-tts-onnx
workflow_dispatch:

concurrency:
group: export-melo-tts-to-onnx-${{ github.ref }}
cancel-in-progress: true

jobs:
export-melo-tts-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export melo-tts
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Run
shell: bash
run: |
cd scripts/melo-tts
./run.sh
- uses: actions/upload-artifact@v4
with:
name: test.wav
path: scripts/melo-tts/test.wav

- name: Publish to huggingface (aishell)
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
export GIT_CLONE_PROTECTION_ACTIVE=false
git clone https://huggingface.co/csukuangfj/vits-melo-tts-zh_en huggingface
cd huggingface
git fetch
git pull
echo "pwd: $PWD"
ls -lh ../scripts/melo-tts
cp -v ../scripts/melo-tts/*.onnx .
cp -v ../scripts/melo-tts/lexicon.txt .
cp -v ../scripts/melo-tts/tokens.txt .
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2
tar xvf dict.tar.bz2
rm dict.tar.bz2
git lfs track "*.onnx"
git add .
git commit -m "add models"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true
cd ..
rm -rf huggingface/.git*
dst=vits-melo-tts-zh_en
mv huggingface $dst
tar cjvf $dst.tar.bz2 $dst
rm -rf $dst
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models
256 changes: 256 additions & 0 deletions scripts/melo-tts/export-onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
#!/usr/bin/env python3
from typing import Any, Dict

import onnx
import torch
from melo.api import TTS
from melo.text import language_id_map, language_tone_start_map
from melo.text.chinese import pinyin_to_symbol_map
from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict

for k, v in pinyin_to_symbol_map.items():
pinyin_to_symbol_map[k] = v.split()


def get_initial_final_tone(word: str):
initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)

ans_phone = []
ans_tone = []

for c, v in zip(initials, finals):
raw_pinyin = c + v
v_without_tone = v[:-1]
try:
tone = v[-1]
except:
print("skip", word, initials, finals)
return [], []

pinyin = c + v_without_tone
assert tone in "12345"

if c:
v_rep_map = {
"uei": "ui",
"iou": "iu",
"uen": "un",
}
if v_without_tone in v_rep_map.keys():
pinyin = c + v_rep_map[v_without_tone]
else:
pinyin_rep_map = {
"ing": "ying",
"i": "yi",
"in": "yin",
"u": "wu",
}
if pinyin in pinyin_rep_map.keys():
pinyin = pinyin_rep_map[pinyin]
else:
single_rep_map = {
"v": "yu",
"e": "e",
"i": "y",
"u": "w",
}
if pinyin[0] in single_rep_map.keys():
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
# print(word, initials, finals, pinyin)

if pinyin not in pinyin_to_symbol_map:
print("skip", pinyin, word, c, v, raw_pinyin)
continue
phone = pinyin_to_symbol_map[pinyin]
ans_phone += phone
ans_tone += [tone] * len(phone)

return ans_phone, ans_tone


def generate_tokens(symbol_list):
with open("tokens.txt", "w", encoding="utf-8") as f:
for i, s in enumerate(symbol_list):
f.write(f"{s} {i}\n")


def generate_lexicon():
word_dict = pinyin_dict.pinyin_dict
phrases = phrases_dict.phrases_dict
with open("lexicon.txt", "w", encoding="utf-8") as f:
for key in word_dict:
if not (0x4E00 <= key <= 0x9FA5):
continue
w = chr(key)
phone, tone = get_initial_final_tone(w)
if not phone:
continue
phone = " ".join(phone)
tone = " ".join(tone)
f.write(f"{w} {phone} {tone}\n")

for w in phrases:
phone, tone = get_initial_final_tone(w)
if not phone:
continue
assert len(phone) == len(tone), (len(phone), len(tone), phone, tone)
phone = " ".join(phone)
tone = " ".join(tone)
f.write(f"{w} {phone} {tone}\n")


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()

for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)

onnx.save(model, filename)


class ModelWrapper(torch.nn.Module):
def __init__(self, model: "SynthesizerTrn"):
super().__init__()
self.model = model

def forward(
self,
x,
x_lengths,
tones,
lang_id,
bert,
ja_bert,
sid,
noise_scale,
length_scale,
noise_scale_w,
max_len=None,
):
"""
Args:
x: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
tones: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
sid: an integer
"""
return self.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
tone=tones,
language=lang_id,
bert=bert,
ja_bert=ja_bert,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0]


def main():
generate_lexicon()

language = "ZH"
model = TTS(language=language, device="cpu")

generate_tokens(model.hps["symbols"])

torch_model = ModelWrapper(model.model)

opset_version = 13
x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
print(x.shape)
x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
sid = torch.tensor([1], dtype=torch.int64)
tones = torch.zeros_like(x)
lang_id = torch.ones_like(x)
noise_scale = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
noise_scale_w = torch.tensor([1.0], dtype=torch.float32)

bert = torch.zeros(1024, x.shape[0], dtype=torch.float32)
ja_bert = torch.zeros(768, x.shape[0], dtype=torch.float32)

x = x.unsqueeze(0)
tones = tones.unsqueeze(0)
lang_id = lang_id.unsqueeze(0)
bert = bert.unsqueeze(0)
ja_bert = ja_bert.unsqueeze(0)

filename = "model.onnx"

torch.onnx.export(
torch_model,
(
x,
x_lengths,
tones,
lang_id,
bert,
ja_bert,
sid,
noise_scale,
length_scale,
noise_scale_w,
),
filename,
opset_version=opset_version,
input_names=[
"x",
"x_lengths",
"tones",
"lang_id",
"bert",
"ja_bert",
"sid",
"noise_scale",
"length_scale",
"noise_scale_w",
],
output_names=["y"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_lengths": {0: "N"},
"tones": {0: "N", 1: "L"},
"lang_id": {0: "N", 1: "L"},
"bert": {0: "N", 2: "L"},
"ja_bert": {0: "N", 2: "L"},
"y": {0: "N", 1: "S", 2: "T"},
},
)

meta_data = {
"model_type": "melo-vits",
"comment": "melo",
"language": "Chinese + English",
"add_blank": int(model.hps.data.add_blank),
"n_speakers": 1,
"sample_rate": model.hps.data.sampling_rate,
"bert_dim": 1024,
"ja_bert_dim": 768,
"speaker_id": list(model.hps.data.spk2id.values())[0],
"lang_id": language_id_map[model.language],
"tone_start": language_tone_start_map[model.language],
"url": "https://github.com/myshell-ai/MeloTTS",
"license": "MIT license",
"description": "MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai",
}
add_meta_data(filename, meta_data)


if __name__ == "__main__":
main()
Loading

0 comments on commit 04c2319

Please sign in to comment.