Skip to content

Commit 8cc7364

Browse files
committed
cleanup 3
1 parent 3ecbdbe commit 8cc7364

File tree

2 files changed

+125
-53
lines changed

2 files changed

+125
-53
lines changed

gptqmodel/nn_modules/qlinear/torch_fused_awq.py

Lines changed: 63 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# SPDX-License-Identifier: Apache-2.0
44
# Contact: qubitium@modelcloud.ai, x.com/qubitium
55

6+
import math
7+
68
import torch
79

810
from ...adapter.adapter import Adapter
@@ -21,7 +23,7 @@
2123

2224

2325
class TorchFusedAwqQuantLinear(TorchFusedQuantLinear):
24-
"""Torch fused AWQ variant that reuses the GPTQ fused kernels via CPU int4 packing."""
26+
"""Torch fused AWQ variant based on GPTQ fused kernels via CPU int4 packing."""
2527

2628
QUANT_TYPE = "torch_fused_awq"
2729
SUPPORTS_BITS = TorchFusedQuantLinear.SUPPORTS_BITS
@@ -66,62 +68,72 @@ def __init__(
6668
bias=bias,
6769
pack_dtype=pack_dtype,
6870
adapter=adapter,
69-
register_buffers=register_buffers,
71+
register_buffers=False,
7072
**kwargs,
7173
)
74+
if register_buffers:
75+
qweight_shape = self._awq_qweight_shape()
76+
group_size = max(int(self.group_size), 1)
77+
group_rows = self._awq_group_count()
78+
pack_cols = qweight_shape[1]
7279

73-
def _load_from_state_dict(
74-
self,
75-
state_dict,
76-
prefix,
77-
local_metadata,
78-
strict,
79-
missing_keys,
80-
unexpected_keys,
81-
error_msgs,
82-
):
83-
qweight_key = prefix + "qweight"
84-
awq_tensor = None
85-
if qweight_key in state_dict:
86-
candidate = state_dict[qweight_key]
87-
if not torch.is_tensor(candidate):
88-
raise TypeError(f"{qweight_key} must be a tensor to load AWQ weights.")
89-
awq_tensor = candidate.to(self.pack_dtype).clone()
90-
expected_rows = self.in_features
91-
expected_cols = max(1, self.out_features // self.pack_factor)
92-
if awq_tensor.shape != (expected_rows, expected_cols):
93-
raise ValueError(
94-
f"{self.__class__.__name__} expects AWQ qweight shape "
95-
f"{(expected_rows, expected_cols)}, but received {tuple(awq_tensor.shape)}."
96-
)
97-
placeholder = getattr(self, "qweight", None)
98-
if isinstance(placeholder, torch.Tensor) and placeholder.numel() == awq_tensor.numel():
99-
state_dict[qweight_key] = torch.zeros_like(placeholder)
100-
else:
101-
rows = max(1, self.in_features // self.pack_factor)
102-
cols = self.out_features
103-
state_dict[qweight_key] = torch.zeros(
104-
(rows, cols),
105-
dtype=self.pack_dtype,
106-
device=awq_tensor.device,
107-
)
108-
super()._load_from_state_dict(
109-
state_dict,
110-
prefix,
111-
local_metadata,
112-
strict,
113-
missing_keys,
114-
unexpected_keys,
115-
error_msgs,
116-
)
117-
if awq_tensor is not None:
118-
state_dict[qweight_key] = awq_tensor
119-
device = getattr(self, "qweight", awq_tensor).device
12080
self.register_buffer(
12181
"qweight",
122-
awq_tensor.to(device=device, dtype=self.pack_dtype).contiguous(),
123-
persistent=True,
82+
torch.zeros(qweight_shape, dtype=self.pack_dtype),
12483
)
84+
self.register_buffer(
85+
"qzeros",
86+
torch.zeros((group_rows, pack_cols), dtype=self.pack_dtype),
87+
)
88+
self.register_buffer(
89+
"scales",
90+
torch.zeros((group_rows, self.out_features), dtype=torch.float16),
91+
)
92+
g_idx = torch.arange(self.in_features, dtype=torch.int32) // group_size
93+
self.register_buffer("g_idx", g_idx)
94+
if bias:
95+
self.register_buffer("bias", torch.zeros(self.out_features, dtype=torch.float16))
96+
else:
97+
self.bias = None
98+
99+
def _awq_qweight_shape(self):
100+
pack_cols = max(1, self.out_features // self.pack_factor)
101+
return self.in_features, pack_cols
102+
103+
def _awq_group_count(self):
104+
group_size = max(int(self.group_size), 1)
105+
return max(1, math.ceil(self.in_features / group_size))
106+
107+
# def _load_from_state_dict(
108+
# self,
109+
# state_dict,
110+
# prefix,
111+
# local_metadata,
112+
# strict,
113+
# missing_keys,
114+
# unexpected_keys,
115+
# error_msgs,
116+
# ):
117+
# self.register_awq_buffers()
118+
# super()._load_from_state_dict(
119+
# state_dict,
120+
# prefix,
121+
# local_metadata,
122+
# strict,
123+
# missing_keys,
124+
# unexpected_keys,
125+
# error_msgs,
126+
# )
127+
# qweight = getattr(self, "qweight", None)
128+
# if torch.is_tensor(qweight):
129+
# expected_shape = self._awq_qweight_shape()
130+
# if tuple(qweight.shape) != expected_shape:
131+
# raise ValueError(
132+
# f"{self.__class__.__name__} only loads AWQ-formatted qweight tensors with "
133+
# f"shape {expected_shape}, but received {tuple(qweight.shape)}."
134+
# )
135+
# if qweight.dtype != self.pack_dtype:
136+
# self.qweight = qweight.to(dtype=self.pack_dtype).contiguous()
125137

126138
def transform_cpu_awq(self, dtype):
127139
src_scales = self.scales

tests/test_kernel_output_awq.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
marlin_import_exception,
2323
)
2424
from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear
25+
from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear
2526
from gptqmodel.utils.marlin import marlin_make_workspace_new
2627

2728

@@ -30,6 +31,7 @@
3031
log = LogBar.shared()
3132

3233
DEVICE = torch.device("cuda:0")
34+
CPU_DEVICE = torch.device("cpu")
3335

3436
GREEN = "\033[32m"
3537
RED = "\033[31m"
@@ -50,6 +52,7 @@ class TestAwqKernelOutput(unittest.TestCase):
5052
(BACKEND.GEMM, torch.float16, 0.004),
5153
# (BACKEND.GEMM, torch.bfloat16, 0.05),
5254
(BACKEND.MARLIN, torch.float16, 0.006),
55+
(BACKEND.TORCH_FUSED_AWQ, torch.float16, 0.004),
5356
# (BACKEND.MARLIN, torch.bfloat16, 0.05),
5457
]
5558

@@ -92,6 +95,16 @@ def setUpClass(cls) -> None:
9295
qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu
9396
)
9497

98+
try:
99+
cls.modules[BACKEND.TORCH_FUSED_AWQ] = cls._build_torch_fused_awq_module(
100+
qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu
101+
)
102+
except Exception as exc:
103+
cls.backend_skip_reason[BACKEND.TORCH_FUSED_AWQ] = (
104+
f"Torch fused AWQ kernel unavailable: {exc}"
105+
)
106+
cls.modules[BACKEND.TORCH_FUSED_AWQ] = None
107+
95108
base_inputs = cls._generate_inputs()
96109
cls.inputs: Dict[torch.dtype, List[torch.Tensor]] = {}
97110
cls.reference_outputs: Dict[torch.dtype, List[torch.Tensor]] = {}
@@ -247,6 +260,35 @@ def _build_torch_awq_module(
247260
module.post_init()
248261
return module
249262

263+
@classmethod
264+
def _build_torch_fused_awq_module(
265+
cls,
266+
qweight_cpu: torch.Tensor,
267+
qzeros_cpu: torch.Tensor,
268+
scales_cpu: torch.Tensor,
269+
bias_cpu: torch.Tensor,
270+
) -> TorchFusedAwqQuantLinear:
271+
module = TorchFusedAwqQuantLinear(
272+
bits=cls.BITS,
273+
group_size=cls.GROUP_SIZE,
274+
sym=True,
275+
desc_act=False,
276+
in_features=cls.in_features,
277+
out_features=cls.out_features,
278+
bias=True,
279+
adapter=None,
280+
register_buffers=True,
281+
).to(CPU_DEVICE)
282+
283+
module.qweight.copy_(qweight_cpu.to(CPU_DEVICE))
284+
module.qzeros.copy_(qzeros_cpu.to(CPU_DEVICE))
285+
module.scales.copy_(scales_cpu.to(torch.float16).to(CPU_DEVICE))
286+
module.bias.copy_(bias_cpu.to(torch.float16).to(CPU_DEVICE))
287+
288+
module.eval()
289+
module.post_init()
290+
return module
291+
250292
@classmethod
251293
def _generate_inputs(cls) -> List[torch.Tensor]:
252294
large_shapes = [(4, 32), (2, 64), (1, 96)]
@@ -288,19 +330,37 @@ def _forward(
288330
*,
289331
compute_dtype: Optional[torch.dtype] = None,
290332
output_dtype: Optional[torch.dtype] = None,
333+
target_device: Optional[torch.device] = None,
291334
) -> List[torch.Tensor]:
335+
if target_device is None:
336+
target_device = cls._infer_module_device(module)
292337
outputs: List[torch.Tensor] = []
293338
with torch.inference_mode():
294339
for tensor in inputs:
295340
local_tensor = tensor
296-
if compute_dtype is not None and tensor.dtype != compute_dtype:
297-
local_tensor = tensor.to(dtype=compute_dtype)
341+
if local_tensor.device != target_device:
342+
local_tensor = local_tensor.to(device=target_device)
343+
if compute_dtype is not None and local_tensor.dtype != compute_dtype:
344+
local_tensor = local_tensor.to(dtype=compute_dtype)
298345
result = module(local_tensor)
299346
if output_dtype is not None and result.dtype != output_dtype:
300347
result = result.to(dtype=output_dtype)
301348
outputs.append(result.detach().cpu())
302349
return outputs
303350

351+
@staticmethod
352+
def _infer_module_device(module: torch.nn.Module) -> torch.device:
353+
try:
354+
tensor = next(module.parameters())
355+
return tensor.device
356+
except StopIteration:
357+
pass
358+
try:
359+
tensor = next(module.buffers())
360+
return tensor.device
361+
except StopIteration:
362+
return torch.device("cpu")
363+
304364
def _maybe_skip_backend(self, backend: BACKEND) -> None:
305365
reason = self.backend_skip_reason.get(backend)
306366
if reason:

0 commit comments

Comments
 (0)