From d1cd12530be98a3f14daebb6f07c436d352ad864 Mon Sep 17 00:00:00 2001 From: Siddharth Sharma Date: Mon, 19 Feb 2024 11:12:10 +0000 Subject: [PATCH] feat: reduce memory footprint & support Turing+Tesla arch (#62) * ckpt * feat: fix bug * feat: auto dtype detection * feat: support fd & vanilla kv caching * update: README.md * nit: elif * update: README.md * remove: dead code * fixes --------- Co-authored-by: Vatsal Co-authored-by: sid --- README.md | 24 +++++++---- fam/llm/decoders.py | 5 ++- fam/llm/layers/attn.py | 92 +++++------------------------------------- fam/llm/model.py | 2 +- fam/llm/sample.py | 21 +++++----- fam/llm/serving.py | 11 +++-- fam/llm/utils.py | 22 ++++++++++ fam/ui/app.py | 3 +- 8 files changed, 69 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 440d883..63df30c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ # MetaVoice-1B +Try out the [demo](https://ttsdemo.themetavoice.xyz/)! + MetaVoice-1B is a 1.2B parameter base model trained on 100K hours of speech for TTS (text-to-speech). It has been built with the following priorities: -* **Emotional speech rhythm and tone** in English. No hallucinations. +* **Emotional speech rhythm and tone** in English. * **Zero-shot cloning for American & British voices**, with 30s reference audio. * Support for (cross-lingual) **voice cloning with finetuning**. * We have had success with as little as 1 minute training data for Indian speakers. @@ -9,12 +11,14 @@ MetaVoice-1B is a 1.2B parameter base model trained on 100K hours of speech for We’re releasing MetaVoice-1B under the Apache 2.0 license, *it can be used without restrictions*. -Try out the [demo](https://ttsdemo.themetavoice.xyz/)! ## Installation -**Pre-requisites:** Python >=3.10,<3.12; GPU with >=24GB RAM. +**Pre-requisites:** +- GPU VRAM >=16GB +- Python >=3.10,<3.12 +**Environment setup** ```bash # install ffmpeg wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz @@ -24,23 +28,27 @@ tar xvf ffmpeg-git-amd64-static.tar.xz sudo mv ffmpeg-git-*-static/ffprobe ffmpeg-git-*-static/ffmpeg /usr/local/bin/ rm -rf ffmpeg-git-* +# install rust if not installed (ensure you've restarted your terminal after installation) +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + pip install -r requirements.txt -# Works only on lasest NVidia GPUs. If you have a different GPU, do not install this. +# Flash Attention works only on latest Nvidia GPUs (Hopper, Ampere & Ada). If you have a different GPU (Tesla or Turing), do not install this. pip install flash-attn pip install -e . ``` ## Usage -1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/sample.py), +1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/sample.py) ```bash -python fam/llm/sample.py --huggingface_repo_id="metavoiceio/metavoice-1B-v0.1" --spk_cond_path="assets/bria.mp3" +python fam/llm/sample.py --spk_cond_path="assets/bria.mp3" ``` -2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](/fam/llm/serving.py) +2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](/fam/llm/serving.py) and [UI](/fam/ui/app.py). ```bash -python fam/llm/serving.py --huggingface_repo_id="metavoiceio/metavoice-1B-v0.1" +python fam/llm/serving.py +python fam/ui/app.py ``` 3. Use it via [Hugging Face](https://huggingface.co/metavoiceio) diff --git a/fam/llm/decoders.py b/fam/llm/decoders.py index 4b2fd07..e59be38 100644 --- a/fam/llm/decoders.py +++ b/fam/llm/decoders.py @@ -9,6 +9,8 @@ from audiocraft.data.audio import audio_read, audio_write from audiocraft.models import MultiBandDiffusion # type: ignore +mbd = MultiBandDiffusion.get_mbd_24khz(bw=6) # 1.5 + class Decoder(ABC): @abstractmethod @@ -23,11 +25,10 @@ def __init__( data_adapter_fn: Callable[[list[list[int]]], tuple[list[int], list[list[int]]]], output_dir: str, ): - self._mbd_bandwidth = 6 # 1.5 self._mbd_sample_rate = 24_000 self._end_of_audio_token = 1024 self._num_codebooks = 8 - self.mbd = MultiBandDiffusion.get_mbd_24khz(bw=self._mbd_bandwidth) + self.mbd = mbd self.tokeniser_decode_fn = tokeniser_decode_fn self._data_adapter_fn = data_adapter_fn diff --git a/fam/llm/layers/attn.py b/fam/llm/layers/attn.py index a90211e..8110306 100644 --- a/fam/llm/layers/attn.py +++ b/fam/llm/layers/attn.py @@ -1,18 +1,13 @@ -import math import warnings import torch import torch.nn as nn +from torch.nn import functional as F try: - from flash_attn import ( # type: ignore - flash_attn_func, - flash_attn_qkvpacked_func, - flash_attn_with_kvcache, - ) + from flash_attn import flash_attn_with_kvcache # type: ignore except ImportError: warnings.warn("flash_attn not installed, make sure to replace attention mechanism with torch_attn") -from torch.nn import functional as F class SelfAttention(nn.Module): @@ -129,50 +124,13 @@ def _update_kv_cache(self, q, k, v): return k, v - def _fa2_attention(self, c_x: torch.Tensor) -> torch.Tensor: - """ - Performs Flash Attention 2.0 CUDA kernel based attention. - - Args: - c_x: The input tensor. - - Returns: - The output tensor. - """ - if self.kv_cache_enabled: - q, k, v = c_x.split(1, dim=2) - q = q.squeeze(2) - k = k.squeeze(2) - v = v.squeeze(2) - - k, v = self._update_kv_cache(q, k, v) - - y = flash_attn_func( - q, - k, - v, - dropout_p=self.dropout if self.training else 0, - softmax_scale=None, - causal=self.causal, - ) - else: - # efficient attention using Flash Attention 2.0 CUDA kernels - y = flash_attn_qkvpacked_func( - c_x, dropout_p=self.dropout if self.training else 0, softmax_scale=None, causal=self.causal - ) # outputs (B, T, nh, hs) - - return y - def _fd_attention(self, c_x: torch.Tensor) -> torch.Tensor: """ Performs Flash decoding based attention. - Args: c_x: The input tensor. - Returns: The output tensor. - Raises: Exception: If key-value caching is not enabled. Exception: If non-causal attention is activated. @@ -217,6 +175,11 @@ def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor: k = k.squeeze(2) # (B, T, nh, hs) v = v.squeeze(2) # (B, T, nh, hs) + # if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and + # use no mask for the "one time step" parts. + # calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index + is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0) + if self.kv_cache_enabled: k, v = self._update_kv_cache(q, k, v) @@ -229,46 +192,13 @@ def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor: v, attn_mask=None, dropout_p=self.dropout if self.training else 0, - is_causal=self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0), + is_causal=is_causal_attn_mask, ).transpose( 1, 2 ) # (B, nh, T, hs) -> (B, T, nh, hs) return y - def _vanilla_attn(self, c_x: torch.Tensor) -> torch.Tensor: - """ - Performs vanilla attention. - - Args: - c_x: The input tensor. - - Returns: - The output tensor. - """ - q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, nh, hs) - q = q.squeeze(2) # (B, T, nh, hs) - k = k.squeeze(2) # (B, T, nh, hs) - v = v.squeeze(2) # (B, T, nh, hs) - - if self.kv_cache_enabled: - k, v = self._update_kv_cache(q, k, v) - - q = q.transpose(1, 2) # (B, nh, T, hs) - k = k.transpose(1, 2) # (B, nh, T, hs) - v = v.transpose(1, 2) # (B, nh, T, hs) - att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T) - if self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0): - att = att.masked_fill( - torch.triu(torch.ones_like(att, dtype=torch.bool), diagonal=1), float("-inf") - ) # (B, nh, T, T) - att = F.softmax(att, dim=-1) # (B, nh, T, T) - att = self.attn_dropout(att) # (B, nh, T, T) - y = att @ v # (B, nh, T, hs) - y = y.transpose(1, 2) # (B, T, nh, hs) - - return y - def forward(self, x): """ Performs the forward pass of the SelfAttention module. @@ -285,14 +215,10 @@ def forward(self, x): c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs) # causal self-attention; - if self.attn_kernel_type == "fa2": - y = self._fa2_attention(c_x) - elif self.attn_kernel_type == "fd": + if self.attn_kernel_type == "fd": y = self._fd_attention(c_x) elif self.attn_kernel_type == "torch_attn": y = self._torch_attn(c_x) - elif self.attn_kernel_type == "hand": - y = self._vanilla_attn(c_x) else: raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}") diff --git a/fam/llm/model.py b/fam/llm/model.py index c6e3c2f..8c42735 100644 --- a/fam/llm/model.py +++ b/fam/llm/model.py @@ -42,7 +42,7 @@ class GPTConfig: rmsnorm_eps: Optional[float] = None # only used for rmsnorm nonlinearity_type: str = "gelu" # "gelu" or "swiglu" swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this - attn_kernel_type: Literal["fa2", "torch_attn", "hand"] = "fa2" + attn_kernel_type: Literal["fd", "torch_attn"] = "torch_attn" kv_cache_enabled: bool = False # whether to use key-value cache for attention diff --git a/fam/llm/sample.py b/fam/llm/sample.py index 5a48402..af9191c 100644 --- a/fam/llm/sample.py +++ b/fam/llm/sample.py @@ -21,7 +21,7 @@ from fam.llm.decoders import Decoder, EncodecDecoder from fam.llm.enhancers import BaseEnhancer, get_enhancer from fam.llm.model import GPT, GPTConfig -from fam.llm.utils import normalize_text +from fam.llm.utils import get_default_dtype, get_default_use_kv_cache, normalize_text from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder from fam.quantiser.text.tokenise import TrainedBPETokeniser @@ -53,7 +53,7 @@ def __init__( tokenizer_cls: Type[TrainedBPETokeniser], decoder_cls: Type[Decoder], data_adapter_fn, - use_kv_cache: Optional[Literal["none", "flash_decoding", "vanilla"]] = None, + use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = None, ): # TODO: disentangle the encodec stuff and numbers etc with rest of this code (esp at encoder-only / second stage model inference) # TODO: remove magic number @@ -151,8 +151,6 @@ def _init_model(self): for block in self.model.transformer.h: block.attn.attn_kernel_type = "fd" elif self.use_kv_cache == "vanilla": - for block in self.model.transformer.h: - block.attn.attn_kernel_type = "torch_attn" self.model.enable_kv_cache() else: raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!") @@ -596,12 +594,12 @@ class SamplingControllerConfig: Sample from a trained model. """ - huggingface_repo_id: str - """Absolute path to the model directory.""" - spk_cond_path: str """Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3""" + huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1" + """Absolute path to the model directory.""" + text: str = ( "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model by MetaVoice." ) @@ -628,7 +626,7 @@ class SamplingControllerConfig: device: Literal["cuda", "cpu"] = "cuda" """Device to use for sampling.""" - dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16" + dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = get_default_dtype() """Data type to use for sampling.""" compile: bool = False @@ -640,9 +638,9 @@ class SamplingControllerConfig: init_from: str = "resume" """Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl').""" - use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = None + use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = get_default_use_kv_cache() """Type of kv caching to use for inference: 1) [none] no kv caching, 2) [flash_decoding] use the - flash decoding kernel, 3) [vanilla] use flash attention 2 with hand implemented kv-cache.""" + flash decoding kernel, 3) [vanilla] use torch attention with hand implemented kv-cache.""" output_dir: str = "samples/" """Relative path to output directory""" @@ -696,10 +694,9 @@ class SamplingControllerConfig: config_second_stage, model_dir=model_dir, device=sampling_config.device, - use_kv_cache=sampling_config.use_kv_cache + use_kv_cache=sampling_config.use_kv_cache, ) - print(f"Synthesising utterance...") sample_utterance( sampling_config.text, os.path.expanduser(sampling_config.spk_cond_path), diff --git a/fam/llm/serving.py b/fam/llm/serving.py index f698bc1..ff3fb8b 100644 --- a/fam/llm/serving.py +++ b/fam/llm/serving.py @@ -25,6 +25,7 @@ get_second_stage_path, sample_utterance, ) +from fam.llm.utils import get_default_dtype, get_default_use_kv_cache logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ @dataclass class ServingConfig: - huggingface_repo_id: str + huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1" """Absolute path to the model directory.""" max_new_tokens: int = 864 * 2 @@ -50,7 +51,7 @@ class ServingConfig: seed: int = 1337 """Random seed for sampling.""" - dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16" + dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = get_default_dtype() """Data type to use for sampling.""" enhancer: Optional[Literal["df"]] = "df" @@ -59,6 +60,10 @@ class ServingConfig: compile: bool = False """Whether to compile the model using PyTorch 2.0.""" + use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = get_default_use_kv_cache() + """Type of kv caching to use for inference: 1) [none] no kv caching, 2) [flash_decoding] use the + flash decoding kernel, 3) [vanilla] use torch attention with hand implemented kv-cache.""" + port: int = 58003 @@ -185,7 +190,7 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp): ) spkemb, llm_stg1, llm_stg2 = build_models( - config1, config2, model_dir=model_dir, device=device, use_kv_cache="flash_decoding" + config1, config2, model_dir=model_dir, device=device, use_kv_cache=GlobalState.config.use_kv_cache ) GlobalState.spkemb_model = spkemb GlobalState.first_stage_model = llm_stg1 diff --git a/fam/llm/utils.py b/fam/llm/utils.py index 60824a6..f2f60cf 100644 --- a/fam/llm/utils.py +++ b/fam/llm/utils.py @@ -1,5 +1,7 @@ import re +import torch + def normalize_text(text: str) -> str: unicode_conversion = { @@ -45,3 +47,23 @@ def normalize_text(text: str) -> str: text = text.strip() text = re.sub("\s\s+", " ", text) # remove multiple spaces return text + + +def get_default_use_kv_cache() -> str: + """Compute default value for 'use_kv_cache' based on GPU architecture""" + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + device_properties = torch.cuda.get_device_properties(i) + return "vanilla" if "Turing" or "Tesla" in device_properties else "flash_decoding" + else: + return "vanilla" + + +def get_default_dtype() -> str: + """Compute default 'dtype' based on GPU architecture""" + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + device_properties = torch.cuda.get_device_properties(i) + return "float16" if "Turing" or "Tesla" in device_properties else "bfloat16" + else: + return "float16" diff --git a/fam/ui/app.py b/fam/ui/app.py index eddbe10..82720c7 100644 --- a/fam/ui/app.py +++ b/fam/ui/app.py @@ -11,7 +11,6 @@ MAX_CHARS = 220 PRESET_VOICES = { # female - "Ava": "https://cdn.themetavoice.xyz/speakers/ava.flac", "Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3", # male "Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3", @@ -63,7 +62,7 @@ def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target): config = { "text": to_say, - "guidance": d_guidance, + "guidance": (d_guidance, 1.0), "top_p": d_top_p, "speaker_ref_path": PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else None, }