Skip to content

Commit

Permalink
feat: gradio interface (metavoiceio#49)
Browse files Browse the repository at this point in the history
* feat: gradio interface for hf space

* feat: move speaker encoder to model hub

* feat: fix missing guidance val

---------

Co-authored-by: sid <sid@themetavoice.xyz>
  • Loading branch information
sidroopdaska and sid authored Feb 21, 2024
1 parent 56e8b54 commit bf1d63c
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 63 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
---
tags: ["text-to-speech", "metavoice", "english", "pretrained"]
suggested_hardware: "a10g-small"
---

# MetaVoice-1B


Expand Down Expand Up @@ -50,10 +55,10 @@ pip install -e .
python fam/llm/sample.py --spk_cond_path="assets/bria.mp3" --text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
```

2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](/fam/llm/serving.py) and [UI](/fam/ui/app.py).
2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](/fam/llm/serving.py) or [web UI](/fam/ui/app.py)
```bash
python fam/llm/serving.py
python fam/ui/app.py
python app.py
```

3. Use it via [Hugging Face](https://huggingface.co/metavoiceio)
Expand Down
151 changes: 99 additions & 52 deletions fam/ui/app.py → app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,67 @@
import io
import json
import os
import sys

project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.insert(0, project_root)


import gradio as gr
import requests
import soundfile as sf
from huggingface_hub import snapshot_download

from fam.llm.sample import (
InferenceConfig,
SamplingControllerConfig,
build_models,
get_first_stage_path,
get_second_stage_path,
sample_utterance,
)
from fam.llm.utils import check_audio_file

#### setup model
sampling_config = SamplingControllerConfig(
huggingface_repo_id="metavoiceio/metavoice-1B-v0.1", spk_cond_path=""
) # spk_cond_path added later
model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
first_stage_ckpt_path = get_first_stage_path(model_dir)
second_stage_ckpt_path = get_second_stage_path(model_dir)

config_first_stage = InferenceConfig(
ckpt_path=first_stage_ckpt_path,
num_samples=sampling_config.num_samples,
seed=sampling_config.seed,
device=sampling_config.device,
dtype=sampling_config.dtype,
compile=sampling_config.compile,
init_from=sampling_config.init_from,
output_dir=sampling_config.output_dir,
)

config_second_stage = InferenceConfig(
ckpt_path=second_stage_ckpt_path,
num_samples=sampling_config.num_samples,
seed=sampling_config.seed,
device=sampling_config.device,
dtype=sampling_config.dtype,
compile=sampling_config.compile,
init_from=sampling_config.init_from,
output_dir=sampling_config.output_dir,
)

API_SERVER_URL = "http://127.0.0.1:58003/tts"
RADIO_CHOICES = ["Preset voices", "Upload target voice"]
sampling_config.max_new_tokens *= 2 # deal with max_new_tokens for flattened interleaving!

# define models
smodel, llm_first_stage, llm_second_stage = build_models(
config_first_stage,
config_second_stage,
model_dir=model_dir,
device=sampling_config.device,
use_kv_cache=sampling_config.use_kv_cache,
)

#### setup interface
RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
MAX_CHARS = 220
PRESET_VOICES = {
# female
Expand All @@ -28,6 +82,15 @@ def denormalise_guidance(guidance):
return 1 + ((guidance - 1) * (3 - 1)) / (5 - 1)


def _check_file_size(path):
if not path:
return
filesize = os.path.getsize(path)
filesize_mb = filesize / 1024 / 1024
if filesize_mb >= 50:
raise gr.Error(f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB")


def _handle_edge_cases(to_say, upload_target):
if not to_say:
raise gr.Error("Please provide text to synthesise")
Expand All @@ -37,52 +100,38 @@ def _handle_edge_cases(to_say, upload_target):
f"Max {MAX_CHARS} characters allowed. Provided: {len(to_say)} characters. Truncating and generating speech...Result at the end can be unstable as a result."
)

def _check_file_size(path):
if not path:
return
filesize = os.path.getsize(path)
filesize_mb = filesize / 1024 / 1024
if filesize_mb >= 50:
raise gr.Error(
f"Please upload a sample less than 20MB for voice cloning. Provided: {round(filesize_mb)} MB"
)
if not upload_target:
return

check_audio_file(upload_target) # check file duration to be atleast 30s
_check_file_size(upload_target)


def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target):
d_top_p = denormalise_top_p(top_p)
d_guidance = denormalise_guidance(guidance)

_handle_edge_cases(to_say, upload_target)

to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]

custom_target_path = upload_target if toggle == RADIO_CHOICES[1] else None

config = {
"text": to_say,
"guidance": (d_guidance, 1.0),
"top_p": d_top_p,
"speaker_ref_path": PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else None,
}
headers = {"Content-Type": "audio/wav", "X-Payload": json.dumps(config)}
if not custom_target_path:
response = requests.post(API_SERVER_URL, headers=headers, data=None)
else:
with open(custom_target_path, "rb") as f:
data = f.read()
response = requests.post(API_SERVER_URL, headers=headers, data=data)

wav, sr = None, None
if response.status_code == 200:
audio_buffer = io.BytesIO(response.content)
audio_buffer.seek(0)
wav, sr = sf.read(audio_buffer, dtype="float32")
else:
print(f"Something went wrong. response status code: {response.status_code}")

return sr, wav
try:
d_top_p = denormalise_top_p(top_p)
d_guidance = denormalise_guidance(guidance)

_handle_edge_cases(to_say, upload_target)

to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]
return sample_utterance(
to_say,
spk_cond_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target,
spkemb_model=smodel,
first_stage_model=llm_first_stage,
second_stage_model=llm_second_stage,
enhancer=sampling_config.enhancer,
guidance_scale=(d_guidance, 1.0),
max_new_tokens=sampling_config.max_new_tokens,
temperature=sampling_config.temperature,
top_k=sampling_config.top_k,
top_p=d_top_p,
first_stage_ckpt_path=None,
second_stage_ckpt_path=None,
)
except Exception as e:
raise gr.Error(f"Something went wrong. Reason: {str(e)}")


def change_voice_selection_layout(choice):
Expand All @@ -105,9 +154,9 @@ def change_voice_selection_layout(choice):
<strong>MetaVoice-1B</strong> is a 1.2B parameter base model for TTS (text-to-speech). It has been built with the following priorities:
\n
* <strong>Emotional speech rhythm and tone</strong> in English.
* <strong>Zero-shot cloning for American & British voices</strong>, with 30s reference audio.
* Support for <strong>voice cloning with finetuning</strong>.
* We have had success with as little as 1 minute training data for Indian speakers.
* <strong>Zero-shot cloning for American & British voices</strong>, with 30s reference audio.
* Support for <strong>long-form synthesis</strong>.
We are releasing the model under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0). See [Github](https://github.com/metavoiceio/metavoice-src) for details and to contribute.
Expand Down Expand Up @@ -169,7 +218,7 @@ def change_voice_selection_layout(choice):

with gr.Column():
speech = gr.Audio(
type="numpy",
type="filepath",
label="MetaVoice-1B says...",
)

Expand All @@ -182,6 +231,4 @@ def change_voice_selection_layout(choice):


demo.queue(default_concurrency_limit=2)
demo.launch(
favicon_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets/favicon.ico"),
)
demo.launch(favicon_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets/favicon.ico"))
File renamed without changes.
15 changes: 8 additions & 7 deletions fam/llm/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dataclasses import dataclass
from typing import List, Literal, Optional, Tuple, Type, Union

import librosa
import torch
import tqdm
import tqdm.contrib.concurrent
Expand All @@ -21,7 +20,12 @@
from fam.llm.decoders import Decoder, EncodecDecoder
from fam.llm.enhancers import BaseEnhancer, get_enhancer
from fam.llm.model import GPT, GPTConfig
from fam.llm.utils import get_default_dtype, get_default_use_kv_cache, normalize_text
from fam.llm.utils import (
check_audio_file,
get_default_dtype,
get_default_use_kv_cache,
normalize_text,
)
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
from fam.quantiser.text.tokenise import TrainedBPETokeniser

Expand Down Expand Up @@ -413,11 +417,6 @@ def get_cached_file(file_or_uri: str):
cache_path = file_or_uri
else:
raise FileNotFoundError(f"File {file_or_uri} not found!")

# check audio file is at min. 30s in length
audio, sr = librosa.load(cache_path)
assert librosa.get_duration(y=audio, sr=sr) >= 30, "Speaker reference audio file needs to be >= 30s in duration."

return cache_path


Expand Down Expand Up @@ -658,6 +657,8 @@ class SamplingControllerConfig:
# TODO: add support for batch sampling via CLI. Function has been implemented above.
sampling_config = tyro.cli(SamplingControllerConfig, use_underscores=True)

check_audio_file(sampling_config.spk_cond_path)

model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
first_stage_ckpt_path = get_first_stage_path(model_dir)
second_stage_ckpt_path = get_second_stage_path(model_dir)
Expand Down
3 changes: 2 additions & 1 deletion fam/llm/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
get_second_stage_path,
sample_utterance,
)
from fam.llm.utils import get_default_dtype, get_default_use_kv_cache
from fam.llm.utils import check_audio_file, get_default_dtype, get_default_use_kv_cache

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,6 +102,7 @@ async def text_to_speech(req: Request):
with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp:
if tts_req.speaker_ref_path is None:
wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp)
check_audio_file(wav_path)
else:
wav_path = tts_req.speaker_ref_path
if wav_path is None:
Expand Down
26 changes: 26 additions & 0 deletions fam/llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import re
import subprocess
import tempfile

import librosa
import torch


Expand Down Expand Up @@ -49,6 +53,28 @@ def normalize_text(text: str) -> str:
return text


def check_audio_file(path_or_uri, threshold_s=30):
if "http" in path_or_uri:
temp_fd, filepath = tempfile.mkstemp()
os.close(temp_fd) # Close the file descriptor, curl will create a new connection
curl_command = ["curl", "-L", path_or_uri, "-o", filepath]
subprocess.run(curl_command, check=True)

else:
filepath = path_or_uri

audio, sr = librosa.load(filepath)
duration_s = librosa.get_duration(y=audio, sr=sr)
if duration_s < threshold_s:
raise Exception(
f"The audio file is too short. Please provide an audio file that is at least {threshold_s} seconds long to proceed."
)

# Clean up the temporary file if it was created
if "http" in path_or_uri:
os.remove(filepath)


def get_default_use_kv_cache() -> str:
"""Compute default value for 'use_kv_cache' based on GPU architecture"""
if torch.cuda.is_available():
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@ uvicorn
tyro
deepfilternet
pydub
soundfile
gradio
huggingface_hub

0 comments on commit bf1d63c

Please sign in to comment.