From 085c15b7c393434db39d2f78a31e8406b3290557 Mon Sep 17 00:00:00 2001 From: mobicham Date: Sun, 13 Oct 2024 16:20:09 +0000 Subject: [PATCH] add gemlite support --- hqq/backends/gemlite.py | 53 +++++++++++++++++++++++++++++++++++++++++ hqq/utils/patching.py | 17 +++++++++---- 2 files changed, 65 insertions(+), 5 deletions(-) create mode 100755 hqq/backends/gemlite.py diff --git a/hqq/backends/gemlite.py b/hqq/backends/gemlite.py new file mode 100755 index 0000000..7e8e278 --- /dev/null +++ b/hqq/backends/gemlite.py @@ -0,0 +1,53 @@ +# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2024 +##################################################### + +import torch +from gemlite.core import GemLiteLinearTriton, DType + +from ..core.quantize import HQQLinear +from ..core.peft import HQQLinearLoRA + + +def patch_hqq_to_gemlite(layer, patch_params): + hqq_layer = None + if isinstance(layer, HQQLinear): + hqq_layer = layer + if isinstance(layer, HQQLinearLoRA): + hqq_layer = layer.linear_layer + + if hqq_layer is None: + return layer + + if hqq_layer.meta["group_size"] is None: + hqq_layer.meta["group_size"] = hqq_layer.in_features + + gemlite_linear = GemLiteLinearTriton( + hqq_layer.meta["nbits"], + group_size=hqq_layer.meta["group_size"], + in_features=hqq_layer.in_features, + out_features=hqq_layer.out_features, + input_dtype=DType.FP16, + output_dtype=DType.FP16, + acc_dtype=DType.FP16, + exhaustive=False, + ) + + orig_shape = hqq_layer.meta["shape"] + W_q = hqq_layer.unpack().view(orig_shape) + scales = hqq_layer.meta["scale"].clone() + zeros = hqq_layer.meta["zero"].clone() + gemlite_linear.pack(W_q, scales, zeros, None) + gemlite_linear.name = hqq_layer.name + + del hqq_layer.W_q + del hqq_layer.meta + del hqq_layer + torch.cuda.empty_cache() + + if isinstance(layer, HQQLinear): + return gemlite_linear + + if isinstance(layer, HQQLinearLoRA): + layer.linear_layer = gemlite_linear + + return layer diff --git a/hqq/utils/patching.py b/hqq/utils/patching.py index 4bbcd21..b7d3210 100755 --- a/hqq/utils/patching.py +++ b/hqq/utils/patching.py @@ -19,6 +19,13 @@ patch_hqq_to_bitblas = None print(colored('Warning: failed to import the BitBlas backend. Check if BitBlas is correctly installed if you want to use the bitblas backend (https://github.com/microsoft/BitBLAS).','yellow')) +try: + from ..backends.gemlite import patch_hqq_to_gemlite +except Exception: + patch_hqq_to_gemlite = None + print(colored('Warning: failed to import the GemLite backend. Check if GemLite is correctly installed if you want to use the gemlite backend (https://github.com/mobiusml/gemlite/).','yellow')) + + def patch_linearlayers(model, fct, patch_param=None, verbose=False): base_class = model.base_class if (hasattr(model, "base_class")) else AutoHQQHFModel base_class.setup_model(model) @@ -109,22 +116,22 @@ def prepare_for_inference(model, allow_merge=False, backend="default", verbose=F patch_linearlayers(model, patch_lora_inference) cleanup() + if backend == "gemlite" and (patch_hqq_to_gemlite is not None): + patch_linearlayers(model, patch_hqq_to_gemlite, verbose=verbose) if backend == "bitblas" and (patch_hqq_to_bitblas is not None): patch_linearlayers(model, patch_hqq_to_bitblas, verbose=verbose) - cleanup() if backend == "torchao_int4": patch_linearlayers(model, patch_hqq_to_aoint4, verbose=verbose) recommended_inductor_config_setter() - cleanup() if allow_merge: # only compatible with symmetric quant kernels patch_linearlayers( model, patch_merge_zeros_with_lora, {"z_shift": 8, "keep_lora": False}, verbose=verbose, - ) - cleanup() + ) if backend == "marlin" and (patch_hqq_to_marlin is not None): patch_linearlayers(model, patch_hqq_to_marlin, verbose=verbose) - cleanup() + + cleanup() patch_linearlayers( model, patch_add_weight_param, {"device": model.device, "dtype": model.dtype}