diff --git a/Dockerfile b/Dockerfile index 8b787ef..a1c6d67 100644 --- a/Dockerfile +++ b/Dockerfile @@ -22,11 +22,12 @@ WORKDIR /app COPY requirements.txt requirements.txt RUN pip install --no-cache-dir packaging wheel torch -RUN pip install --no-cache-dir flash-attn +RUN pip install --no-cache-dir audiocraft # HACK: installation fails within the requirements.txt RUN pip install --no-cache-dir -r requirements.txt +RUN pip install --no-cache-dir --upgrade torch torchaudio COPY . . RUN pip install --no-cache-dir -e . -ENTRYPOINT ["python3.10", "fam/llm/serving.py"] \ No newline at end of file +ENTRYPOINT ["python3.10", "serving.py"] \ No newline at end of file diff --git a/README.md b/README.md index 4bfd630..3360b8b 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ docker-compose up -d server && docker-compose ps && docker-compose logs -f ## Installation **Pre-requisites:** -- GPU VRAM >=16GB +- GPU VRAM >=12GB - Python >=3.10,<3.12 **Environment setup** @@ -49,32 +49,36 @@ rm -rf ffmpeg-git-* curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh pip install -r requirements.txt - -# Flash Attention works only on latest Nvidia GPUs (Hopper, Ampere & Ada). If you have a different GPU (Tesla or Turing), do not install this. -pip install flash-attn - +pip install --upgrade torch torchaudio  # for torch.compile improvements pip install -e . ``` ## Usage -1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/sample.py) +1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py) ```bash -python fam/llm/sample.py --spk_cond_path="assets/bria.mp3" --text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model." +python -i fam/llm/fast_inference.py + +# Run e.g. of API usage within the interactive python session +tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model.", spk_ref_path="assets/bria.mp3") ``` +> Note: The script takes 30-90s to startup (depending on hardware). This is because we torch.compile the model for fast inference. +> On Ampere, Ada-Lovelace, and Hopper architecture GPUs, once compiled, the synthesise() API runs faster than real-time, with a Real-Time Factor (RTF) < 1.0. -2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](/fam/llm/serving.py) or [web UI](/fam/ui/app.py) +2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py) ```bash -python fam/llm/serving.py +python serving.py python app.py ``` 3. Use it via [Hugging Face](https://huggingface.co/metavoiceio) 4. [Google Collab](https://colab.research.google.com/drive/1UmjE1mzfG4td0rCjJEaAWGQXpn_GuwwY?authuser=0#scrollTo=mPgTfUdBJF1B) -## Soon -- Faster inference ⚡ -- Fine-tuning code -- Synthesis of arbitrary length text + +## Upcoming +- [x] Faster inference ⚡ +- [ ] Fine-tuning code +- [ ] Synthesis of arbitrary length text + ## Architecture We predict EnCodec tokens from text, and speaker information. This is then diffused up to the waveform level, with post-processing applied to clean up the audio. diff --git a/app.py b/app.py index 3a64dc6..ddb4b6d 100644 --- a/app.py +++ b/app.py @@ -7,58 +7,12 @@ import gradio as gr -from huggingface_hub import snapshot_download - -from fam.llm.sample import ( - InferenceConfig, - SamplingControllerConfig, - build_models, - get_first_stage_path, - get_second_stage_path, - sample_utterance, -) + +from fam.llm.fast_inference import TTS from fam.llm.utils import check_audio_file #### setup model -sampling_config = SamplingControllerConfig( - huggingface_repo_id="metavoiceio/metavoice-1B-v0.1", spk_cond_path="" -) # spk_cond_path added later -model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id) -first_stage_ckpt_path = get_first_stage_path(model_dir) -second_stage_ckpt_path = get_second_stage_path(model_dir) - -config_first_stage = InferenceConfig( - ckpt_path=first_stage_ckpt_path, - num_samples=sampling_config.num_samples, - seed=sampling_config.seed, - device=sampling_config.device, - dtype=sampling_config.dtype, - compile=sampling_config.compile, - init_from=sampling_config.init_from, - output_dir=sampling_config.output_dir, -) - -config_second_stage = InferenceConfig( - ckpt_path=second_stage_ckpt_path, - num_samples=sampling_config.num_samples, - seed=sampling_config.seed, - device=sampling_config.device, - dtype=sampling_config.dtype, - compile=sampling_config.compile, - init_from=sampling_config.init_from, - output_dir=sampling_config.output_dir, -) - -sampling_config.max_new_tokens *= 2 # deal with max_new_tokens for flattened interleaving! - -# define models -smodel, llm_first_stage, llm_second_stage = build_models( - config_first_stage, - config_second_stage, - model_dir=model_dir, - device=sampling_config.device, - use_kv_cache=sampling_config.use_kv_cache, -) +TTS_MODEL = TTS() #### setup interface RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"] @@ -115,20 +69,12 @@ def tts(to_say, top_p, guidance, toggle, preset_dropdown, upload_target): _handle_edge_cases(to_say, upload_target) to_say = to_say if len(to_say) < MAX_CHARS else to_say[:MAX_CHARS] - return sample_utterance( - to_say, - spk_cond_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target, - spkemb_model=smodel, - first_stage_model=llm_first_stage, - second_stage_model=llm_second_stage, - enhancer=sampling_config.enhancer, - guidance_scale=(d_guidance, 1.0), - max_new_tokens=sampling_config.max_new_tokens, - temperature=sampling_config.temperature, - top_k=sampling_config.top_k, + + return TTS_MODEL.synthesise( + text=to_say, + spk_ref_path=PRESET_VOICES[preset_dropdown] if toggle == RADIO_CHOICES[0] else upload_target, top_p=d_top_p, - first_stage_ckpt_path=None, - second_stage_ckpt_path=None, + guidance_scale=d_guidance, ) except Exception as e: raise gr.Error(f"Something went wrong. Reason: {str(e)}") diff --git a/fam/llm/decoders.py b/fam/llm/decoders.py index e59be38..9561582 100644 --- a/fam/llm/decoders.py +++ b/fam/llm/decoders.py @@ -68,7 +68,7 @@ def decode( # TODO: this has strange behaviour -- if causal is True, it returns tokens. if causal is False, it SAVES the audio file. text_ids, extracted_audio_ids = self._data_adapter_fn(tokens) text = self.tokeniser_decode_fn(text_ids) - print(f"Text: {text}") + # print(f"Text: {text}") tokens = torch.tensor(extracted_audio_ids, device="cuda").unsqueeze(0) @@ -92,11 +92,10 @@ def decode( try: wav_file_name = self.output_dir / f"synth_{text.replace(' ', '_')[:25]}_{uuid.uuid4()}" self._save_audio(wav_file_name, wav) - print(f"\nSaved audio to {wav_file_name}.wav") return wav_file_name except Exception as e: print(f"Failed to save audio! Reason: {e}") + wav_file_name = self.output_dir / f"synth_{uuid.uuid4()}" self._save_audio(wav_file_name, wav) - print(f"\nSaved audio to {wav_file_name}.wav") return wav_file_name diff --git a/fam/llm/enhancers.py b/fam/llm/enhancers.py index b9522dd..f4338c7 100644 --- a/fam/llm/enhancers.py +++ b/fam/llm/enhancers.py @@ -97,6 +97,12 @@ def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer: """ if enhancer_name == "df": + import warnings + + warnings.filterwarnings( + "ignore", + message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.', + ) return DFEnhancer() else: raise ValueError(f"Unknown enhancer name: {enhancer_name}") diff --git a/fam/llm/fast_inference.py b/fam/llm/fast_inference.py new file mode 100644 index 0000000..a3b1f89 --- /dev/null +++ b/fam/llm/fast_inference.py @@ -0,0 +1,141 @@ +import os +import shutil +import tempfile +import time +from pathlib import Path + +import librosa +import torch +from huggingface_hub import snapshot_download + +from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook +from fam.llm.decoders import EncodecDecoder +from fam.llm.fast_inference_utils import build_model, main +from fam.llm.inference import ( + EncodecDecoder, + InferenceConfig, + Model, + TiltedEncodec, + TrainedBPETokeniser, + get_cached_embedding, + get_cached_file, + get_enhancer, +) +from fam.llm.utils import ( + check_audio_file, + get_default_dtype, + get_device, + normalize_text, +) + + +class TTS: + def __init__( + self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs" + ): + """ + model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio) + """ + + # NOTE: this needs to come first so that we don't change global state when we want to use + # the torch.compiled-model. + self._dtype = get_default_dtype() + self._device = get_device() + self._model_dir = snapshot_download(repo_id=model_name) + self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024) + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt" + config_second_stage = InferenceConfig( + ckpt_path=second_stage_ckpt_path, + num_samples=1, + seed=seed, + device=self._device, + dtype=self._dtype, + compile=False, + init_from="resume", + output_dir=self.output_dir, + ) + data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024) + self.llm_second_stage = Model( + config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode + ) + self.enhancer = get_enhancer("df") + + self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype] + self.model, self.tokenizer, self.smodel, self.model_size = build_model( + precision=self.precision, + checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"), + spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"), + device=self._device, + compile=True, + compile_prefill=True, + ) + + def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str: + """ + text: Text to speak + spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3 + top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker + guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style. + temperature: Temperature for sampling applied to both LLMs (first & second stage) + + returns: path to speech .wav file + """ + text = normalize_text(text) + spk_ref_path = get_cached_file(spk_ref_path) + check_audio_file(spk_ref_path) + spk_emb = get_cached_embedding( + spk_ref_path, + self.smodel, + ).to(device=self._device, dtype=self.precision) + + start = time.time() + # first stage LLM + tokens = main( + model=self.model, + tokenizer=self.tokenizer, + model_size=self.model_size, + prompt=text, + spk_emb=spk_emb, + top_p=torch.tensor(top_p, device=self._device, dtype=self.precision), + guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision), + temperature=torch.tensor(temperature, device=self._device, dtype=self.precision), + ) + text_ids, extracted_audio_ids = self.first_stage_adapter.decode([tokens]) + + b_speaker_embs = spk_emb.unsqueeze(0) + + # second stage LLM + multi-band diffusion model + wav_files = self.llm_second_stage( + texts=[text], + encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)], + speaker_embs=b_speaker_embs, + batch_size=1, + guidance_scale=None, + top_p=None, + top_k=200, + temperature=1.0, + max_new_tokens=None, + ) + + # enhance using deepfilternet + wav_file = wav_files[0] + with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp: + self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name) + shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav") + print(f"\nSaved audio to {wav_file}.wav") + + # calculating real-time factor (RTF) + time_to_synth_s = time.time() - start + audio, sr = librosa.load(str(wav_file) + ".wav") + duration_s = librosa.get_duration(y=audio, sr=sr) + print(f"\nTotal time to synth (s): {time_to_synth_s}") + print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}") + + return str(wav_file) + ".wav" + + +if __name__ == "__main__": + tts = TTS() diff --git a/fam/llm/fast_inference_utils.py b/fam/llm/fast_inference_utils.py new file mode 100644 index 0000000..6ed7b9e --- /dev/null +++ b/fam/llm/fast_inference_utils.py @@ -0,0 +1,432 @@ +# Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, this +# list of conditions and the following disclaimer in the documentation and/or other +# materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import itertools +import time +from pathlib import Path +from typing import Optional, Tuple + +import torch +import torch._dynamo.config +import torch._inductor.config +import tqdm + + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize() + elif "cpu" in device: + pass + else: + print(f"device={device} is not yet suppported") + + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True +torch._inductor.config.fx_graph_cache = ( + True # Experimental feature to reduce compilation times, will be on by default in future +) + +# imports need to happen after setting above flags +from fam.llm.fast_model import Transformer +from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder +from fam.quantiser.text.tokenise import TrainedBPETokeniser + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def top_p_sample(logits: torch.Tensor, top_p: torch.Tensor): + # ref: huggingface/transformers + + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[-1:] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) + scores = logits.masked_fill(indices_to_remove, -float("Inf")) + return scores + + +def logits_to_probs( + logits, + *, + temperature: torch.Tensor, + top_p: Optional[torch.Tensor] = None, + top_k: Optional[torch.Tensor] = None, +): + logits = logits / torch.max(temperature, 1e-5 * torch.ones_like(temperature)) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + if top_p is not None: + logits = top_p_sample(logits, top_p) + + probs = torch.nn.functional.softmax(logits, dim=-1) + + return probs + + +def sample( + logits, + guidance_scale: torch.Tensor, + temperature: torch.Tensor, + top_p: Optional[torch.Tensor] = None, + top_k: Optional[torch.Tensor] = None, +): + # (b, t, vocab_size) + logits = logits[:, -1] + logits_cond, logits_uncond_spkemb = logits.split(logits.size(0) // 2, dim=0) + logits = guidance_scale * logits_cond + (1 - guidance_scale) * logits_uncond_spkemb + probs = logits_to_probs(logits[0], temperature=temperature, top_p=top_p, top_k=top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, + x: torch.Tensor, + spk_emb: torch.Tensor, + input_pos: torch.Tensor, + **sampling_kwargs, +) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, spk_emb, input_pos) + return sample(logits, **sampling_kwargs)[0] + + +def decode_one_token( + model: Transformer, + x: torch.Tensor, + spk_emb: torch.Tensor, + input_pos: torch.Tensor, + **sampling_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, spk_emb, input_pos) + return sample(logits, **sampling_kwargs) + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + spk_emb: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + return_probs: bool = False, + end_of_audio_token: int = 2048, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in tqdm.tqdm(range(num_new_tokens)): + if (cur_token == end_of_audio_token).any(): + break + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token(model, cur_token, spk_emb, input_pos, **sampling_kwargs) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + if return_probs: + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1).repeat(2, 1) + + return new_tokens, new_probs + + +def model_forward(model, x, spk_emb, input_pos): + return model(x, spk_emb, input_pos) + + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + spk_emb: torch.Tensor, + *, + max_new_tokens: Optional[int] = None, + callback=lambda x: x, + end_of_audio_token: int = 2048, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + if max_new_tokens is None: + max_seq_length = model.config.block_size + else: + max_seq_length = T + max_new_tokens + max_seq_length = min(max_seq_length, model.config.block_size) + max_new_tokens = max_seq_length - T + if max_new_tokens <= 0: + raise ValueError("Prompt is too long to generate more tokens") + + device, dtype = prompt.device, prompt.dtype + + seq = torch.clone(prompt) + input_pos = torch.arange(0, T, device=device) + + next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs) + seq = torch.cat([seq, next_token.view(1)]) + + input_pos = torch.tensor([T], device=device, dtype=torch.int) + + generated_tokens, _ = decode_n_tokens( + model, + next_token.view(1, -1).repeat(2, 1), + spk_emb, + input_pos, + max_new_tokens - 1, + callback=callback, + end_of_audio_token=end_of_audio_token, + **sampling_kwargs, + ) + seq = torch.cat([seq, torch.cat(generated_tokens)]) + + return seq + + +def encode_tokens(tokenizer, string, device="cuda"): + tokens = tokenizer.encode(string) + return torch.tensor(tokens, dtype=torch.int, device=device) + + +def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision): + ##### MODEL + with torch.device("meta"): + model = Transformer.from_name("metavoice-1B") + + # TODO(quantization): enable + # if "int8" in str(checkpoint_path): + # print("Using int8 weight-only quantization!") + # from quantize import WeightOnlyInt8QuantHandler + # simple_quantizer = WeightOnlyInt8QuantHandler(model) + # model = simple_quantizer.convert_for_runtime() + # from quantize import WeightOnlyInt8QuantHandler + + # if "int4" in str(checkpoint_path): + # print("Using int4 quantization!") + # path_comps = checkpoint_path.name.split(".") + # assert path_comps[-2].startswith("g") + # groupsize = int(path_comps[-2][1:]) + # from quantize import WeightOnlyInt4QuantHandler + # simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) + # model = simple_quantizer.convert_for_runtime() + + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False) + state_dict = checkpoint["model"] + # convert MetaVoice-1B model weights naming to gptfast naming + unwanted_prefix = "_orig_mod." + for k, v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) + state_dict["tok_embeddings.weight"] = state_dict.pop("transformer.wtes.0.weight") + state_dict["pos_embeddings.weight"] = state_dict.pop("transformer.wpe.weight") + state_dict["output.weight"] = state_dict.pop("lm_heads.0.weight") + state_dict["norm.weight"] = state_dict.pop("transformer.ln_f.weight") + for k, v in list(state_dict.items()): + if k.startswith("transformer.h."): + state_dict[k.replace("transformer.h.", "layers.")] = state_dict.pop(k) + k = k.replace("transformer.h.", "layers.") + if ".attn.c_attn." in k: + state_dict[k.replace(".attn.c_attn.", ".attention.wqkv.")] = state_dict.pop(k) + k = k.replace(".attn.c_attn.", ".attention.wqkv.") + if ".attn.c_proj." in k: + state_dict[k.replace(".attn.c_proj.", ".attention.wo.")] = state_dict.pop(k) + k = k.replace(".attn.c_proj.", ".attention.wo.") + if ".mlp.swiglu.w1." in k: + state_dict[k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.")] = state_dict.pop(k) + k = k.replace(".mlp.swiglu.w1.", ".feed_forward.swiglu.w1.") + if ".mlp.swiglu.w3." in k: + state_dict[k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.")] = state_dict.pop(k) + k = k.replace(".mlp.swiglu.w3.", ".feed_forward.swiglu.w3.") + if ".ln_1." in k: + state_dict[k.replace(".ln_1.", ".attention_norm.")] = state_dict.pop(k) + k = k.replace(".ln_1.", ".attention_norm.") + if ".ln_2." in k: + state_dict[k.replace(".ln_2.", ".ffn_norm.")] = state_dict.pop(k) + k = k.replace(".ln_2.", ".ffn_norm.") + if ".mlp.c_proj." in k: + state_dict[k.replace(".mlp.c_proj.", ".feed_forward.w2.")] = state_dict.pop(k) + k = k.replace(".mlp.c_proj.", ".feed_forward.w2.") + + model.load_state_dict(state_dict, assign=True) + # simple_quantizer = WeightOnlyInt8QuantHandler(model) + # quantized_state_dict = simple_quantizer.create_quantized_state_dict() + # model = simple_quantizer.convert_for_runtime() + # model.load_state_dict(quantized_state_dict, assign=True) + model = model.to(device=device, dtype=precision) + + ###### TOKENIZER + tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {}) + tokenizer = TrainedBPETokeniser(**tokenizer_info) + + ###### SPEAKER EMBEDDER + # TODO: fix! + smodel = SpeakerEncoder( + weights_fpath=spk_emb_ckpt_path, + device=device, + eval=True, + verbose=False, + ) + return model.eval(), tokenizer, smodel + + +def build_model( + *, + precision: torch.dtype, + checkpoint_path: Path = Path(""), + spk_emb_ckpt_path: Path = Path(""), + compile_prefill: bool = False, + compile: bool = True, + device: str = "cuda", +): + assert checkpoint_path.is_file(), checkpoint_path + + print(f"Using device={device}") + + print("Loading model ...") + t0 = time.time() + model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision) + + device_sync(device=device) # MKG + print(f"Time to load model: {time.time() - t0:.02f} seconds") + + torch.manual_seed(1234) + model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())]) + + with torch.device(device): + model.setup_spk_cond_mask() + model.setup_caches(max_batch_size=2, max_seq_length=model.config.block_size) + + if compile: + print("Compiling...Can take up to 2 mins.") + global decode_one_token, prefill + decode_one_token = torch.compile( + decode_one_token, + mode="max-autotune", + fullgraph=True, + ) + + if compile_prefill: + prefill = torch.compile( + prefill, + fullgraph=True, + dynamic=True, + ) + + encoded = encode_tokens(tokenizer, "Hello, what's up?", device=device) + spk_emb = torch.randn((1, 256), device=device, dtype=precision) + + device_sync(device=device) # MKG + t0 = time.perf_counter() + y = generate( + model, + encoded, + spk_emb, + max_new_tokens=200, + callback=lambda x: x, + temperature=torch.tensor(1.0, device=device, dtype=precision), + top_k=None, + top_p=torch.tensor(0.95, device=device, dtype=precision), + guidance_scale=torch.tensor(3.0, device=device, dtype=precision), + end_of_audio_token=9999, # don't end early for compilation stage. + ) + + device_sync(device=device) # MKG + + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + + return model, tokenizer, smodel, model_size + + +def main( + *, + model, + tokenizer, + model_size, + prompt: str, + guidance_scale: torch.Tensor, + temperature: torch.Tensor, + spk_emb: torch.Tensor, + top_k: Optional[torch.Tensor] = None, + top_p: Optional[torch.Tensor] = None, + device: str = "cuda", +) -> list: + """Generates text samples based on a pre-trained Transformer model and tokenizer.""" + + encoded = encode_tokens(tokenizer, prompt, device=device) + prompt_length = encoded.size(0) + + aggregate_metrics: dict = { + "tokens_per_sec": [], + } + + device_sync(device=device) # MKG + + if True: + callback = lambda x: x + t0 = time.perf_counter() + + y = generate( + model, + encoded, + spk_emb, + callback=callback, + temperature=temperature, + top_k=top_k, + top_p=top_p, + guidance_scale=guidance_scale, + ) + + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + print(f"Time for 1st stage LLM inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s") + # print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}") + print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB\n") + + return y.tolist() diff --git a/fam/llm/fast_model.py b/fam/llm/fast_model.py new file mode 100644 index 0000000..5d74bd9 --- /dev/null +++ b/fam/llm/fast_model.py @@ -0,0 +1,261 @@ +# Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, this +# list of conditions and the following disclaimer in the documentation and/or other +# materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from dataclasses import dataclass +from functools import reduce +from math import gcd +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + +from fam.llm.utils import get_default_dtype + +import logging + +# Adjust the logging level +logger = logging.getLogger("torch") +logger.setLevel(logging.ERROR) + + +def find_multiple(n: int, *args: Tuple[int]) -> int: + k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + speaker_emb_dim: int = 256 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + norm_eps: float = 1e-5 + dtype: torch.dtype = torch.bfloat16 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + self.dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[get_default_dtype()] + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "metavoice-1B": dict( + n_layer=24, + n_head=16, + dim=2048, + vocab_size=2562, + ), +} + + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.pos_embeddings = nn.Embedding(config.block_size, config.dim) + self.speaker_cond_pos = nn.Linear(config.speaker_emb_dim, config.dim, bias=False) + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_spk_cond_mask(self): + self.spk_cond_mask = torch.zeros((2, 1, self.config.dim), dtype=torch.bool) + self.spk_cond_mask[0] = 1 + + def setup_caches(self, max_batch_size, max_seq_length): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype=self.config.dtype + ) + + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + + def forward(self, idx: Tensor, spk_emb: Tensor, input_pos: Tensor) -> Tensor: + mask = self.causal_mask[None, None, input_pos] + x = ( + self.tok_embeddings(idx) + + self.pos_embeddings(input_pos) + # masking for speaker condition free guidance + + self.speaker_cond_pos(spk_emb) * self.spk_cond_mask + ) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor: + h = x + self.attention(self.attention_norm(x), mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + + def forward( + self, + x: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class SwiGLU(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return F.silu(self.w1(x)) * self.w3(x) + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.swiglu = SwiGLU(config) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(self.swiglu(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight diff --git a/fam/llm/sample.py b/fam/llm/inference.py similarity index 97% rename from fam/llm/sample.py rename to fam/llm/inference.py index 4d80c97..13ddfac 100644 --- a/fam/llm/sample.py +++ b/fam/llm/inference.py @@ -1,3 +1,7 @@ +""" +Command: python fam/llm/inference.py --spk_cond_path="assets/bria.mp3" --text="This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model." +""" + import dataclasses import hashlib import json @@ -6,6 +10,7 @@ import shutil import subprocess import tempfile +import time from contextlib import nullcontext from dataclasses import dataclass from typing import List, Literal, Optional, Tuple, Type, Union @@ -20,12 +25,7 @@ from fam.llm.decoders import Decoder, EncodecDecoder from fam.llm.enhancers import BaseEnhancer, get_enhancer from fam.llm.model import GPT, GPTConfig -from fam.llm.utils import ( - check_audio_file, - get_default_dtype, - get_default_use_kv_cache, - normalize_text, -) +from fam.llm.utils import check_audio_file, get_default_dtype, normalize_text from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder from fam.quantiser.text.tokenise import TrainedBPETokeniser @@ -57,7 +57,7 @@ def __init__( tokenizer_cls: Type[TrainedBPETokeniser], decoder_cls: Type[Decoder], data_adapter_fn, - use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = None, + use_kv_cache: Optional[Literal["vanilla"]] = None, ): # TODO: disentangle the encodec stuff and numbers etc with rest of this code (esp at encoder-only / second stage model inference) # TODO: remove magic number @@ -150,11 +150,7 @@ def _init_model(self): if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False: raise Exception("kv_cache not supported for non-causal models!") - if self.use_kv_cache == "flash_decoding": - self.model.enable_kv_cache() - for block in self.model.transformer.h: - block.attn.attn_kernel_type = "fd" - elif self.use_kv_cache == "vanilla": + if self.use_kv_cache == "vanilla": self.model.enable_kv_cache() else: raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!") @@ -471,6 +467,8 @@ def _sample_utterance_batch( speaker_embs.append(get_cached_embedding(spk_cond_path, spkemb_model) if spk_cond_path else None) b_speaker_embs = torch.cat(speaker_embs, dim=0) + + start = time.time() b_tokens = first_stage_model( texts=texts, speaker_embs=b_speaker_embs, @@ -514,6 +512,8 @@ def _sample_utterance_batch( first_stage_ckpt_path, second_stage_ckpt_path, ) + + print(f"time_to_synth_s: {time.time() - start}") return [str(w) + ".wav" if not str(w).endswith(".wav") else str(w) for w in wav_files] @@ -637,9 +637,8 @@ class SamplingControllerConfig: init_from: str = "resume" """Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl').""" - use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = get_default_use_kv_cache() - """Type of kv caching to use for inference: 1) [none] no kv caching, 2) [flash_decoding] use the - flash decoding kernel, 3) [vanilla] use torch attention with hand implemented kv-cache.""" + use_kv_cache: Optional[Literal["vanilla"]] = "vanilla" + """Type of kv caching to use for inference: 1) [none] no kv caching, 2) [vanilla] use torch attention with hand implemented kv-cache.""" output_dir: str = "samples/" """Relative path to output directory""" diff --git a/fam/llm/layers/attn.py b/fam/llm/layers/attn.py index 8110306..a1053ee 100644 --- a/fam/llm/layers/attn.py +++ b/fam/llm/layers/attn.py @@ -4,11 +4,6 @@ import torch.nn as nn from torch.nn import functional as F -try: - from flash_attn import flash_attn_with_kvcache # type: ignore -except ImportError: - warnings.warn("flash_attn not installed, make sure to replace attention mechanism with torch_attn") - class SelfAttention(nn.Module): def __init__(self, config): @@ -124,42 +119,6 @@ def _update_kv_cache(self, q, k, v): return k, v - def _fd_attention(self, c_x: torch.Tensor) -> torch.Tensor: - """ - Performs Flash decoding based attention. - Args: - c_x: The input tensor. - Returns: - The output tensor. - Raises: - Exception: If key-value caching is not enabled. - Exception: If non-causal attention is activated. - """ - if self.kv_cache_enabled is False: - raise Exception("Flash decoding required kv_cache to be enabled") - - if self.causal is False: - raise Exception("It is only supported for causal attention") - - q, k, v = c_x.split(1, dim=2) - q = q.squeeze(2) - k = k.squeeze(2) - v = v.squeeze(2) - - y = flash_attn_with_kvcache( - q, - self.kv_cache[0], - self.kv_cache[1], - k, - v, - cache_seqlens=self.kv_cache_first_empty_index, - softmax_scale=None, - causal=self.causal, - ) - self.kv_cache_first_empty_index += q.shape[1] - - return y - def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor: """ Performs attention using the torch.nn.functional.scaled_dot_product_attention function. @@ -215,9 +174,7 @@ def forward(self, x): c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs) # causal self-attention; - if self.attn_kernel_type == "fd": - y = self._fd_attention(c_x) - elif self.attn_kernel_type == "torch_attn": + if self.attn_kernel_type == "torch_attn": y = self._torch_attn(c_x) else: raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}") diff --git a/fam/llm/model.py b/fam/llm/model.py index 8c42735..3f98bf7 100644 --- a/fam/llm/model.py +++ b/fam/llm/model.py @@ -42,7 +42,7 @@ class GPTConfig: rmsnorm_eps: Optional[float] = None # only used for rmsnorm nonlinearity_type: str = "gelu" # "gelu" or "swiglu" swiglu_multiple_of: Optional[int] = None # MLP hidden layer (using SwiGLU) will be multiple of this - attn_kernel_type: Literal["fd", "torch_attn"] = "torch_attn" + attn_kernel_type: Literal["torch_attn"] = "torch_attn" kv_cache_enabled: bool = False # whether to use key-value cache for attention @@ -395,7 +395,7 @@ def generate( raise Exception("top_p is not supported for non-causal sampling") out = [] - for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="non-causal batching"): + for start_index in tqdm.tqdm(range(0, idx.shape[0], batch_size), desc="Non-causal batching"): end_index = min(start_index + batch_size, idx.shape[0]) out.append( self._non_causal_sample( diff --git a/fam/llm/utils.py b/fam/llm/utils.py index 79fe4a2..bc2eed6 100644 --- a/fam/llm/utils.py +++ b/fam/llm/utils.py @@ -43,7 +43,7 @@ def normalize_text(text: str) -> str: non_bpe_chars = set([c for c in list(text) if ord(c) >= 256]) if len(non_bpe_chars) > 0: non_bpe_points = [(c, ord(c)) for c in non_bpe_chars] - raise ValueError(f"Non-BPE single token characters found: {non_bpe_points}") + raise ValueError(f"Non-supported character found: {non_bpe_points}") text = text.replace("\t", " ") text = text.replace("\n", " ") @@ -75,21 +75,18 @@ def check_audio_file(path_or_uri, threshold_s=30): os.remove(filepath) -def get_default_use_kv_cache() -> str: - """Compute default value for 'use_kv_cache' based on GPU architecture""" - if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - device_properties = torch.cuda.get_device_properties(i) - return "vanilla" if "Turing" or "Tesla" in device_properties else "flash_decoding" - else: - return "vanilla" - - def get_default_dtype() -> str: """Compute default 'dtype' based on GPU architecture""" if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): device_properties = torch.cuda.get_device_properties(i) - return "float16" if "Turing" or "Tesla" in device_properties else "bfloat16" + dtype = "float16" if device_properties.major <= 7 else "bfloat16" # tesla and turing architectures else: - return "float16" + dtype = "float16" + + print(f"using dtype={dtype}") + return dtype + + +def get_device() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" diff --git a/requirements.txt b/requirements.txt index 0d4c697..4bac4a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -torch>=2.1.0 -transformers +torch>=2.1.0 # required to install audiocraft librosa tqdm tiktoken==0.5.1 @@ -12,4 +11,4 @@ tyro deepfilternet pydub gradio -huggingface_hub \ No newline at end of file +huggingface_hub diff --git a/fam/llm/serving.py b/serving.py similarity index 53% rename from fam/llm/serving.py rename to serving.py index 018997f..48daf26 100644 --- a/fam/llm/serving.py +++ b/serving.py @@ -5,27 +5,18 @@ import tempfile import warnings from pathlib import Path -from typing import Literal, Optional, Tuple +from typing import Optional import fastapi import fastapi.middleware.cors -import torch import tyro import uvicorn from attr import dataclass from fastapi import Request from fastapi.responses import Response -from huggingface_hub import snapshot_download - -from fam.llm.sample import ( - InferenceConfig, - Model, - build_models, - get_first_stage_path, - get_second_stage_path, - sample_utterance, -) -from fam.llm.utils import check_audio_file, get_default_dtype, get_default_use_kv_cache + +from fam.llm.fast_inference import TTS +from fam.llm.utils import check_audio_file logger = logging.getLogger(__name__) @@ -39,41 +30,19 @@ class ServingConfig: huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1" """Absolute path to the model directory.""" - max_new_tokens: int = 864 * 2 - """Maximum number of new tokens to generate from the first stage model.""" - temperature: float = 1.0 """Temperature for sampling applied to both models.""" - top_k: int = 200 - """Top k for sampling applied to both models.""" - seed: int = 1337 """Random seed for sampling.""" - dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = get_default_dtype() - """Data type to use for sampling.""" - - enhancer: Optional[Literal["df"]] = "df" - """Enhancer to use for post-processing.""" - - compile: bool = False - """Whether to compile the model using PyTorch 2.0.""" - - use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = get_default_use_kv_cache() - """Type of kv caching to use for inference: 1) [none] no kv caching, 2) [flash_decoding] use the - flash decoding kernel, 3) [vanilla] use torch attention with hand implemented kv-cache.""" - port: int = 58003 # Singleton class _GlobalState: - spkemb_model: torch.nn.Module - first_stage_model: Model - second_stage_model: Model config: ServingConfig - enhancer: object + tts: TTS GlobalState = _GlobalState() @@ -82,9 +51,9 @@ class _GlobalState: @dataclass(frozen=True) class TTSRequest: text: str - guidance: Optional[Tuple[float, float]] = (3.0, 1.0) - top_p: Optional[float] = 0.95 speaker_ref_path: Optional[str] = None + guidance: float = 3.0 + top_p: float = 0.95 top_k: Optional[int] = None @@ -109,25 +78,20 @@ async def text_to_speech(req: Request): wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp) check_audio_file(wav_path) else: + # TODO: fix wav_path = tts_req.speaker_ref_path + if wav_path is None: warnings.warn("Running without speaker reference") assert tts_req.guidance is None - wav_out_path = sample_utterance( - tts_req.text, - wav_path, - GlobalState.spkemb_model, - GlobalState.first_stage_model, - GlobalState.second_stage_model, - enhancer=GlobalState.enhancer, - first_stage_ckpt_path=None, - second_stage_ckpt_path=None, - guidance_scale=tts_req.guidance, - max_new_tokens=GlobalState.config.max_new_tokens, - temperature=GlobalState.config.temperature, - top_k=tts_req.top_k, + + wav_out_path = GlobalState.tts.synthesise( + text=tts_req.text, + spk_ref_path=wav_path, top_p=tts_req.top_p, + guidance_scale=tts_req.guidance, ) + with open(wav_out_path, "rb") as f: return Response(content=f.read(), media_type="audio/wav") except Exception as e: @@ -157,15 +121,14 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp): if __name__ == "__main__": - # This has to be here to avoid some weird audiocraft shenaningans messing up matplotlib - from fam.llm.enhancers import get_enhancer - for name in logging.root.manager.loggerDict: logger = logging.getLogger(name) logger.setLevel(logging.INFO) logging.root.setLevel(logging.INFO) GlobalState.config = tyro.cli(ServingConfig) + GlobalState.tts = TTS(seed=GlobalState.config.seed) + app.add_middleware( fastapi.middleware.cors.CORSMiddleware, allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"], @@ -173,37 +136,6 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp): allow_methods=["*"], allow_headers=["*"], ) - - device = "cuda" if torch.cuda.is_available() else "cpu" - common_config = dict( - num_samples=1, - seed=1337, - device=device, - dtype=GlobalState.config.dtype, - compile=GlobalState.config.compile, - init_from="resume", - output_dir=tempfile.mkdtemp(), - ) - model_dir = snapshot_download(repo_id=GlobalState.config.huggingface_repo_id) - config1 = InferenceConfig( - ckpt_path=get_first_stage_path(model_dir), - **common_config, - ) - - config2 = InferenceConfig( - ckpt_path=get_second_stage_path(model_dir), - **common_config, - ) - - spkemb, llm_stg1, llm_stg2 = build_models( - config1, config2, model_dir=model_dir, device=device, use_kv_cache=GlobalState.config.use_kv_cache - ) - GlobalState.spkemb_model = spkemb - GlobalState.first_stage_model = llm_stg1 - GlobalState.second_stage_model = llm_stg2 - GlobalState.enhancer = get_enhancer(GlobalState.config.enhancer) - - # start server uvicorn.run( app, host="0.0.0.0",