Skip to content

Commit

Permalink
add gemlite support
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Oct 13, 2024
1 parent 6a3e4b4 commit 085c15b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 5 deletions.
53 changes: 53 additions & 0 deletions hqq/backends/gemlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2024
#####################################################

import torch
from gemlite.core import GemLiteLinearTriton, DType

from ..core.quantize import HQQLinear
from ..core.peft import HQQLinearLoRA


def patch_hqq_to_gemlite(layer, patch_params):
hqq_layer = None
if isinstance(layer, HQQLinear):
hqq_layer = layer
if isinstance(layer, HQQLinearLoRA):
hqq_layer = layer.linear_layer

if hqq_layer is None:
return layer

if hqq_layer.meta["group_size"] is None:
hqq_layer.meta["group_size"] = hqq_layer.in_features

gemlite_linear = GemLiteLinearTriton(
hqq_layer.meta["nbits"],
group_size=hqq_layer.meta["group_size"],
in_features=hqq_layer.in_features,
out_features=hqq_layer.out_features,
input_dtype=DType.FP16,
output_dtype=DType.FP16,
acc_dtype=DType.FP16,
exhaustive=False,
)

orig_shape = hqq_layer.meta["shape"]
W_q = hqq_layer.unpack().view(orig_shape)
scales = hqq_layer.meta["scale"].clone()
zeros = hqq_layer.meta["zero"].clone()
gemlite_linear.pack(W_q, scales, zeros, None)
gemlite_linear.name = hqq_layer.name

del hqq_layer.W_q
del hqq_layer.meta
del hqq_layer
torch.cuda.empty_cache()

if isinstance(layer, HQQLinear):
return gemlite_linear

if isinstance(layer, HQQLinearLoRA):
layer.linear_layer = gemlite_linear

return layer
17 changes: 12 additions & 5 deletions hqq/utils/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
patch_hqq_to_bitblas = None
print(colored('Warning: failed to import the BitBlas backend. Check if BitBlas is correctly installed if you want to use the bitblas backend (https://github.com/microsoft/BitBLAS).','yellow'))

try:
from ..backends.gemlite import patch_hqq_to_gemlite
except Exception:
patch_hqq_to_gemlite = None
print(colored('Warning: failed to import the GemLite backend. Check if GemLite is correctly installed if you want to use the gemlite backend (https://github.com/mobiusml/gemlite/).','yellow'))


def patch_linearlayers(model, fct, patch_param=None, verbose=False):
base_class = model.base_class if (hasattr(model, "base_class")) else AutoHQQHFModel
base_class.setup_model(model)
Expand Down Expand Up @@ -109,22 +116,22 @@ def prepare_for_inference(model, allow_merge=False, backend="default", verbose=F
patch_linearlayers(model, patch_lora_inference)
cleanup()

if backend == "gemlite" and (patch_hqq_to_gemlite is not None):
patch_linearlayers(model, patch_hqq_to_gemlite, verbose=verbose)
if backend == "bitblas" and (patch_hqq_to_bitblas is not None):
patch_linearlayers(model, patch_hqq_to_bitblas, verbose=verbose)
cleanup()
if backend == "torchao_int4":
patch_linearlayers(model, patch_hqq_to_aoint4, verbose=verbose)
recommended_inductor_config_setter()
cleanup()
if allow_merge: # only compatible with symmetric quant kernels
patch_linearlayers(
model, patch_merge_zeros_with_lora, {"z_shift": 8, "keep_lora": False},
verbose=verbose,
)
cleanup()
)
if backend == "marlin" and (patch_hqq_to_marlin is not None):
patch_linearlayers(model, patch_hqq_to_marlin, verbose=verbose)
cleanup()

cleanup()

patch_linearlayers(
model, patch_add_weight_param, {"device": model.device, "dtype": model.dtype}
Expand Down

0 comments on commit 085c15b

Please sign in to comment.