Skip to content

Commit

Permalink
feat: reduce memory footprint & support Turing+Tesla arch (metavoicei…
Browse files Browse the repository at this point in the history
…o#62)

* ckpt

* feat: fix bug

* feat: auto dtype detection

* feat: support fd & vanilla kv caching

* update: README.md

* nit: elif

* update: README.md

* remove: dead code

* fixes

---------

Co-authored-by: Vatsal <vatsal@themetavoice.xyz>
Co-authored-by: sid <sid@themetavoice.xyz>
  • Loading branch information
3 people committed Feb 19, 2024
1 parent 11550bb commit d1cd125
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 111 deletions.
24 changes: 16 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# MetaVoice-1B

Try out the [demo](https://ttsdemo.themetavoice.xyz/)!

MetaVoice-1B is a 1.2B parameter base model trained on 100K hours of speech for TTS (text-to-speech). It has been built with the following priorities:
* **Emotional speech rhythm and tone** in English. No hallucinations.
* **Emotional speech rhythm and tone** in English.
* **Zero-shot cloning for American & British voices**, with 30s reference audio.
* Support for (cross-lingual) **voice cloning with finetuning**.
* We have had success with as little as 1 minute training data for Indian speakers.
* Support for **long-form synthesis**.

We’re releasing MetaVoice-1B under the Apache 2.0 license, *it can be used without restrictions*.

Try out the [demo](https://ttsdemo.themetavoice.xyz/)!

## Installation

**Pre-requisites:** Python >=3.10,<3.12; GPU with >=24GB RAM.
**Pre-requisites:**
- GPU VRAM >=16GB
- Python >=3.10,<3.12

**Environment setup**
```bash
# install ffmpeg
wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz
Expand All @@ -24,23 +28,27 @@ tar xvf ffmpeg-git-amd64-static.tar.xz
sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/
rm -rf ffmpeg-git-*

# install rust if not installed (ensure you've restarted your terminal after installation)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

pip install -r requirements.txt

# Works only on lasest NVidia GPUs. If you have a different GPU, do not install this.
# Flash Attention works only on latest Nvidia GPUs (Hopper, Ampere & Ada). If you have a different GPU (Tesla or Turing), do not install this.
pip install flash-attn

pip install -e .
```

## Usage
1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/sample.py),
1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/sample.py)
```bash
python fam/llm/sample.py --huggingface_repo_id="metavoiceio/metavoice-1B-v0.1" --spk_cond_path="assets/bria.mp3"
python fam/llm/sample.py --spk_cond_path="assets/bria.mp3"
```

2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](/fam/llm/serving.py)
2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](/fam/llm/serving.py) and [UI](/fam/ui/app.py).
```bash
python fam/llm/serving.py --huggingface_repo_id="metavoiceio/metavoice-1B-v0.1"
python fam/llm/serving.py
python fam/ui/app.py
```

3. Use it via [Hugging Face](https://huggingface.co/metavoiceio)
Expand Down
5 changes: 3 additions & 2 deletions fam/llm/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from audiocraft.data.audio import audio_read, audio_write
from audiocraft.models import MultiBandDiffusion # type: ignore

mbd = MultiBandDiffusion.get_mbd_24khz(bw=6) # 1.5


class Decoder(ABC):
@abstractmethod
Expand All @@ -23,11 +25,10 @@ def __init__(
data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]],
output_dir: str,
):
self._mbd_bandwidth = 6 # 1.5
self._mbd_sample_rate = 24_000
self._end_of_audio_token = 1024
self._num_codebooks = 8
self.mbd = MultiBandDiffusion.get_mbd_24khz(bw=self._mbd_bandwidth)
self.mbd = mbd

self.tokeniser_decode_fn = tokeniser_decode_fn
self._data_adapter_fn = data_adapter_fn
Expand Down
92 changes: 9 additions & 83 deletions fam/llm/layers/attn.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import math
import warnings

import torch
import torch.nn as nn
from torch.nn import functional as F

try:
from flash_attn import ( # type: ignore
flash_attn_func,
flash_attn_qkvpacked_func,
flash_attn_with_kvcache,
)
from flash_attn import flash_attn_with_kvcache # type: ignore
except ImportError:
warnings.warn("flash_attn not installed, make sure to replace attention mechanism with torch_attn")
from torch.nn import functional as F


class SelfAttention(nn.Module):
Expand Down Expand Up @@ -129,50 +124,13 @@ def _update_kv_cache(self, q, k, v):

return k, v

def _fa2_attention(self, c_x: torch.Tensor) -> torch.Tensor:
"""
Performs Flash Attention 2.0 CUDA kernel based attention.
Args:
c_x: The input tensor.
Returns:
The output tensor.
"""
if self.kv_cache_enabled:
q, k, v = c_x.split(1, dim=2)
q = q.squeeze(2)
k = k.squeeze(2)
v = v.squeeze(2)

k, v = self._update_kv_cache(q, k, v)

y = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout if self.training else 0,
softmax_scale=None,
causal=self.causal,
)
else:
# efficient attention using Flash Attention 2.0 CUDA kernels
y = flash_attn_qkvpacked_func(
c_x, dropout_p=self.dropout if self.training else 0, softmax_scale=None, causal=self.causal
) # outputs (B, T, nh, hs)

return y

def _fd_attention(self, c_x: torch.Tensor) -> torch.Tensor:
"""
Performs Flash decoding based attention.
Args:
c_x: The input tensor.
Returns:
The output tensor.
Raises:
Exception: If key-value caching is not enabled.
Exception: If non-causal attention is activated.
Expand Down Expand Up @@ -217,6 +175,11 @@ def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor:
k = k.squeeze(2) # (B, T, nh, hs)
v = v.squeeze(2) # (B, T, nh, hs)

# if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and
# use no mask for the "one time step" parts.
# calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index
is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0)

if self.kv_cache_enabled:
k, v = self._update_kv_cache(q, k, v)

Expand All @@ -229,46 +192,13 @@ def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor:
v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0),
is_causal=is_causal_attn_mask,
).transpose(
1, 2
) # (B, nh, T, hs) -> (B, T, nh, hs)

return y

def _vanilla_attn(self, c_x: torch.Tensor) -> torch.Tensor:
"""
Performs vanilla attention.
Args:
c_x: The input tensor.
Returns:
The output tensor.
"""
q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, nh, hs)
q = q.squeeze(2) # (B, T, nh, hs)
k = k.squeeze(2) # (B, T, nh, hs)
v = v.squeeze(2) # (B, T, nh, hs)

if self.kv_cache_enabled:
k, v = self._update_kv_cache(q, k, v)

q = q.transpose(1, 2) # (B, nh, T, hs)
k = k.transpose(1, 2) # (B, nh, T, hs)
v = v.transpose(1, 2) # (B, nh, T, hs)
att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T)
if self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0):
att = att.masked_fill(
torch.triu(torch.ones_like(att, dtype=torch.bool), diagonal=1), float("-inf")
) # (B, nh, T, T)
att = F.softmax(att, dim=-1) # (B, nh, T, T)
att = self.attn_dropout(att) # (B, nh, T, T)
y = att @ v # (B, nh, T, hs)
y = y.transpose(1, 2) # (B, T, nh, hs)

return y

def forward(self, x):
"""
Performs the forward pass of the SelfAttention module.
Expand All @@ -285,14 +215,10 @@ def forward(self, x):
c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs)

# causal self-attention;
if self.attn_kernel_type == "fa2":
y = self._fa2_attention(c_x)
elif self.attn_kernel_type == "fd":
if self.attn_kernel_type == "fd":
y = self._fd_attention(c_x)
elif self.attn_kernel_type == "torch_attn":
y = self._torch_attn(c_x)
elif self.attn_kernel_type == "hand":
y = self._vanilla_attn(c_x)
else:
raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}")

Expand Down
2 changes: 1 addition & 1 deletion fam/llm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class GPTConfig:
rmsnorm_eps: Optional[float] = None # only used for rmsnorm
nonlinearity_type: str = "gelu" # "gelu" or "swiglu"
swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this
attn_kernel_type: Literal["fa2", "torch_attn", "hand"] = "fa2"
attn_kernel_type: Literal["fd", "torch_attn"] = "torch_attn"
kv_cache_enabled: bool = False # whether to use key-value cache for attention


Expand Down
21 changes: 9 additions & 12 deletions fam/llm/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from fam.llm.decoders import Decoder, EncodecDecoder
from fam.llm.enhancers import BaseEnhancer, get_enhancer
from fam.llm.model import GPT, GPTConfig
from fam.llm.utils import normalize_text
from fam.llm.utils import get_default_dtype, get_default_use_kv_cache, normalize_text
from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
from fam.quantiser.text.tokenise import TrainedBPETokeniser

Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
tokenizer_cls: Type[TrainedBPETokeniser],
decoder_cls: Type[Decoder],
data_adapter_fn,
use_kv_cache: Optional[Literal["none", "flash_decoding", "vanilla"]] = None,
use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = None,
):
# TODO: disentangle the encodec stuff and numbers etc with rest of this code (esp at encoder-only / second stage model inference)
# TODO: remove magic number
Expand Down Expand Up @@ -151,8 +151,6 @@ def _init_model(self):
for block in self.model.transformer.h:
block.attn.attn_kernel_type = "fd"
elif self.use_kv_cache == "vanilla":
for block in self.model.transformer.h:
block.attn.attn_kernel_type = "torch_attn"
self.model.enable_kv_cache()
else:
raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!")
Expand Down Expand Up @@ -596,12 +594,12 @@ class SamplingControllerConfig:
Sample from a trained model.
"""

huggingface_repo_id: str
"""Absolute path to the model directory."""

spk_cond_path: str
"""Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3"""

huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1"
"""Absolute path to the model directory."""

text: str = (
"This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model by MetaVoice."
)
Expand All @@ -628,7 +626,7 @@ class SamplingControllerConfig:
device: Literal["cuda", "cpu"] = "cuda"
"""Device to use for sampling."""

dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16"
dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = get_default_dtype()
"""Data type to use for sampling."""

compile: bool = False
Expand All @@ -640,9 +638,9 @@ class SamplingControllerConfig:
init_from: str = "resume"
"""Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')."""

use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = None
use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = get_default_use_kv_cache()
"""Type of kv caching to use for inference: 1) [none] no kv caching, 2) [flash_decoding] use the
flash decoding kernel, 3) [vanilla] use flash attention 2 with hand implemented kv-cache."""
flash decoding kernel, 3) [vanilla] use torch attention with hand implemented kv-cache."""

output_dir: str = "samples/"
"""Relative path to output directory"""
Expand Down Expand Up @@ -696,10 +694,9 @@ class SamplingControllerConfig:
config_second_stage,
model_dir=model_dir,
device=sampling_config.device,
use_kv_cache=sampling_config.use_kv_cache
use_kv_cache=sampling_config.use_kv_cache,
)

print(f"Synthesising utterance...")
sample_utterance(
sampling_config.text,
os.path.expanduser(sampling_config.spk_cond_path),
Expand Down
11 changes: 8 additions & 3 deletions fam/llm/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_second_stage_path,
sample_utterance,
)
from fam.llm.utils import get_default_dtype, get_default_use_kv_cache

logger = logging.getLogger(__name__)

Expand All @@ -35,7 +36,7 @@

@dataclass
class ServingConfig:
huggingface_repo_id: str
huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1"
"""Absolute path to the model directory."""

max_new_tokens: int = 864 * 2
Expand All @@ -50,7 +51,7 @@ class ServingConfig:
seed: int = 1337
"""Random seed for sampling."""

dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16"
dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = get_default_dtype()
"""Data type to use for sampling."""

enhancer: Optional[Literal["df"]] = "df"
Expand All @@ -59,6 +60,10 @@ class ServingConfig:
compile: bool = False
"""Whether to compile the model using PyTorch 2.0."""

use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = get_default_use_kv_cache()
"""Type of kv caching to use for inference: 1) [none] no kv caching, 2) [flash_decoding] use the
flash decoding kernel, 3) [vanilla] use torch attention with hand implemented kv-cache."""

port: int = 58003


Expand Down Expand Up @@ -185,7 +190,7 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
)

spkemb, llm_stg1, llm_stg2 = build_models(
config1, config2, model_dir=model_dir, device=device, use_kv_cache="flash_decoding"
config1, config2, model_dir=model_dir, device=device, use_kv_cache=GlobalState.config.use_kv_cache
)
GlobalState.spkemb_model = spkemb
GlobalState.first_stage_model = llm_stg1
Expand Down
22 changes: 22 additions & 0 deletions fam/llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

import torch


def normalize_text(text: str) -> str:
unicode_conversion = {
Expand Down Expand Up @@ -45,3 +47,23 @@ def normalize_text(text: str) -> str:
text = text.strip()
text = re.sub("\s\s+", " ", text) # remove multiple spaces
return text


def get_default_use_kv_cache() -> str:
"""Compute default value for 'use_kv_cache' based on GPU architecture"""
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
device_properties = torch.cuda.get_device_properties(i)
return "vanilla" if "Turing" or "Tesla" in device_properties else "flash_decoding"
else:
return "vanilla"


def get_default_dtype() -> str:
"""Compute default 'dtype' based on GPU architecture"""
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
device_properties = torch.cuda.get_device_properties(i)
return "float16" if "Turing" or "Tesla" in device_properties else "bfloat16"
else:
return "float16"
Loading

0 comments on commit d1cd125

Please sign in to comment.