Skip to content

Commit

Permalink
feat: add int4 and int8 weight-only quantisation (metavoiceio#95)
Browse files Browse the repository at this point in the history
* feat: add int4 and int8 quantisation

* formatting fixes

* update README

* update

* fix pr comments

* add comment

* add comment

---------

Co-authored-by: EC2 Default User <ec2-user@ip-172-31-30-234.eu-west-1.compute.internal>
  • Loading branch information
vatsalaggarwal and EC2 Default User authored Mar 15, 2024
1 parent e2a9c84 commit af516be
Show file tree
Hide file tree
Showing 6 changed files with 431 additions and 31 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ poetry install && poetry run pip install torch==2.2.1 torchaudio==2.2.1
## Usage
1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py)
```bash
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.
poetry run python -i fam/llm/fast_inference.py

# Run e.g. of API usage within the interactive python session
Expand All @@ -71,6 +73,8 @@ tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-s
2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py)
```bash
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.
poetry run python serving.py
poetry run python app.py
```
Expand Down
3 changes: 2 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@


import gradio as gr
import tyro

from fam.llm.fast_inference import TTS
from fam.llm.utils import check_audio_file

#### setup model
TTS_MODEL = TTS()
TTS_MODEL = tyro.cli(TTS)

#### setup interface
RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
Expand Down
24 changes: 21 additions & 3 deletions fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import tempfile
import time
from pathlib import Path
from typing import Literal, Optional

import librosa
import torch
import tyro
from huggingface_hub import snapshot_download

from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
Expand Down Expand Up @@ -33,10 +35,25 @@ class TTS:
END_OF_AUDIO_TOKEN = 1024

def __init__(
self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs"
self,
model_name: str = "metavoiceio/metavoice-1B-v0.1",
*,
seed: int = 1337,
output_dir: str = "outputs",
quantisation_mode: Optional[Literal["int4", "int8"]] = None,
):
"""
model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio)
Initialise the TTS model.
Args:
model_name: refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio)
seed: random seed for reproducibility
output_dir: directory to save output files
quantisation_mode: quantisation mode for first-stage LLM.
Options:
- None for no quantisation (bf16 or fp16 based on device),
- int4 for int4 weight-only quantisation,
- int8 for int8 weight-only quantisation.
"""

# NOTE: this needs to come first so that we don't change global state when we want to use
Expand Down Expand Up @@ -73,6 +90,7 @@ def __init__(
device=self._device,
compile=True,
compile_prefill=True,
quantisation_mode=quantisation_mode,
)

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
Expand Down Expand Up @@ -140,4 +158,4 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.


if __name__ == "__main__":
tts = TTS()
tts = tyro.cli(TTS)
64 changes: 39 additions & 25 deletions fam/llm/fast_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import itertools
import time
import warnings
from pathlib import Path
from typing import Optional, Tuple
from typing import Literal, Optional, Tuple

import torch
import torch._dynamo.config
import torch._inductor.config
import tqdm

from fam.llm.fast_quantize import WeightOnlyInt4QuantHandler, WeightOnlyInt8QuantHandler


def device_sync(device):
if "cuda" in device:
Expand Down Expand Up @@ -230,28 +233,13 @@ def encode_tokens(tokenizer: TrainedBPETokeniser, text: str, device="cuda") -> t
return torch.tensor(tokens, dtype=torch.int, device=device)


def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
def _load_model(
checkpoint_path, spk_emb_ckpt_path, device, precision, quantisation_mode: Optional[Literal["int4", "int8"]] = None
):
##### MODEL
with torch.device("meta"):
model = Transformer.from_name("metavoice-1B")

# TODO(quantization): enable
# if "int8" in str(checkpoint_path):
# print("Using int8 weight-only quantization!")
# from quantize import WeightOnlyInt8QuantHandler
# simple_quantizer = WeightOnlyInt8QuantHandler(model)
# model = simple_quantizer.convert_for_runtime()
# from quantize import WeightOnlyInt8QuantHandler

# if "int4" in str(checkpoint_path):
# print("Using int4 quantization!")
# path_comps = checkpoint_path.name.split(".")
# assert path_comps[-2].startswith("g")
# groupsize = int(path_comps[-2][1:])
# from quantize import WeightOnlyInt4QuantHandler
# simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
# model = simple_quantizer.convert_for_runtime()

checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
state_dict = checkpoint["model"]
# convert MetaVoice-1B model weights naming to gptfast naming
Expand Down Expand Up @@ -290,11 +278,34 @@ def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
k = k.replace(".mlp.c_proj.", ".feed_forward.w2.")

model.load_state_dict(state_dict, assign=True)
# simple_quantizer = WeightOnlyInt8QuantHandler(model)
# quantized_state_dict = simple_quantizer.create_quantized_state_dict()
# model = simple_quantizer.convert_for_runtime()
# model.load_state_dict(quantized_state_dict, assign=True)
model = model.to(device=device, dtype=precision)
model = model.to(device=device, dtype=torch.bfloat16)

if quantisation_mode == "int8":
warnings.warn(
"int8 quantisation is slower than bf16/fp16 for undebugged reasons! Please set optimisation_mode to `None` or to `int4`."
)
warnings.warn(
"quantisation will degrade the quality of the audio! Please set optimisation_mode to `None` for best quality."
)
simple_quantizer = WeightOnlyInt8QuantHandler(model)
quantized_state_dict = simple_quantizer.create_quantized_state_dict()
model = simple_quantizer.convert_for_runtime()
model.load_state_dict(quantized_state_dict, assign=True)
model = model.to(device=device, dtype=torch.bfloat16)
# TODO: int8/int4 doesn't decrease VRAM usage substantially... fix that (might be linked to kv-cache)
torch.cuda.empty_cache()
elif quantisation_mode == "int4":
warnings.warn(
"quantisation will degrade the quality of the audio! Please set optimisation_mode to `None` for best quality."
)
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize=128)
quantized_state_dict = simple_quantizer.create_quantized_state_dict()
model = simple_quantizer.convert_for_runtime(use_cuda=True)
model.load_state_dict(quantized_state_dict, assign=True)
model = model.to(device=device, dtype=torch.bfloat16)
torch.cuda.empty_cache()
elif quantisation_mode is not None:
raise Exception(f"Invalid quantisation mode {quantisation_mode}! Must be either 'int4' or 'int8'!")

###### TOKENIZER
tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {})
Expand All @@ -318,14 +329,17 @@ def build_model(
compile_prefill: bool = False,
compile: bool = True,
device: str = "cuda",
quantisation_mode: Optional[Literal["int4", "int8"]] = None,
):
assert checkpoint_path.is_file(), checkpoint_path

print(f"Using device={device}")

print("Loading model ...")
t0 = time.time()
model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision)
model, tokenizer, smodel = _load_model(
checkpoint_path, spk_emb_ckpt_path, device, precision, quantisation_mode=quantisation_mode
)

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
Expand Down
Loading

0 comments on commit af516be

Please sign in to comment.