Skip to content

Commit

Permalink
feat: 2-3x inference speedup, faster than real-time (#71)
Browse files Browse the repository at this point in the history
* feat: 2.5x inference speedup, faster than real-time

* fix pr comments

* improve README

* feat: add faster inferencing to serving.py

* feat: faster inference on gradio app

* feat: support for bfloat16 & float16

* feat: calc RTF

* ckpt

* update: error messaging

* update: output location

* feat: improving terminal messaging

* update: README.md

* update: dockerfile

* fix: containerisation

* fix: t4 detection

* update: README.md

---------

Co-authored-by: sid <sid@themetavoice.xyz>
  • Loading branch information
vatsalaggarwal and sid authored Feb 25, 2024
1 parent 182ec71 commit 26fc3df
Show file tree
Hide file tree
Showing 14 changed files with 916 additions and 242 deletions.
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ WORKDIR /app
COPY requirements.txt requirements.txt

RUN pip install --no-cache-dir packaging wheel torch
RUN pip install --no-cache-dir flash-attn
RUN pip install --no-cache-dir audiocraft # HACK: installation fails within the requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install --no-cache-dir --upgrade torch torchaudio

COPY . .

RUN pip install --no-cache-dir -e .

ENTRYPOINT ["python3.10", "fam/llm/serving.py"]
ENTRYPOINT ["python3.10", "serving.py"]
30 changes: 17 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ docker-compose up -d server && docker-compose ps && docker-compose logs -f
## Installation

**Pre-requisites:**
- GPU VRAM >=16GB
- GPU VRAM >=12GB
- Python >=3.10,<3.12

**Environment setup**
Expand All @@ -49,32 +49,36 @@ rm -rf ffmpeg-git-*
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

pip install -r requirements.txt

# Flash Attention works only on latest Nvidia GPUs (Hopper, Ampere & Ada). If you have a different GPU (Tesla or Turing), do not install this.
pip install flash-attn

pip install --upgrade torch torchaudio  # for torch.compile improvements
pip install -e .
```

## Usage
1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/sample.py)
1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py)
```bash
python fam/llm/sample.py --spk_cond_path="assets/bria.mp3" --text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
python -i fam/llm/fast_inference.py

# Run e.g. of API usage within the interactive python session
tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model.", spk_ref_path="assets/bria.mp3")
```
> Note: The script takes 30-90s to startup (depending on hardware). This is because we torch.compile the model for fast inference.
> On Ampere, Ada-Lovelace, and Hopper architecture GPUs, once compiled, the synthesise() API runs faster than real-time, with a Real-Time Factor (RTF) < 1.0.
2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](/fam/llm/serving.py) or [web UI](/fam/ui/app.py)
2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py)
```bash
python fam/llm/serving.py
python serving.py
python app.py
```

3. Use it via [Hugging Face](https://huggingface.co/metavoiceio)
4. [Google Collab](https://colab.research.google.com/drive/1UmjE1mzfG4td0rCjJEaAWGQXpn_GuwwY?authuser=0#scrollTo=mPgTfUdBJF1B)

## Soon
- Faster inference ⚡
- Fine-tuning code
- Synthesis of arbitrary length text

## Upcoming
- [x] Faster inference ⚡
- [ ] Fine-tuning code
- [ ] Synthesis of arbitrary length text


## Architecture
We predict EnCodec tokens from text, and speaker information. This is then diffused up to the waveform level, with post-processing applied to clean up the audio.
Expand Down
70 changes: 8 additions & 62 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,12 @@


import gradio as gr
from huggingface_hub import snapshot_download

from fam.llm.sample import (
InferenceConfig,
SamplingControllerConfig,
build_models,
get_first_stage_path,
get_second_stage_path,
sample_utterance,
)

from fam.llm.fast_inference import TTS
from fam.llm.utils import check_audio_file

#### setup model
sampling_config = SamplingControllerConfig(
huggingface_repo_id="metavoiceio/metavoice-1B-v0.1", spk_cond_path=""
) # spk_cond_path added later
model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
first_stage_ckpt_path = get_first_stage_path(model_dir)
second_stage_ckpt_path = get_second_stage_path(model_dir)

config_first_stage = InferenceConfig(
ckpt_path=first_stage_ckpt_path,
num_samples=sampling_config.num_samples,
seed=sampling_config.seed,
device=sampling_config.device,
dtype=sampling_config.dtype,
compile=sampling_config.compile,
init_from=sampling_config.init_from,
output_dir=sampling_config.output_dir,
)

config_second_stage = InferenceConfig(
ckpt_path=second_stage_ckpt_path,
num_samples=sampling_config.num_samples,
seed=sampling_config.seed,
device=sampling_config.device,
dtype=sampling_config.dtype,
compile=sampling_config.compile,
init_from=sampling_config.init_from,
output_dir=sampling_config.output_dir,
)

sampling_config.max_new_tokens *= 2 # deal with max_new_tokens for flattened interleaving!

# define models
smodel, llm_first_stage, llm_second_stage = build_models(
config_first_stage,
config_second_stage,
model_dir=model_dir,
device=sampling_config.device,
use_kv_cache=sampling_config.use_kv_cache,
)
TTS_MODEL = TTS()

#### setup interface
RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
Expand Down Expand Up @@ -115,20 +69,12 @@ def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target):
_handle_edge_cases(to_say, upload_target)

to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]
return sample_utterance(
to_say,
spk_cond_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target,
spkemb_model=smodel,
first_stage_model=llm_first_stage,
second_stage_model=llm_second_stage,
enhancer=sampling_config.enhancer,
guidance_scale=(d_guidance, 1.0),
max_new_tokens=sampling_config.max_new_tokens,
temperature=sampling_config.temperature,
top_k=sampling_config.top_k,

return TTS_MODEL.synthesise(
text=to_say,
spk_ref_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target,
top_p=d_top_p,
first_stage_ckpt_path=None,
second_stage_ckpt_path=None,
guidance_scale=d_guidance,
)
except Exception as e:
raise gr.Error(f"Something went wrong. Reason: {str(e)}")
Expand Down
5 changes: 2 additions & 3 deletions fam/llm/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def decode(
# TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file.
text_ids, extracted_audio_ids = self._data_adapter_fn(tokens)
text = self.tokeniser_decode_fn(text_ids)
print(f"Text: {text}")
# print(f"Text: {text}")

tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0)

Expand All @@ -92,11 +92,10 @@ def decode(
try:
wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}"
self._save_audio(wav_file_name, wav)
print(f"\nSaved audio to {wav_file_name}.wav")
return wav_file_name
except Exception as e:
print(f"Failed to save audio! Reason: {e}")

wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}"
self._save_audio(wav_file_name, wav)
print(f"\nSaved audio to {wav_file_name}.wav")
return wav_file_name
6 changes: 6 additions & 0 deletions fam/llm/enhancers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
"""

if enhancer_name == "df":
import warnings

warnings.filterwarnings(
"ignore",
message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.',
)
return DFEnhancer()
else:
raise ValueError(f"Unknown enhancer name: {enhancer_name}")
141 changes: 141 additions & 0 deletions fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os
import shutil
import tempfile
import time
from pathlib import Path

import librosa
import torch
from huggingface_hub import snapshot_download

from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
from fam.llm.decoders import EncodecDecoder
from fam.llm.fast_inference_utils import build_model, main
from fam.llm.inference import (
EncodecDecoder,
InferenceConfig,
Model,
TiltedEncodec,
TrainedBPETokeniser,
get_cached_embedding,
get_cached_file,
get_enhancer,
)
from fam.llm.utils import (
check_audio_file,
get_default_dtype,
get_device,
normalize_text,
)


class TTS:
def __init__(
self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs"
):
"""
model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio)
"""

# NOTE: this needs to come first so that we don't change global state when we want to use
# the torch.compiled-model.
self._dtype = get_default_dtype()
self._device = get_device()
self._model_dir = snapshot_download(repo_id=model_name)
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)

second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
config_second_stage = InferenceConfig(
ckpt_path=second_stage_ckpt_path,
num_samples=1,
seed=seed,
device=self._device,
dtype=self._dtype,
compile=False,
init_from="resume",
output_dir=self.output_dir,
)
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
self.llm_second_stage = Model(
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
)
self.enhancer = get_enhancer("df")

self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
self.model, self.tokenizer, self.smodel, self.model_size = build_model(
precision=self.precision,
checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"),
spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
device=self._device,
compile=True,
compile_prefill=True,
)

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
"""
text: Text to speak
spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3
top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker
guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style.
temperature: Temperature for sampling applied to both LLMs (first & second stage)
returns: path to speech .wav file
"""
text = normalize_text(text)
spk_ref_path = get_cached_file(spk_ref_path)
check_audio_file(spk_ref_path)
spk_emb = get_cached_embedding(
spk_ref_path,
self.smodel,
).to(device=self._device, dtype=self.precision)

start = time.time()
# first stage LLM
tokens = main(
model=self.model,
tokenizer=self.tokenizer,
model_size=self.model_size,
prompt=text,
spk_emb=spk_emb,
top_p=torch.tensor(top_p, device=self._device, dtype=self.precision),
guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision),
temperature=torch.tensor(temperature, device=self._device, dtype=self.precision),
)
text_ids, extracted_audio_ids = self.first_stage_adapter.decode([tokens])

b_speaker_embs = spk_emb.unsqueeze(0)

# second stage LLM + multi-band diffusion model
wav_files = self.llm_second_stage(
texts=[text],
encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)],
speaker_embs=b_speaker_embs,
batch_size=1,
guidance_scale=None,
top_p=None,
top_k=200,
temperature=1.0,
max_new_tokens=None,
)

# enhance using deepfilternet
wav_file = wav_files[0]
with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
print(f"\nSaved audio to {wav_file}.wav")

# calculating real-time factor (RTF)
time_to_synth_s = time.time() - start
audio, sr = librosa.load(str(wav_file) + ".wav")
duration_s = librosa.get_duration(y=audio, sr=sr)
print(f"\nTotal time to synth (s): {time_to_synth_s}")
print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}")

return str(wav_file) + ".wav"


if __name__ == "__main__":
tts = TTS()
Loading

0 comments on commit 26fc3df

Please sign in to comment.