Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 2-3x inference speedup, faster than real-time #71

Merged
merged 17 commits into from
Feb 25, 2024
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ python app.py

3. Use it via [Hugging Face](https://huggingface.co/metavoiceio)

### Dirty faster inference (~2-3x faster)

0. Upgrade to pytorch 2.2.0: `pip install --upgrade torch torchvision torchaudio`
1. Place voice reference samples in `fam/llm/assets` folder.
2. Run inference in the following manner:
```bash
cd fam/llm
python -i gptfast_inference.py
>>> print(inferencer.synthesize("Hello world!", "assets/male.wav"))
>>> print(inferencer.synthesize("Crazy fast speed coming right at you!", "assets/male.wav"))
>>> print(inferencer.synthesize("Crazy fast speed coming right at you!", "assets/female.wav"))
sidroopdaska marked this conversation as resolved.
Show resolved Hide resolved
```


## Soon
- Faster inference ⚡
- Fine-tuning code
Expand Down
98 changes: 98 additions & 0 deletions fam/llm/gptfast_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import shutil
import tempfile
from pathlib import Path

import torch
from huggingface_hub import snapshot_download

from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
from fam.llm.decoders import EncodecDecoder
from fam.llm.gptfast_sample_utils import build_model, main
from fam.llm.sample import (
EncodecDecoder,
InferenceConfig,
Model,
TiltedEncodec,
TrainedBPETokeniser,
get_cached_embedding,
get_enhancer,
)
from fam.llm.utils import check_audio_file, normalize_text


class Inferencer:
def __init__(self):
# NOTE: this needs to come first so that we don't change global state when we want to use
# the torch.compiled-model.
self._model_dir = snapshot_download(repo_id="metavoiceio/metavoice-1B-v0.1")
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
config_second_stage = InferenceConfig(
ckpt_path=second_stage_ckpt_path,
num_samples=1,
seed=1337,
device="cuda",
dtype="bfloat16",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto handle dtype. check fam/llm/utils.py

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

compile=False,
init_from="resume",
output_dir=".",
)
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
self.llm_second_stage = Model(
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
)
self.enhancer = get_enhancer("df")

self.model, self.tokenizer, self.smodel, self.precision, self.model_size = build_model(
checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"),
spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
device="cuda",
compile=True,
compile_prefill=True,
)

def synthesize(self, text, spk_ref_path, top_p=0.95, guidance_scale=3.0, temperature=1.0):
text = normalize_text(text)
check_audio_file(spk_ref_path)
spk_emb = get_cached_embedding(
spk_ref_path,
self.smodel,
).to(device="cuda", dtype=self.precision)
tokens = main(
model=self.model,
tokenizer=self.tokenizer,
smodel=self.smodel,
precision=self.precision,
model_size=self.model_size,
prompt=text,
spk_ref_path=spk_ref_path,
spk_emb=spk_emb,
top_p=torch.tensor(top_p, device="cuda", dtype=self.precision),
guidance_scale=torch.tensor(guidance_scale, device="cuda", dtype=self.precision),
temperature=torch.tensor(temperature, device="cuda", dtype=self.precision),
)
text_ids, extracted_audio_ids = self.first_stage_adapter.decode([tokens])
b_speaker_embs = spk_emb.unsqueeze(0)
wav_files = self.llm_second_stage(
texts=[text],
encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device="cuda").unsqueeze(0)],
speaker_embs=b_speaker_embs,
batch_size=1,
guidance_scale=None,
top_p=None,
top_k=200,
temperature=1.0,
max_new_tokens=None,
)
wav_file = wav_files[0]
with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
return str(wav_file) + ".wav"


if __name__ == "__main__":
inferencer = Inferencer()
print(inferencer.synthesize("Hello world!", "assets/male.wav"))
print(inferencer.synthesize("Crazy fast speed coming right at you!", "assets/male.wav"))
print(inferencer.synthesize("Crazy fast speed coming right at you!", "assets/female.wav"))
257 changes: 257 additions & 0 deletions fam/llm/gptfast_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of
# conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this
# list of conditions and the following disclaimer in the documentation and/or other
# materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from dataclasses import dataclass
from functools import reduce
from math import gcd
from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F


def find_multiple(n: int, *args: Tuple[int]) -> int:
k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
if n % k == 0:
return n
return n + k - (n % k)


@dataclass
class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
speaker_emb_dim: int = 256
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
norm_eps: float = 1e-5

def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head

@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
assert len(config) == 1, name
return cls(**transformer_configs[config[0]])


transformer_configs = {
"metavoice-1B": dict(
n_layer=24,
n_head=16,
dim=2048,
vocab_size=2562,
),
}


class KVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]

k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out


class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config

self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.pos_embeddings = nn.Embedding(config.block_size, config.dim)
self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False)
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

self.freqs_cis: Optional[Tensor] = None
sidroopdaska marked this conversation as resolved.
Show resolved Hide resolved
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1

def setup_spk_cond_mask(self):
self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool)
self.spk_cond_mask[0] = 1

def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)

self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor:
mask = self.causal_mask[None, None, input_pos]
x = (
self.tok_embeddings(idx)
+ self.pos_embeddings(input_pos)
# masking for speaker condition free guidance
+ self.speaker_cond_pos(spk_emb) * self.spk_cond_mask
)

for i, layer in enumerate(self.layers):
x = layer(x, input_pos, mask)
x = self.norm(x)
logits = self.output(x)
return logits

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))


class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out


class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0

total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None

self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
sidroopdaska marked this conversation as resolved.
Show resolved Hide resolved

def forward(
self,
x: Tensor,
mask: Tensor,
input_pos: Optional[Tensor] = None,
) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

y = self.wo(y)
return y


class SwiGLU(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)

def forward(self, x: Tensor) -> Tensor:
return F.silu(self.w1(x)) * self.w3(x)


class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.swiglu = SwiGLU(config)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
return self.w2(self.swiglu(x))


class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
Loading