Skip to content

Commit e0da12a

Browse files
authored
AWQ Torch Fused Kernel (#2190)
* use torch.ops.aten fused ops for awq * cleanup * cleanup2 * float16 only * log * fused path * add gptq torch fused doc on layout * fix awq transformation * log rtol/atol * cleanup * cleanup2 * cleanup 3 * remove unused * inline methods * remove debug logs * avoid clone * merge code with gptq torch fused * cleanup, add XPU todo * make sure to test both xpu and cpu * xpu tests * fix xpu transform * cleanup * cleanup2 * cleanup 3 plus test logs * tabulate logs * prepare for v5.4.0 release
1 parent be7a43b commit e0da12a

File tree

10 files changed

+1028
-49
lines changed

10 files changed

+1028
-49
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
</p>
1818

1919
## Latest News
20+
* 11/9/2025 [5.4.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.4.0): ✨New Intel CPU and XPU hw optimized AWQ `TorchFusedAWQ` kernel. Torch Fused kernels now compatible with `torch.compile`. Fixed AWQ MoE model compatibility and reduced vram usage.
2021
* 11/3/2025 [5.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.2.0): 🎉Minimax M2 support with [ModelCloud BF16 M2 Model](https://huggingface.co/ModelCloud/MiniMax-M2-BF16). New `VramStrategy.Balanced` quantization property for reduced memory usage for large MoE on multi-3090 (24GB) devices. ✨Marin model. New AWQ Torch reference kernel. Fix AWQ Marlin kernel for bf16. Fix GLM 4.5/4.6 MoE missing `mtp` layers on model save (HF bug). Modular refractor. 🎉AWQ support out of beta with full feature support in including multi-gpu quant and MoE vram saving. ✨Brumby (attention free) model support. ✨Brumby (attention free) model support. ✨IBM Granite Nano support. New `calibration_concat_separator` config option.
2122
* 10/24/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by
2223
default to massively reduce `cpu` ram usage. New `Intel` and `AMD` cpu hw accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `Vram` pressure for large models reduced during quantization.
@@ -202,8 +203,8 @@ GPT-QModel is validated for Linux, MacOS, and Windows 11:
202203
|-----------------|---------------| --- | -------------- |-----------------------------------------------|
203204
| 🐧 Linux | Nvidia GPU || `Ampere+` | Marlin, Exllama V2, Exallma V1, Triton, Torch |
204205
| 🐧 Linux | AMD GPU || `7900XT+`, `ROCm 6.2+` | Exllama V2, Exallma V1, Torch |
205-
| 🐧 Linux | Intel XPU || `Arc`, `Datacenter Max` | Torch Fused (Python 2.8+), Torch |
206-
| 🐧 Linux | Intel/AMD CPU || `avx`, `amx`, `xmx` | Torch Fused (Python 2.8+), Torch |
206+
| 🐧 Linux | Intel XPU || `Arc`, `Datacenter Max` | TorchFused, TorchFusedAWQ, Torch |
207+
| 🐧 Linux | Intel/AMD CPU || `avx`, `amx`, `xmx` | TorchFused, TorchFusedAWQ, Torch |
207208
| 🍎 MacOS | GPU (Metal) / CPU || `Apple Silicon`, `M1+` | Torch, MLX via conversion |
208209
| 🪟 Windows | GPU (Nvidia) / CPU || `Nvidia` | Torch |
209210

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
```
2+
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
3+
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
4+
# SPDX-License-Identifier: Apache-2.0
5+
# Contact: qubitium@modelcloud.ai, x.com/qubitium
6+
```
7+
8+
# Torch Fused INT4 Transformations
9+
10+
This note explains what `TorchFusedQuantLinear.transform_xpu` and `transform_cpu`
11+
do to GPTQ-format tensors before calling the fused `torch.ops.aten` kernels.
12+
The goal is to document the exact tensor shapes, the axis permutations, and the
13+
bit packing order expected by `aten._weight_int4pack_mm_*` so you do not need to
14+
reverse engineer the loops in `gptqmodel/nn_modules/qlinear/torch_fused.py:175-219`.
15+
16+
## Terminology and starting layout
17+
18+
Let:
19+
20+
* `I` – number of input features.
21+
* `O` – number of output features.
22+
* `B` – quantization bits (always 4 here).
23+
* `W` – number of bits stored per lane in `pack_dtype` (`W = 32` by default).
24+
* `pack_factor = W / B` – how many quantized values share one lane (8 when `B=4`).
25+
* `group_size` – number of input channels that share one `(scale, zero)` pair.
26+
* `G = ceil(I / group_size)` – number of groups (and rows in `scales`/`qzeros`).
27+
28+
Immediately after loading a GPTQ v2 checkpoint:
29+
30+
```
31+
qweight : [I / pack_factor, O] dtype = pack_dtype (int32)
32+
qzeros : [G, O / pack_factor] dtype = pack_dtype (int32)
33+
scales : [G, O] dtype = fp16
34+
g_idx : [I] dtype = int32 (maps input channel -> group id)
35+
```
36+
37+
Each entry of `qweight`/`qzeros` is a 32-bit lane that packs `pack_factor`
38+
4-bit nibbles. Conceptually, a single column of `qweight` (one output channel)
39+
looks like this before unpacking:
40+
41+
```
42+
raw lane bits (int32) → [in_{k+7}] [in_{k+6}] … [in_{k+1}] [in_{k}]
43+
bit positions → 31..28 27..24 7..4 3..0
44+
```
45+
46+
## `transform_xpu(dtype)`
47+
48+
The XPU path needs tensors that match
49+
`aten._weight_int4pack_mm_with_scales_and_zeros`. The routine performs five
50+
steps:
51+
52+
1. **Scales cast**`self.scales = self.scales.clone().to(dtype)`. No layout changes.
53+
2. **Unpack `qzeros`** – expand each 32-bit lane into `pack_factor` nibbles, mask
54+
with `0xF`, then reshape to `[G, O]`.
55+
56+
```
57+
Before unpack (per group g):
58+
qzeros[g] = [ lane_0, lane_1, … ] (each lane holds 8 outputs)
59+
After unpack:
60+
zeros[g] = [ z_{0}, z_{1}, …, z_{O-1} ]
61+
62+
lane layout
63+
┌──────────── 32 bits ────────────┐
64+
| z_{b+7} | … | z_{b+1} | z_{b} |
65+
└────────────────────────────────┘ ← reshaped into consecutive columns
66+
```
67+
68+
3. **Unpack and reorder `qweight`** – identical nibble extraction produces a
69+
tensor shaped `[I, O]`. It is then re-indexed with `ret_idx` so that input
70+
rows follow the `g_idx` schedule used during quantization, and finally
71+
transposed to `[O, I]`. At this point every row corresponds to one output
72+
channel and every column corresponds to an *unpacked* input channel.
73+
74+
```
75+
weight_full (after transpose):
76+
input columns →
77+
┌───────────────────────────────────────────┐
78+
out0│ w00 w01 w02 w03 w04 w05 w06 w07 … w0(I-1) │
79+
out1│ w10 w11 w12 w13 w14 w15 w16 w17 … w1(I-1) │
80+
│ ⋮ │
81+
```
82+
83+
4. **Pack rows into XPU layout** – the double `for` loop rebuilds `int32`
84+
lanes, but now the rows are `O` (output channels) instead of packed input
85+
clusters. The resulting tensor has shape `[O, I / pack_factor]`.
86+
87+
```
88+
packed[row=j, col=k] stores inputs (8 values) =
89+
weight_full[j, 8k + i] for i = 0..7
90+
91+
31..28 27..24 23..20 19..16 15..12 11..8 7..4 3..0
92+
[in+7] [in+6] [in+5] [in+4] [in+3] [in+2] [in+1] [in+0]
93+
```
94+
95+
5. **Finalize buffers**`self.qweight = packed.contiguous()` (int32) and
96+
`self.qzeros = zeros.contiguous()` (float, `[G, O]`). These, together with
97+
`self.scales`, match the signature of
98+
`aten._weight_int4pack_mm_with_scales_and_zeros(x, qweight, group_size, scales, qzeros)`.
99+
100+
For XPU execution, `_fused_op_forward` also permutes activations before the
101+
matmul:
102+
103+
```
104+
x = x[:, ret_idx]
105+
```
106+
107+
This applies the inverse of the group-wise reordering performed in step 3,
108+
ensuring that column `i` of `qweight` always multiplies the same logical input
109+
channel the calibration used.
110+
111+
### Visual summary (XPU)
112+
113+
```
114+
┌─────────────┐ unpack+permute ┌─────────────┐
115+
raw qweight →│ I/8 × O │ ───────────────────────→ │ O × I │
116+
└─────────────┘ └─────────────┘
117+
pack rows ↓
118+
┌─────────────┐
119+
│ O × (I/8) │ int32 lanes
120+
└─────────────┘
121+
122+
raw qzeros → [G × O/8] lanes ──unpack──► zeros [G × O]
123+
scales → [G × O] (cast to `dtype`)
124+
```
125+
126+
## `transform_cpu(dtype)`
127+
128+
The CPU path shares the unpack/reorder logic but delegates the final packing to
129+
PyTorch’s helper so the layout matches
130+
`aten._weight_int4pack_mm_for_cpu`. Steps:
131+
132+
1. **Scales cast** – identical to the XPU path.
133+
2. **Unpack + reorder `qweight`** – same as step 3 above, yielding
134+
`weight_full = [O, I]` with 4-bit integers.
135+
3. **Convert to int4pack**`torch.ops.aten._convert_weight_to_int4pack_for_cpu`
136+
repacks that matrix into `torch.uint8` tiles of shape `[O, I * B / 8]`
137+
(i.e., `I/2` columns when `B=4`). Each byte stores two adjacent inputs.
138+
139+
```
140+
byte layout (per output row j):
141+
bits 7..4 → weight_full[j, 2k+1]
142+
bits 3..0 → weight_full[j, 2k]
143+
```
144+
145+
The helper currently requires both `O` and `I` to be multiples of 16; the op
146+
raises `_convert_weight_to_int4pack_cpu : expect N to be dividable by 16`
147+
otherwise.
148+
149+
4. **Merge scales and zeros** – The fused CPU kernel expects scale and zero
150+
offsets in a single tensor, so `pack_scales_and_zeros` stacks them along the
151+
last dimension:
152+
153+
```
154+
scales_and_zeros[g, o] = [ scale[g, o], zero[g, o] ] shape = [G, O, 2]
155+
156+
group g
157+
┌──────── out dimension ────────┐
158+
│ [ s, z ] [ s, z ] … [ s, z ] │
159+
└─────────────────────────────────┘
160+
```
161+
162+
The current GPTQ fused path only uses symmetric int4, so `self.qzeros` is
163+
zeroed before packing (`zero[g, o] = 0`). Non-symmetric per-group offsets
164+
would require extending this block.
165+
166+
5. **Buffers used at runtime**`self.qweight` is now the `uint8`
167+
int4pack tensor, `self.scales_and_zeros` stores the merged metadata, and
168+
`_fused_op_forward` calls
169+
`aten._weight_int4pack_mm_for_cpu(x, qweight_uint8, group_size, scales_and_zeros)`.
170+
171+
### Visual summary (CPU)
172+
173+
```
174+
weight_full (O × I, ints) ──_convert_weight_to_int4pack_for_cpu──►
175+
┌──────────────┐ ┌──────────────┐
176+
│ O × I │ │ O × (I/2) │ uint8
177+
└──────────────┘ └──────────────┘
178+
↑ ↑
179+
└───────── unpack & transpose from raw qweight ───────────┘
180+
181+
scales (G × O, dtype `dtype`)
182+
qzeros (G × O, zeroed) ──► scales_and_zeros (G × O × 2)
183+
```
184+
185+
## Activation permutation and fused matmul
186+
187+
Both device paths rely on the same activation permutation:
188+
189+
1. `ret_idx` is built once from `g_idx` so that unpacked rows can be restored to
190+
the calibration order.
191+
2. Before calling any fused matmul, `_fused_op_forward` applies `x = x[:, ret_idx]`.
192+
3. The matmul then multiplies `x` with the packed `qweight`:
193+
194+
* XPU: `aten._weight_int4pack_mm_with_scales_and_zeros`
195+
consumes `qweight[int32][O, I/8]`, `scales[G, O]`, and `qzeros[G, O]`.
196+
* CPU: `aten._weight_int4pack_mm_for_cpu`
197+
consumes `qweight[uint8][O, I/2]` and `scales_and_zeros[G, O, 2]`.
198+
199+
Because the same `ret_idx` is used for both the unpacked weight (during packing)
200+
and the activation tensor (during inference), every nibble in the packed matrix
201+
aligns with the correct logical input column.
202+
203+
## Comparing XPU vs CPU transformations
204+
205+
Although both device paths share the same unpack → reorder → transpose steps,
206+
they diverge in how the packed tensors are laid out and what the fused matmul
207+
expects afterward. The table below highlights the key differences for quick
208+
debugging.
209+
210+
| Aspect | XPU (`transform_xpu`) | CPU (`transform_cpu`) |
211+
|----------------------------|---------------------------------------------------------------|-------------------------------------------------------------------|
212+
| Packed `qweight` shape | `[O, I / 8]`, dtype `int32` | `[O, I / 2]`, dtype `uint8` |
213+
| Bits per storage lane | 32-bit lane packs 8 inputs; nibble order `[in+7 … in+0]` | 8-bit lane packs 2 inputs; high nibble = odd, low nibble = even |
214+
| Packing direction | Manual double-loop packs along **columns** of `weight_full` | `_convert_weight_to_int4pack_for_cpu` packs along **columns** into bytes |
215+
| Per-group zeros | Unpacked to full `[G, O]` tensor and passed separately | Forced to zero and merged with scales via `pack_scales_and_zeros` |
216+
| Scale format | One tensor per group (`scales[G, O]`) | Concatenated `[..., 0] = scale`, `[..., 1] = zero` (`float`) |
217+
| Fused kernel call | `_weight_int4pack_mm_with_scales_and_zeros(x, qW, gsz, s, z)` | `_weight_int4pack_mm_for_cpu(x, qW, gsz, scales_and_zeros)` |
218+
| Alignment requirements | Determined by manual pack loop (only needs `I % 8 == 0`) | Kernel enforces `I % 16 == 0` and `O % 16 == 0` |
219+
| Activation permutation | `x = x[:, ret_idx]` prior to matmul (same code path) | Same permutation reuse |
220+
221+
Visually, you can think of the difference as *row-major lane packing* (XPU)
222+
versus *byte-tiling* (CPU):
223+
224+
```
225+
XPU: | int32 lane | = [w7][w6][w5][w4][w3][w2][w1][w0]
226+
CPU: | uint8 lane | = [w1][w0]
227+
```
228+
229+
Both forms originate from the same `[O, I]` intermediate; the divergence is only
230+
in the final storage type, accompanying metadata, and fused operator ABI.
231+
232+
## AWQ compatibility (`torch_fused_awq.py`)
233+
234+
`TorchFusedAwqQuantLinear` (`gptqmodel/nn_modules/qlinear/torch_fused_awq.py`)
235+
reuses the CPU fused kernel while accepting checkpoints emitted by the AWQ
236+
tooling. The module always expects `qweight` to be stored in the AWQ layout
237+
`[in_features, out_features / pack_factor]`, meaning each row corresponds to a
238+
single logical input channel. `transform_cpu_awq` performs a fixed shim before
239+
the standard CPU packing runs:
240+
241+
1. **Unpack AWQ rows**`unpack_awq` expands each column lane into eight
242+
outputs, yielding `iweight[int8][I, O]` and `izeros[int8][G, O]`. Both
243+
tensors are then permuted with `reverse_awq_order` (the inverse of
244+
`quantization.awq.utils.packing_utils.AWQ_ORDER`) so the columns match the
245+
logical transformer layout expected by GPTQ.
246+
2. **Normalize zero codes** – AWQ stores integer zero points per output channel.
247+
`transform_cpu_awq` converts them into floating offsets compatible with the
248+
fused kernel using
249+
`zeros_fp16 = (2^{bits-1} - izeros) * scales_fp32`, keeping the result in
250+
`float16` so the metadata matches the original AWQ calibration statistics.
251+
3. **Repack into GPTQ lanes** – The unpacked `iweight` matrix is reshaped to
252+
`[I / pack_factor, pack_factor, O]` and re-packed along the `pack_factor`
253+
dimension so each row once again represents eight inputs inside a 32-bit
254+
lane. After this step `self.qweight` is indistinguishable from a GPTQ v2
255+
tensor, which means the regular `transform_cpu` logic can run unchanged.
256+
4. **Delegate to the base CPU transform** – Calling `super().transform_cpu`
257+
converts the temporary GPTQ-formatted `qweight` into the `[O, I/2]` `uint8`
258+
int4pack layout and produces `scales_and_zeros` from the (temporarily zeroed)
259+
metadata.
260+
5. **Restore AWQ metadata** – Immediately afterward, the AWQ shim reinstates
261+
the real `float16` scales and the converted zero offsets, then rebuilds
262+
`scales_and_zeros = pack_scales_and_zeros(scales, zeros_fp16)`. This ensures
263+
`_weight_int4pack_mm_for_cpu` receives the same affine parameters the AWQ
264+
calibration solved for.
265+
266+
Because the shim runs entirely on the CPU path, `TorchFusedAwqQuantLinear`
267+
currently raises `NotImplementedError` when asked to run the fused transform on
268+
`xpu` devices. If the module has not been transformed yet (or fused ops are
269+
unavailable), inference falls back to the dense AWQ matmul computed by
270+
`awq_weight_dequantize`, which simply dequantizes the cached AWQ tensors on the fly.
271+
272+
## Quick reference
273+
274+
| Stage | Shape / dtype (int4) | Notes |
275+
|--------------------------------|-----------------------------------------------------------|------------------------------------------------|
276+
| Raw `qweight` | `[I / 8, O]`, `int32` | 8 nibbles per lane |
277+
| After unpack + transpose | `[O, I]`, `int8` (values in `[0, 15]`) | Used by both device paths |
278+
| Packed XPU `qweight` | `[O, I / 8]`, `int32` | Bits `[3:0]` hold the lowest-numbered channel |
279+
| Packed CPU `qweight` | `[O, I / 2]`, `uint8` | High nibble = odd input, low nibble = even |
280+
| `qzeros` (post-XPU transform) | `[G, O]`, matches `scales` | Passed separately to the XPU fused op |
281+
| `scales_and_zeros` (CPU only) | `[G, O, 2]`, float | `[..., 0] = scale`, `[..., 1] = zero` |
282+
| Raw AWQ `qweight` | `[I, O / 8]`, `int32` | Rows are single inputs packed across outputs |
283+
| Unpacked AWQ weights/zeros | `iweight[I, O]`, `izeros[G, O]`, `int8` | Produced by `unpack_awq` + `reverse_awq_order` |
284+
| AWQ zero offsets (final) | `[G, O]`, `float16` | `(2^{bits-1} - izeros) * scales`; merged via `pack_scales_and_zeros` |
285+
286+
These details mirror the expectations of the Intel XPU and CPU fused matmul
287+
kernels, and the ASCII layouts above describe how rows/columns line up inside
288+
every packed tensor before the fused matmul executes.

0 commit comments

Comments
 (0)