From af516be2e3f448946d48c79b1c0fa16989866598 Mon Sep 17 00:00:00 2001 From: Vatsal Aggarwal Date: Fri, 15 Mar 2024 11:14:17 +0000 Subject: [PATCH] feat: add int4 and int8 weight-only quantisation (#95) * feat: add int4 and int8 quantisation * formatting fixes * update README * update * fix pr comments * add comment * add comment --------- Co-authored-by: EC2 Default User --- README.md | 4 + app.py | 3 +- fam/llm/fast_inference.py | 24 ++- fam/llm/fast_inference_utils.py | 64 +++--- fam/llm/fast_quantize.py | 361 ++++++++++++++++++++++++++++++++ serving.py | 6 +- 6 files changed, 431 insertions(+), 31 deletions(-) create mode 100644 fam/llm/fast_quantize.py diff --git a/README.md b/README.md index 6bac9a2..0400149 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,8 @@ poetry install && poetry run pip install torch==2.2.1 torchaudio==2.2.1 ## Usage 1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py) ```bash +# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio. +# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16. poetry run python -i fam/llm/fast_inference.py # Run e.g. of API usage within the interactive python session @@ -71,6 +73,8 @@ tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-s 2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py) ```bash +# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio. +# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16. poetry run python serving.py poetry run python app.py ``` diff --git a/app.py b/app.py index ddb4b6d..37ee06a 100644 --- a/app.py +++ b/app.py @@ -7,12 +7,13 @@ import gradio as gr +import tyro from fam.llm.fast_inference import TTS from fam.llm.utils import check_audio_file #### setup model -TTS_MODEL = TTS() +TTS_MODEL = tyro.cli(TTS) #### setup interface RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"] diff --git a/fam/llm/fast_inference.py b/fam/llm/fast_inference.py index daed440..b93b75c 100644 --- a/fam/llm/fast_inference.py +++ b/fam/llm/fast_inference.py @@ -3,9 +3,11 @@ import tempfile import time from pathlib import Path +from typing import Literal, Optional import librosa import torch +import tyro from huggingface_hub import snapshot_download from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook @@ -33,10 +35,25 @@ class TTS: END_OF_AUDIO_TOKEN = 1024 def __init__( - self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs" + self, + model_name: str = "metavoiceio/metavoice-1B-v0.1", + *, + seed: int = 1337, + output_dir: str = "outputs", + quantisation_mode: Optional[Literal["int4", "int8"]] = None, ): """ - model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio) + Initialise the TTS model. + + Args: + model_name: refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio) + seed: random seed for reproducibility + output_dir: directory to save output files + quantisation_mode: quantisation mode for first-stage LLM. + Options: + - None for no quantisation (bf16 or fp16 based on device), + - int4 for int4 weight-only quantisation, + - int8 for int8 weight-only quantisation. """ # NOTE: this needs to come first so that we don't change global state when we want to use @@ -73,6 +90,7 @@ def __init__( device=self._device, compile=True, compile_prefill=True, + quantisation_mode=quantisation_mode, ) def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str: @@ -140,4 +158,4 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3. if __name__ == "__main__": - tts = TTS() + tts = tyro.cli(TTS) diff --git a/fam/llm/fast_inference_utils.py b/fam/llm/fast_inference_utils.py index 97c2d95..cbeb708 100644 --- a/fam/llm/fast_inference_utils.py +++ b/fam/llm/fast_inference_utils.py @@ -25,14 +25,17 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import itertools import time +import warnings from pathlib import Path -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import torch import torch._dynamo.config import torch._inductor.config import tqdm +from fam.llm.fast_quantize import WeightOnlyInt4QuantHandler, WeightOnlyInt8QuantHandler + def device_sync(device): if "cuda" in device: @@ -230,28 +233,13 @@ def encode_tokens(tokenizer: TrainedBPETokeniser, text: str, device="cuda") -> t return torch.tensor(tokens, dtype=torch.int, device=device) -def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision): +def _load_model( + checkpoint_path, spk_emb_ckpt_path, device, precision, quantisation_mode: Optional[Literal["int4", "int8"]] = None +): ##### MODEL with torch.device("meta"): model = Transformer.from_name("metavoice-1B") - # TODO(quantization): enable - # if "int8" in str(checkpoint_path): - # print("Using int8 weight-only quantization!") - # from quantize import WeightOnlyInt8QuantHandler - # simple_quantizer = WeightOnlyInt8QuantHandler(model) - # model = simple_quantizer.convert_for_runtime() - # from quantize import WeightOnlyInt8QuantHandler - - # if "int4" in str(checkpoint_path): - # print("Using int4 quantization!") - # path_comps = checkpoint_path.name.split(".") - # assert path_comps[-2].startswith("g") - # groupsize = int(path_comps[-2][1:]) - # from quantize import WeightOnlyInt4QuantHandler - # simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) - # model = simple_quantizer.convert_for_runtime() - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False) state_dict = checkpoint["model"] # convert MetaVoice-1B model weights naming to gptfast naming @@ -290,11 +278,34 @@ def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision): k = k.replace(".mlp.c_proj.", ".feed_forward.w2.") model.load_state_dict(state_dict, assign=True) - # simple_quantizer = WeightOnlyInt8QuantHandler(model) - # quantized_state_dict = simple_quantizer.create_quantized_state_dict() - # model = simple_quantizer.convert_for_runtime() - # model.load_state_dict(quantized_state_dict, assign=True) - model = model.to(device=device, dtype=precision) + model = model.to(device=device, dtype=torch.bfloat16) + + if quantisation_mode == "int8": + warnings.warn( + "int8 quantisation is slower than bf16/fp16 for undebugged reasons! Please set optimisation_mode to `None` or to `int4`." + ) + warnings.warn( + "quantisation will degrade the quality of the audio! Please set optimisation_mode to `None` for best quality." + ) + simple_quantizer = WeightOnlyInt8QuantHandler(model) + quantized_state_dict = simple_quantizer.create_quantized_state_dict() + model = simple_quantizer.convert_for_runtime() + model.load_state_dict(quantized_state_dict, assign=True) + model = model.to(device=device, dtype=torch.bfloat16) + # TODO: int8/int4 doesn't decrease VRAM usage substantially... fix that (might be linked to kv-cache) + torch.cuda.empty_cache() + elif quantisation_mode == "int4": + warnings.warn( + "quantisation will degrade the quality of the audio! Please set optimisation_mode to `None` for best quality." + ) + simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize=128) + quantized_state_dict = simple_quantizer.create_quantized_state_dict() + model = simple_quantizer.convert_for_runtime(use_cuda=True) + model.load_state_dict(quantized_state_dict, assign=True) + model = model.to(device=device, dtype=torch.bfloat16) + torch.cuda.empty_cache() + elif quantisation_mode is not None: + raise Exception(f"Invalid quantisation mode {quantisation_mode}! Must be either 'int4' or 'int8'!") ###### TOKENIZER tokenizer_info = checkpoint.get("meta", {}).get("tokenizer", {}) @@ -318,6 +329,7 @@ def build_model( compile_prefill: bool = False, compile: bool = True, device: str = "cuda", + quantisation_mode: Optional[Literal["int4", "int8"]] = None, ): assert checkpoint_path.is_file(), checkpoint_path @@ -325,7 +337,9 @@ def build_model( print("Loading model ...") t0 = time.time() - model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision) + model, tokenizer, smodel = _load_model( + checkpoint_path, spk_emb_ckpt_path, device, precision, quantisation_mode=quantisation_mode + ) device_sync(device=device) # MKG print(f"Time to load model: {time.time() - t0:.02f} seconds") diff --git a/fam/llm/fast_quantize.py b/fam/llm/fast_quantize.py new file mode 100644 index 0000000..aa3ea8b --- /dev/null +++ b/fam/llm/fast_quantize.py @@ -0,0 +1,361 @@ +# Copyright (c) MetaVoice Labs Inc., Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, this +# list of conditions and the following disclaimer in the documentation and/or other +# materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import time +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F + +default_device = "cuda" if torch.cuda.is_available() else "cpu" + +##### Quantization Primitives ###### + + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scales and zero_points based on min and max + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scales = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scales is the same dtype as the original tensor + scales = torch.clamp(scales, min=eps).to(x.dtype) + zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scales/zp + x_div = x / scales.unsqueeze(-1) + x_round = torch.round(x_div) + x_zp = x_round + zero_points.unsqueeze(-1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scales, zero_points + + +def get_group_qparams(w, n_bit=4, groupsize=128): + # needed for GPTQ with padding + if groupsize > w.shape[-1]: + groupsize = w.shape[-1] + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + zeros = min_val + scales * (2 ** (n_bit - 1)) + return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(torch.bfloat16).reshape(w.shape[0], -1) + + +def pack_scales_and_zeros(scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + +def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column quantize + if groupsize > w.shape[-1] and scales.shape[-1] == 1: + groupsize = w.shape[-1] + + assert w.shape[-1] % groupsize == 0 + assert w.dim() == 2 + + to_quant = w.reshape(-1, groupsize) + assert torch.isnan(to_quant).sum() == 0 + + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + min_val = zeros - scales * (2 ** (n_bit - 1)) + max_int = 2**n_bit - 1 + min_int = 0 + w_int32 = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int).to(torch.int32).reshape_as(w) + + return w_int32 + + +def group_quantize_tensor(w, n_bit=4, groupsize=128): + scales, zeros = get_group_qparams(w, n_bit, groupsize) + w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize) + scales_and_zeros = pack_scales_and_zeros(scales, zeros) + return w_int32, scales_and_zeros + + +def group_dequantize_tensor_from_qparams(w_int32, scales, zeros, n_bit=4, groupsize=128): + assert groupsize > 1 + # needed for GPTQ single column dequantize + if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1: + groupsize = w_int32.shape[-1] + assert w_int32.shape[-1] % groupsize == 0 + assert w_int32.dim() == 2 + + w_int32_grouped = w_int32.reshape(-1, groupsize) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + + w_dq = w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32) + return w_dq + + +##### Weight-only int8 per-channel quantized code ###### + + +def replace_linear_weight_only_int8_per_channel(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features)) + else: + replace_linear_weight_only_int8_per_channel(child) + + +class WeightOnlyInt8QuantHandler: + def __init__(self, mod): + self.mod = mod + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + # TODO: quantise RMSNorm as well. + if isinstance(mod, torch.nn.Linear): + int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8) + cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu") + cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to("cpu") + + return cur_state_dict + + def convert_for_runtime(self): + replace_linear_weight_only_int8_per_channel(self.mod) + return self.mod + + +class WeightOnlyInt8Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales + + +##### weight only int4 per channel groupwise quantized code ###### + + +def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + weight_int32, scales_and_zeros = group_quantize_tensor(weight_bf16, n_bit=4, groupsize=groupsize) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) + return weight_int4pack, scales_and_zeros + + +def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros) + new_shape = origin_x_size[:-1] + (out_features,) + c = c.reshape(new_shape) + return c + + +def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): + return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding, use_cuda): + for name, child in module.named_children(): + if isinstance(child, nn.Linear) and child.out_features % 8 == 0: + if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles): + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=False, + use_cuda=use_cuda, + ), + ) + elif padding: + setattr( + module, + name, + WeightOnlyInt4Linear( + child.in_features, + child.out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + padding=True, + use_cuda=use_cuda, + ), + ) + else: + replace_linear_int4(child, groupsize, inner_k_tiles, padding, use_cuda) + + +class WeightOnlyInt4QuantHandler: + def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True): + self.mod = mod + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.padding = padding + assert groupsize in [32, 64, 128, 256] + assert inner_k_tiles in [2, 4, 8] + + @torch.no_grad() + def create_quantized_state_dict(self): + cur_state_dict = self.mod.state_dict() + for fqn, mod in self.mod.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + if out_features % 8 != 0: + continue + assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + weight = mod.weight.data + if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): + if self.padding: + import torch.nn.functional as F + from model import find_multiple + + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it" + ) + continue + weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles + ) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") + + return cur_state_dict + + def convert_for_runtime(self, use_cuda): + replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda) + return self.mod + + +class WeightOnlyInt4Linear(torch.nn.Module): + __constants__ = ["in_features", "out_features"] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__( + self, + in_features: int, + out_features: int, + bias=True, + device=None, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, + padding: bool = True, + use_cuda=True, + ) -> None: + super().__init__() + self.padding = padding + if padding: + from model import find_multiple + + self.origin_in_features = in_features + in_features = find_multiple(in_features, 1024) + + self.in_features = in_features + self.out_features = out_features + assert not bias, "require bias=False" + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + + assert out_features % 8 == 0, "require out_features % 8 == 0" + assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + if use_cuda: + self.register_buffer( + "weight", + torch.empty( + (out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32 + ), + ) + else: + self.register_buffer("weight", torch.empty((out_features, in_features // 2), dtype=torch.uint8)) + self.register_buffer( + "scales_and_zeros", torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = input.to(torch.bfloat16) + if self.padding: + import torch.nn.functional as F + + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) + return linear_forward_int4(input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize) diff --git a/serving.py b/serving.py index 48daf26..94119b0 100644 --- a/serving.py +++ b/serving.py @@ -5,7 +5,7 @@ import tempfile import warnings from pathlib import Path -from typing import Optional +from typing import Literal, Optional import fastapi import fastapi.middleware.cors @@ -38,6 +38,8 @@ class ServingConfig: port: int = 58003 + quantisation_mode: Optional[Literal["int4", "int8"]] = None + # Singleton class _GlobalState: @@ -127,7 +129,7 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp): logging.root.setLevel(logging.INFO) GlobalState.config = tyro.cli(ServingConfig) - GlobalState.tts = TTS(seed=GlobalState.config.seed) + GlobalState.tts = TTS(seed=GlobalState.config.seed, quantisation_mode=GlobalState.config.quantisation_mode) app.add_middleware( fastapi.middleware.cors.CORSMiddleware,