From 01e3bc09a37a803fc2b2baa92ba61aa5029cf57a Mon Sep 17 00:00:00 2001 From: lucapericlp Date: Tue, 26 Mar 2024 11:35:36 +0000 Subject: [PATCH] Parametrising fast inference so that finetuned models can be used (#113) --- README.md | 11 ++++++++--- fam/llm/fast_inference.py | 7 ++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 158d53e..5419fd9 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Open In Colab -[![](https://dcbadge.vercel.app/api/server/Cpy6U3na8Z?style=flat&compact=True)](https://discord.gg/tbTbkGEgJM) +[![](https://dcbadge.vercel.app/api/server/Cpy6U3na8Z?style=flat&compact=True)](https://discord.gg/tbTbkGEgJM) [![Twitter](https://img.shields.io/twitter/url/https/twitter.com/OnusFM.svg?style=social&label=@metavoiceio)](https://twitter.com/metavoiceio) @@ -69,7 +69,7 @@ poetry install && poetry run pip install torch==2.2.1 torchaudio==2.2.1 ## Usage 1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py) ```bash -# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio. +# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio. # Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16. poetry run python -i fam/llm/fast_inference.py @@ -82,7 +82,7 @@ tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-s 2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py) ```bash -# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio. +# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio. # Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16. poetry run python serving.py poetry run python app.py @@ -108,6 +108,11 @@ Try it out using our sample datasets via: poetry run finetune --train ./datasets/sample_dataset.csv --val ./datasets/sample_val_dataset.csv ``` +Once you've trained your model, you can use it for inference via: +```bash +poetry run python -i fam/llm/fast_inference.py --first_stage_path ./my-finetuned_model.pt +``` + ### Configuration In order to set hyperparameters such as learning rate, what to freeze, etc, you diff --git a/fam/llm/fast_inference.py b/fam/llm/fast_inference.py index b93b75c..904aeb9 100644 --- a/fam/llm/fast_inference.py +++ b/fam/llm/fast_inference.py @@ -41,6 +41,7 @@ def __init__( seed: int = 1337, output_dir: str = "outputs", quantisation_mode: Optional[Literal["int4", "int8"]] = None, + first_stage_path: Optional[str] = None, ): """ Initialise the TTS model. @@ -54,6 +55,7 @@ def __init__( - None for no quantisation (bf16 or fp16 based on device), - int4 for int4 weight-only quantisation, - int8 for int8 weight-only quantisation. + first_stage_path: path to first-stage LLM checkpoint. If provided, this will override the one grabbed from Hugging Face via `model_name`. """ # NOTE: this needs to come first so that we don't change global state when we want to use @@ -64,6 +66,9 @@ def __init__( self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN) self.output_dir = output_dir os.makedirs(self.output_dir, exist_ok=True) + if first_stage_path: + print(f"Overriding first stage checkpoint via provided model: {first_stage_path}") + first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt" second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt" config_second_stage = InferenceConfig( @@ -85,7 +90,7 @@ def __init__( self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype] self.model, self.tokenizer, self.smodel, self.model_size = build_model( precision=self.precision, - checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"), + checkpoint_path=Path(first_stage_ckpt), spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"), device=self._device, compile=True,