Skip to content

Commit

Permalink
CI job. Gpt awq 4 (#2665)
Browse files Browse the repository at this point in the history
* add gptq and awq int4 support in intel platform

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix ci failure

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* set kv cache dtype

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine the code according to the review command

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Simplifying conditionals + reverting integration tests values.

* Unused import

* Fix redundant import.

* Revert change after rebase.

* Upgrading the tests (TP>1 fix changes to use different kernels.)

* Update server/text_generation_server/layers/gptq/__init__.py

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
Narsil and sywangyi authored Oct 18, 2024
1 parent 8ec5755 commit 153ff37
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,65 @@
},
{
"id": 3226,
"logprob": -8.9453125,
"logprob": -9.0234375,
"text": " ge"
},
{
"id": 21017,
"logprob": -8.8515625,
"logprob": -9.0859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.21875,
"logprob": -0.25585938,
"text": "_"
},
{
"id": 6009,
"logprob": -1.2773438,
"logprob": -2.1972656,
"text": "mean"
},
{
"id": 26,
"logprob": -0.25195312,
"logprob": -0.2998047,
"text": "("
},
{
"id": 62,
"logprob": -4.8203125,
"logprob": -5.6445312,
"text": "L"
},
{
"id": 44,
"logprob": -3.7734375,
"logprob": -3.0839844,
"text": ":"
},
{
"id": 1682,
"logprob": -0.8310547,
"logprob": -0.6748047,
"text": " List"
},
{
"id": 77,
"logprob": -0.22766113,
"logprob": -0.3864746,
"text": "["
},
{
"id": 1808,
"logprob": -0.46240234,
"logprob": -0.9355469,
"text": "float"
},
{
"id": 10794,
"logprob": -3.0234375,
"logprob": -2.5371094,
"text": "]):"
}
],
"seed": null,
"tokens": [
{
"id": 284,
"logprob": -0.04626465,
"logprob": -1.1679688,
"special": false,
"text": "\n "
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,65 @@
},
{
"id": 3226,
"logprob": -8.9453125,
"logprob": -9.015625,
"text": " ge"
},
{
"id": 21017,
"logprob": -8.859375,
"logprob": -9.0859375,
"text": "ometric"
},
{
"id": 81,
"logprob": -0.21984863,
"logprob": -0.25585938,
"text": "_"
},
{
"id": 6009,
"logprob": -1.2861328,
"logprob": -2.2304688,
"text": "mean"
},
{
"id": 26,
"logprob": -0.25219727,
"logprob": -0.29760742,
"text": "("
},
{
"id": 62,
"logprob": -4.8007812,
"logprob": -5.6796875,
"text": "L"
},
{
"id": 44,
"logprob": -3.7949219,
"logprob": -3.0742188,
"text": ":"
},
{
"id": 1682,
"logprob": -0.8046875,
"logprob": -0.67626953,
"text": " List"
},
{
"id": 77,
"logprob": -0.22424316,
"logprob": -0.38842773,
"text": "["
},
{
"id": 1808,
"logprob": -0.46191406,
"logprob": -0.9165039,
"text": "float"
},
{
"id": 10794,
"logprob": -3.0253906,
"logprob": -2.5527344,
"text": "]):"
}
],
"seed": 0,
"tokens": [
{
"id": 284,
"logprob": 0.0,
"logprob": -0.048583984,
"special": false,
"text": "\n "
},
Expand Down
8 changes: 8 additions & 0 deletions server/text_generation_server/layers/awq/quantize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from text_generation_server.utils.import_utils import SYSTEM

if SYSTEM == "ipex":
from .ipex import WQLinear
elif SYSTEM == "cuda":
from .cuda import WQLinear

__all__ = ["WQLinear"]
48 changes: 48 additions & 0 deletions server/text_generation_server/layers/awq/quantize/ipex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional
import torch
import torch.nn as nn
import intel_extension_for_pytorch as ipex


class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()

if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")

self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit

self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0

self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self.woq_linear = (
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
self.qweight,
self.scales,
self.qzeros,
self.in_features,
self.out_features,
bias=self.bias,
group_size=self.group_size,
quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,
dtype=ipex.llm.quantization.QuantDtype.INT4,
)
)

@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
22 changes: 14 additions & 8 deletions server/text_generation_server/layers/gptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader

if SYSTEM == "ipex":
from .ipex import QuantLinear
elif SYSTEM == "cuda":
from .cuda import QuantLinear


@dataclass
class GPTQWeight(Weight):
Expand Down Expand Up @@ -36,7 +41,7 @@ def get_linear(self, bias: torch.Tensor):
"to use Exllama/GPTQ kernels for AWQ inference."
)
try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
from text_generation_server.layers.awq.quantize import WQLinear

return WQLinear(
w_bit=self.bits,
Expand All @@ -60,8 +65,6 @@ def get_linear(self, bias: torch.Tensor):

return ExllamaQuantLinear(self, bias)
else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear

return QuantLinear(
self.qweight,
self.qzeros,
Expand Down Expand Up @@ -298,6 +301,7 @@ def get_weights_row(self, weights: Weights, prefix: str):
self._get_gptq_params(weights)

use_exllama = True
desc_act = self.desc_act
if self.bits != 4:
use_exllama = False

Expand All @@ -321,7 +325,8 @@ def get_weights_row(self, weights: Weights, prefix: str):
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
# Remove g_idx[0] to adapt the check with TP>1.
(g_idx - g_idx[0]).cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
Expand All @@ -332,6 +337,7 @@ def get_weights_row(self, weights: Weights, prefix: str):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
desc_act = True

from text_generation_server.layers.gptq import (
CAN_EXLLAMA,
Expand All @@ -350,16 +356,16 @@ def get_weights_row(self, weights: Weights, prefix: str):
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")

if use_exllama and self.groupsize != -1:
if not desc_act and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
if g_idx is not None:
# qzeros, scales sharded, and g_idx must be adjusted accordingly
g_idx = g_idx - g_idx[0]
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")

if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]

if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
Expand Down
Loading

0 comments on commit 153ff37

Please sign in to comment.