Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 2-3x inference speedup, faster than real-time #71

Merged
merged 17 commits into from
Feb 25, 2024
Merged
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ WORKDIR /app
COPY requirements.txt requirements.txt

RUN pip install --no-cache-dir packaging wheel torch
RUN pip install --no-cache-dir flash-attn
RUN pip install --no-cache-dir audiocraft # HACK: installation fails within the requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install --no-cache-dir --upgrade torch torchaudio

COPY . .

RUN pip install --no-cache-dir -e .

ENTRYPOINT ["python3.10", "fam/llm/serving.py"]
ENTRYPOINT ["python3.10", "serving.py"]
30 changes: 17 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ docker-compose up -d server && docker-compose ps && docker-compose logs -f
## Installation

**Pre-requisites:**
- GPU VRAM >=16GB
- GPU VRAM >=12GB
- Python >=3.10,<3.12

**Environment setup**
Expand All @@ -49,32 +49,36 @@ rm -rf ffmpeg-git-*
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

pip install -r requirements.txt

# Flash Attention works only on latest Nvidia GPUs (Hopper, Ampere & Ada). If you have a different GPU (Tesla or Turing), do not install this.
pip install flash-attn

pip install --upgrade torch torchaudio  # for torch.compile improvements
pip install -e .
```

## Usage
1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/sample.py)
1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py)
```bash
python fam/llm/sample.py --spk_cond_path="assets/bria.mp3" --text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
python -i fam/llm/fast_inference.py

# Run e.g. of API usage within the interactive python session
tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model.", spk_ref_path="assets/bria.mp3")
```
> Note: The script takes 30-90s to startup (depending on hardware). This is because we torch.compile the model for fast inference.
> On Ampere, Ada-Lovelace, and Hopper architecture GPUs, once compiled, the synthesise() API runs faster than real-time, with a Real-Time Factor (RTF) < 1.0.

2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](/fam/llm/serving.py) or [web UI](/fam/ui/app.py)
2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py)
```bash
python fam/llm/serving.py
python serving.py
python app.py
```

3. Use it via [Hugging Face](https://huggingface.co/metavoiceio)
4. [Google Collab](https://colab.research.google.com/drive/1UmjE1mzfG4td0rCjJEaAWGQXpn_GuwwY?authuser=0#scrollTo=mPgTfUdBJF1B)

## Soon
- Faster inference ⚡
- Fine-tuning code
- Synthesis of arbitrary length text

## Upcoming
- [x] Faster inference ⚡
- [ ] Fine-tuning code
- [ ] Synthesis of arbitrary length text


## Architecture
We predict EnCodec tokens from text, and speaker information. This is then diffused up to the waveform level, with post-processing applied to clean up the audio.
Expand Down
70 changes: 8 additions & 62 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,12 @@


import gradio as gr
from huggingface_hub import snapshot_download

from fam.llm.sample import (
InferenceConfig,
SamplingControllerConfig,
build_models,
get_first_stage_path,
get_second_stage_path,
sample_utterance,
)

from fam.llm.fast_inference import TTS
from fam.llm.utils import check_audio_file

#### setup model
sampling_config = SamplingControllerConfig(
huggingface_repo_id="metavoiceio/metavoice-1B-v0.1", spk_cond_path=""
) # spk_cond_path added later
model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
first_stage_ckpt_path = get_first_stage_path(model_dir)
second_stage_ckpt_path = get_second_stage_path(model_dir)

config_first_stage = InferenceConfig(
ckpt_path=first_stage_ckpt_path,
num_samples=sampling_config.num_samples,
seed=sampling_config.seed,
device=sampling_config.device,
dtype=sampling_config.dtype,
compile=sampling_config.compile,
init_from=sampling_config.init_from,
output_dir=sampling_config.output_dir,
)

config_second_stage = InferenceConfig(
ckpt_path=second_stage_ckpt_path,
num_samples=sampling_config.num_samples,
seed=sampling_config.seed,
device=sampling_config.device,
dtype=sampling_config.dtype,
compile=sampling_config.compile,
init_from=sampling_config.init_from,
output_dir=sampling_config.output_dir,
)

sampling_config.max_new_tokens *= 2 # deal with max_new_tokens for flattened interleaving!

# define models
smodel, llm_first_stage, llm_second_stage = build_models(
config_first_stage,
config_second_stage,
model_dir=model_dir,
device=sampling_config.device,
use_kv_cache=sampling_config.use_kv_cache,
)
TTS_MODEL = TTS()

#### setup interface
RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
Expand Down Expand Up @@ -115,20 +69,12 @@ def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target):
_handle_edge_cases(to_say, upload_target)

to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS]
return sample_utterance(
to_say,
spk_cond_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target,
spkemb_model=smodel,
first_stage_model=llm_first_stage,
second_stage_model=llm_second_stage,
enhancer=sampling_config.enhancer,
guidance_scale=(d_guidance, 1.0),
max_new_tokens=sampling_config.max_new_tokens,
temperature=sampling_config.temperature,
top_k=sampling_config.top_k,

return TTS_MODEL.synthesise(
text=to_say,
spk_ref_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target,
top_p=d_top_p,
first_stage_ckpt_path=None,
second_stage_ckpt_path=None,
guidance_scale=d_guidance,
)
except Exception as e:
raise gr.Error(f"Something went wrong. Reason: {str(e)}")
Expand Down
5 changes: 2 additions & 3 deletions fam/llm/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def decode(
# TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file.
text_ids, extracted_audio_ids = self._data_adapter_fn(tokens)
text = self.tokeniser_decode_fn(text_ids)
print(f"Text: {text}")
# print(f"Text: {text}")

tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0)

Expand All @@ -92,11 +92,10 @@ def decode(
try:
wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}"
self._save_audio(wav_file_name, wav)
print(f"\nSaved audio to {wav_file_name}.wav")
return wav_file_name
except Exception as e:
print(f"Failed to save audio! Reason: {e}")

wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}"
self._save_audio(wav_file_name, wav)
print(f"\nSaved audio to {wav_file_name}.wav")
return wav_file_name
6 changes: 6 additions & 0 deletions fam/llm/enhancers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
"""

if enhancer_name == "df":
import warnings

warnings.filterwarnings(
"ignore",
message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.',
)
return DFEnhancer()
else:
raise ValueError(f"Unknown enhancer name: {enhancer_name}")
141 changes: 141 additions & 0 deletions fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os
import shutil
import tempfile
import time
from pathlib import Path

import librosa
import torch
from huggingface_hub import snapshot_download

from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
from fam.llm.decoders import EncodecDecoder
from fam.llm.fast_inference_utils import build_model, main
from fam.llm.inference import (
EncodecDecoder,
InferenceConfig,
Model,
TiltedEncodec,
TrainedBPETokeniser,
get_cached_embedding,
get_cached_file,
get_enhancer,
)
from fam.llm.utils import (
check_audio_file,
get_default_dtype,
get_device,
normalize_text,
)


class TTS:
def __init__(
self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs"
):
"""
model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio)
"""

# NOTE: this needs to come first so that we don't change global state when we want to use
# the torch.compiled-model.
self._dtype = get_default_dtype()
self._device = get_device()
self._model_dir = snapshot_download(repo_id=model_name)
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)

second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
config_second_stage = InferenceConfig(
ckpt_path=second_stage_ckpt_path,
num_samples=1,
seed=seed,
device=self._device,
dtype=self._dtype,
compile=False,
init_from="resume",
output_dir=self.output_dir,
)
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
self.llm_second_stage = Model(
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
)
self.enhancer = get_enhancer("df")

self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
self.model, self.tokenizer, self.smodel, self.model_size = build_model(
precision=self.precision,
checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"),
spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
device=self._device,
compile=True,
compile_prefill=True,
)

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
"""
text: Text to speak
spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3
top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker
guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style.
temperature: Temperature for sampling applied to both LLMs (first & second stage)

returns: path to speech .wav file
"""
text = normalize_text(text)
spk_ref_path = get_cached_file(spk_ref_path)
check_audio_file(spk_ref_path)
spk_emb = get_cached_embedding(
spk_ref_path,
self.smodel,
).to(device=self._device, dtype=self.precision)

start = time.time()
# first stage LLM
tokens = main(
model=self.model,
tokenizer=self.tokenizer,
model_size=self.model_size,
prompt=text,
spk_emb=spk_emb,
top_p=torch.tensor(top_p, device=self._device, dtype=self.precision),
guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision),
temperature=torch.tensor(temperature, device=self._device, dtype=self.precision),
)
text_ids, extracted_audio_ids = self.first_stage_adapter.decode([tokens])

b_speaker_embs = spk_emb.unsqueeze(0)

# second stage LLM + multi-band diffusion model
wav_files = self.llm_second_stage(
texts=[text],
encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)],
speaker_embs=b_speaker_embs,
batch_size=1,
guidance_scale=None,
top_p=None,
top_k=200,
temperature=1.0,
max_new_tokens=None,
)

# enhance using deepfilternet
wav_file = wav_files[0]
with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
print(f"\nSaved audio to {wav_file}.wav")

# calculating real-time factor (RTF)
time_to_synth_s = time.time() - start
audio, sr = librosa.load(str(wav_file) + ".wav")
duration_s = librosa.get_duration(y=audio, sr=sr)
print(f"\nTotal time to synth (s): {time_to_synth_s}")
print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}")

return str(wav_file) + ".wav"


if __name__ == "__main__":
tts = TTS()
Loading