|
| 1 | +``` |
| 2 | +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai |
| 3 | +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai |
| 4 | +# SPDX-License-Identifier: Apache-2.0 |
| 5 | +# Contact: qubitium@modelcloud.ai, x.com/qubitium |
| 6 | +``` |
| 7 | + |
| 8 | +# Torch Fused INT4 Transformations |
| 9 | + |
| 10 | +This note explains what `TorchFusedQuantLinear.transform_xpu` and `transform_cpu` |
| 11 | +do to GPTQ-format tensors before calling the fused `torch.ops.aten` kernels. |
| 12 | +The goal is to document the exact tensor shapes, the axis permutations, and the |
| 13 | +bit packing order expected by `aten._weight_int4pack_mm_*` so you do not need to |
| 14 | +reverse engineer the loops in `gptqmodel/nn_modules/qlinear/torch_fused.py:175-219`. |
| 15 | + |
| 16 | +## Terminology and starting layout |
| 17 | + |
| 18 | +Let: |
| 19 | + |
| 20 | +* `I` – number of input features. |
| 21 | +* `O` – number of output features. |
| 22 | +* `B` – quantization bits (always 4 here). |
| 23 | +* `W` – number of bits stored per lane in `pack_dtype` (`W = 32` by default). |
| 24 | +* `pack_factor = W / B` – how many quantized values share one lane (8 when `B=4`). |
| 25 | +* `group_size` – number of input channels that share one `(scale, zero)` pair. |
| 26 | +* `G = ceil(I / group_size)` – number of groups (and rows in `scales`/`qzeros`). |
| 27 | + |
| 28 | +Immediately after loading a GPTQ v2 checkpoint: |
| 29 | + |
| 30 | +``` |
| 31 | +qweight : [I / pack_factor, O] dtype = pack_dtype (int32) |
| 32 | +qzeros : [G, O / pack_factor] dtype = pack_dtype (int32) |
| 33 | +scales : [G, O] dtype = fp16 |
| 34 | +g_idx : [I] dtype = int32 (maps input channel -> group id) |
| 35 | +``` |
| 36 | + |
| 37 | +Each entry of `qweight`/`qzeros` is a 32-bit lane that packs `pack_factor` |
| 38 | +4-bit nibbles. Conceptually, a single column of `qweight` (one output channel) |
| 39 | +looks like this before unpacking: |
| 40 | + |
| 41 | +``` |
| 42 | +raw lane bits (int32) → [in_{k+7}] [in_{k+6}] … [in_{k+1}] [in_{k}] |
| 43 | +bit positions → 31..28 27..24 7..4 3..0 |
| 44 | +``` |
| 45 | + |
| 46 | +## `transform_xpu(dtype)` |
| 47 | + |
| 48 | +The XPU path needs tensors that match |
| 49 | +`aten._weight_int4pack_mm_with_scales_and_zeros`. The routine performs five |
| 50 | +steps: |
| 51 | + |
| 52 | +1. **Scales cast** – `self.scales = self.scales.clone().to(dtype)`. No layout changes. |
| 53 | +2. **Unpack `qzeros`** – expand each 32-bit lane into `pack_factor` nibbles, mask |
| 54 | + with `0xF`, then reshape to `[G, O]`. |
| 55 | + |
| 56 | + ``` |
| 57 | + Before unpack (per group g): |
| 58 | + qzeros[g] = [ lane_0, lane_1, … ] (each lane holds 8 outputs) |
| 59 | + After unpack: |
| 60 | + zeros[g] = [ z_{0}, z_{1}, …, z_{O-1} ] |
| 61 | +
|
| 62 | + lane layout |
| 63 | + ┌──────────── 32 bits ────────────┐ |
| 64 | + | z_{b+7} | … | z_{b+1} | z_{b} | |
| 65 | + └────────────────────────────────┘ ← reshaped into consecutive columns |
| 66 | + ``` |
| 67 | + |
| 68 | +3. **Unpack and reorder `qweight`** – identical nibble extraction produces a |
| 69 | + tensor shaped `[I, O]`. It is then re-indexed with `ret_idx` so that input |
| 70 | + rows follow the `g_idx` schedule used during quantization, and finally |
| 71 | + transposed to `[O, I]`. At this point every row corresponds to one output |
| 72 | + channel and every column corresponds to an *unpacked* input channel. |
| 73 | + |
| 74 | + ``` |
| 75 | + weight_full (after transpose): |
| 76 | + input columns → |
| 77 | + ┌───────────────────────────────────────────┐ |
| 78 | + out0│ w00 w01 w02 w03 w04 w05 w06 w07 … w0(I-1) │ |
| 79 | + out1│ w10 w11 w12 w13 w14 w15 w16 w17 … w1(I-1) │ |
| 80 | + │ ⋮ │ |
| 81 | + ``` |
| 82 | + |
| 83 | +4. **Pack rows into XPU layout** – the double `for` loop rebuilds `int32` |
| 84 | + lanes, but now the rows are `O` (output channels) instead of packed input |
| 85 | + clusters. The resulting tensor has shape `[O, I / pack_factor]`. |
| 86 | + |
| 87 | + ``` |
| 88 | + packed[row=j, col=k] stores inputs (8 values) = |
| 89 | + weight_full[j, 8k + i] for i = 0..7 |
| 90 | +
|
| 91 | + 31..28 27..24 23..20 19..16 15..12 11..8 7..4 3..0 |
| 92 | + [in+7] [in+6] [in+5] [in+4] [in+3] [in+2] [in+1] [in+0] |
| 93 | + ``` |
| 94 | + |
| 95 | +5. **Finalize buffers** – `self.qweight = packed.contiguous()` (int32) and |
| 96 | + `self.qzeros = zeros.contiguous()` (float, `[G, O]`). These, together with |
| 97 | + `self.scales`, match the signature of |
| 98 | + `aten._weight_int4pack_mm_with_scales_and_zeros(x, qweight, group_size, scales, qzeros)`. |
| 99 | + |
| 100 | +For XPU execution, `_fused_op_forward` also permutes activations before the |
| 101 | +matmul: |
| 102 | + |
| 103 | +``` |
| 104 | +x = x[:, ret_idx] |
| 105 | +``` |
| 106 | + |
| 107 | +This applies the inverse of the group-wise reordering performed in step 3, |
| 108 | +ensuring that column `i` of `qweight` always multiplies the same logical input |
| 109 | +channel the calibration used. |
| 110 | + |
| 111 | +### Visual summary (XPU) |
| 112 | + |
| 113 | +``` |
| 114 | + ┌─────────────┐ unpack+permute ┌─────────────┐ |
| 115 | +raw qweight →│ I/8 × O │ ───────────────────────→ │ O × I │ |
| 116 | + └─────────────┘ └─────────────┘ |
| 117 | + pack rows ↓ |
| 118 | + ┌─────────────┐ |
| 119 | + │ O × (I/8) │ int32 lanes |
| 120 | + └─────────────┘ |
| 121 | +
|
| 122 | +raw qzeros → [G × O/8] lanes ──unpack──► zeros [G × O] |
| 123 | +scales → [G × O] (cast to `dtype`) |
| 124 | +``` |
| 125 | + |
| 126 | +## `transform_cpu(dtype)` |
| 127 | + |
| 128 | +The CPU path shares the unpack/reorder logic but delegates the final packing to |
| 129 | +PyTorch’s helper so the layout matches |
| 130 | +`aten._weight_int4pack_mm_for_cpu`. Steps: |
| 131 | + |
| 132 | +1. **Scales cast** – identical to the XPU path. |
| 133 | +2. **Unpack + reorder `qweight`** – same as step 3 above, yielding |
| 134 | + `weight_full = [O, I]` with 4-bit integers. |
| 135 | +3. **Convert to int4pack** – `torch.ops.aten._convert_weight_to_int4pack_for_cpu` |
| 136 | + repacks that matrix into `torch.uint8` tiles of shape `[O, I * B / 8]` |
| 137 | + (i.e., `I/2` columns when `B=4`). Each byte stores two adjacent inputs. |
| 138 | + |
| 139 | + ``` |
| 140 | + byte layout (per output row j): |
| 141 | + bits 7..4 → weight_full[j, 2k+1] |
| 142 | + bits 3..0 → weight_full[j, 2k] |
| 143 | + ``` |
| 144 | + |
| 145 | + The helper currently requires both `O` and `I` to be multiples of 16; the op |
| 146 | + raises `_convert_weight_to_int4pack_cpu : expect N to be dividable by 16` |
| 147 | + otherwise. |
| 148 | + |
| 149 | +4. **Merge scales and zeros** – The fused CPU kernel expects scale and zero |
| 150 | + offsets in a single tensor, so `pack_scales_and_zeros` stacks them along the |
| 151 | + last dimension: |
| 152 | + |
| 153 | + ``` |
| 154 | + scales_and_zeros[g, o] = [ scale[g, o], zero[g, o] ] shape = [G, O, 2] |
| 155 | +
|
| 156 | + group g |
| 157 | + ┌──────── out dimension ────────┐ |
| 158 | + │ [ s, z ] [ s, z ] … [ s, z ] │ |
| 159 | + └─────────────────────────────────┘ |
| 160 | + ``` |
| 161 | + |
| 162 | + The current GPTQ fused path only uses symmetric int4, so `self.qzeros` is |
| 163 | + zeroed before packing (`zero[g, o] = 0`). Non-symmetric per-group offsets |
| 164 | + would require extending this block. |
| 165 | + |
| 166 | +5. **Buffers used at runtime** – `self.qweight` is now the `uint8` |
| 167 | + int4pack tensor, `self.scales_and_zeros` stores the merged metadata, and |
| 168 | + `_fused_op_forward` calls |
| 169 | + `aten._weight_int4pack_mm_for_cpu(x, qweight_uint8, group_size, scales_and_zeros)`. |
| 170 | + |
| 171 | +### Visual summary (CPU) |
| 172 | + |
| 173 | +``` |
| 174 | +weight_full (O × I, ints) ──_convert_weight_to_int4pack_for_cpu──► |
| 175 | +┌──────────────┐ ┌──────────────┐ |
| 176 | +│ O × I │ │ O × (I/2) │ uint8 |
| 177 | +└──────────────┘ └──────────────┘ |
| 178 | + ↑ ↑ |
| 179 | + └───────── unpack & transpose from raw qweight ───────────┘ |
| 180 | +
|
| 181 | +scales (G × O, dtype `dtype`) |
| 182 | +qzeros (G × O, zeroed) ──► scales_and_zeros (G × O × 2) |
| 183 | +``` |
| 184 | + |
| 185 | +## Activation permutation and fused matmul |
| 186 | + |
| 187 | +Both device paths rely on the same activation permutation: |
| 188 | + |
| 189 | +1. `ret_idx` is built once from `g_idx` so that unpacked rows can be restored to |
| 190 | + the calibration order. |
| 191 | +2. Before calling any fused matmul, `_fused_op_forward` applies `x = x[:, ret_idx]`. |
| 192 | +3. The matmul then multiplies `x` with the packed `qweight`: |
| 193 | + |
| 194 | + * XPU: `aten._weight_int4pack_mm_with_scales_and_zeros` |
| 195 | + consumes `qweight[int32][O, I/8]`, `scales[G, O]`, and `qzeros[G, O]`. |
| 196 | + * CPU: `aten._weight_int4pack_mm_for_cpu` |
| 197 | + consumes `qweight[uint8][O, I/2]` and `scales_and_zeros[G, O, 2]`. |
| 198 | + |
| 199 | +Because the same `ret_idx` is used for both the unpacked weight (during packing) |
| 200 | +and the activation tensor (during inference), every nibble in the packed matrix |
| 201 | +aligns with the correct logical input column. |
| 202 | + |
| 203 | +## Comparing XPU vs CPU transformations |
| 204 | + |
| 205 | +Although both device paths share the same unpack → reorder → transpose steps, |
| 206 | +they diverge in how the packed tensors are laid out and what the fused matmul |
| 207 | +expects afterward. The table below highlights the key differences for quick |
| 208 | +debugging. |
| 209 | + |
| 210 | +| Aspect | XPU (`transform_xpu`) | CPU (`transform_cpu`) | |
| 211 | +|----------------------------|---------------------------------------------------------------|-------------------------------------------------------------------| |
| 212 | +| Packed `qweight` shape | `[O, I / 8]`, dtype `int32` | `[O, I / 2]`, dtype `uint8` | |
| 213 | +| Bits per storage lane | 32-bit lane packs 8 inputs; nibble order `[in+7 … in+0]` | 8-bit lane packs 2 inputs; high nibble = odd, low nibble = even | |
| 214 | +| Packing direction | Manual double-loop packs along **columns** of `weight_full` | `_convert_weight_to_int4pack_for_cpu` packs along **columns** into bytes | |
| 215 | +| Per-group zeros | Unpacked to full `[G, O]` tensor and passed separately | Forced to zero and merged with scales via `pack_scales_and_zeros` | |
| 216 | +| Scale format | One tensor per group (`scales[G, O]`) | Concatenated `[..., 0] = scale`, `[..., 1] = zero` (`float`) | |
| 217 | +| Fused kernel call | `_weight_int4pack_mm_with_scales_and_zeros(x, qW, gsz, s, z)` | `_weight_int4pack_mm_for_cpu(x, qW, gsz, scales_and_zeros)` | |
| 218 | +| Alignment requirements | Determined by manual pack loop (only needs `I % 8 == 0`) | Kernel enforces `I % 16 == 0` and `O % 16 == 0` | |
| 219 | +| Activation permutation | `x = x[:, ret_idx]` prior to matmul (same code path) | Same permutation reuse | |
| 220 | + |
| 221 | +Visually, you can think of the difference as *row-major lane packing* (XPU) |
| 222 | +versus *byte-tiling* (CPU): |
| 223 | + |
| 224 | +``` |
| 225 | +XPU: | int32 lane | = [w7][w6][w5][w4][w3][w2][w1][w0] |
| 226 | +CPU: | uint8 lane | = [w1][w0] |
| 227 | +``` |
| 228 | + |
| 229 | +Both forms originate from the same `[O, I]` intermediate; the divergence is only |
| 230 | +in the final storage type, accompanying metadata, and fused operator ABI. |
| 231 | + |
| 232 | +## AWQ compatibility (`torch_fused_awq.py`) |
| 233 | + |
| 234 | +`TorchFusedAwqQuantLinear` (`gptqmodel/nn_modules/qlinear/torch_fused_awq.py`) |
| 235 | +reuses the CPU fused kernel while accepting checkpoints emitted by the AWQ |
| 236 | +tooling. The module always expects `qweight` to be stored in the AWQ layout |
| 237 | +`[in_features, out_features / pack_factor]`, meaning each row corresponds to a |
| 238 | +single logical input channel. `transform_cpu_awq` performs a fixed shim before |
| 239 | +the standard CPU packing runs: |
| 240 | + |
| 241 | +1. **Unpack AWQ rows** – `unpack_awq` expands each column lane into eight |
| 242 | + outputs, yielding `iweight[int8][I, O]` and `izeros[int8][G, O]`. Both |
| 243 | + tensors are then permuted with `reverse_awq_order` (the inverse of |
| 244 | + `quantization.awq.utils.packing_utils.AWQ_ORDER`) so the columns match the |
| 245 | + logical transformer layout expected by GPTQ. |
| 246 | +2. **Normalize zero codes** – AWQ stores integer zero points per output channel. |
| 247 | + `transform_cpu_awq` converts them into floating offsets compatible with the |
| 248 | + fused kernel using |
| 249 | + `zeros_fp16 = (2^{bits-1} - izeros) * scales_fp32`, keeping the result in |
| 250 | + `float16` so the metadata matches the original AWQ calibration statistics. |
| 251 | +3. **Repack into GPTQ lanes** – The unpacked `iweight` matrix is reshaped to |
| 252 | + `[I / pack_factor, pack_factor, O]` and re-packed along the `pack_factor` |
| 253 | + dimension so each row once again represents eight inputs inside a 32-bit |
| 254 | + lane. After this step `self.qweight` is indistinguishable from a GPTQ v2 |
| 255 | + tensor, which means the regular `transform_cpu` logic can run unchanged. |
| 256 | +4. **Delegate to the base CPU transform** – Calling `super().transform_cpu` |
| 257 | + converts the temporary GPTQ-formatted `qweight` into the `[O, I/2]` `uint8` |
| 258 | + int4pack layout and produces `scales_and_zeros` from the (temporarily zeroed) |
| 259 | + metadata. |
| 260 | +5. **Restore AWQ metadata** – Immediately afterward, the AWQ shim reinstates |
| 261 | + the real `float16` scales and the converted zero offsets, then rebuilds |
| 262 | + `scales_and_zeros = pack_scales_and_zeros(scales, zeros_fp16)`. This ensures |
| 263 | + `_weight_int4pack_mm_for_cpu` receives the same affine parameters the AWQ |
| 264 | + calibration solved for. |
| 265 | + |
| 266 | +Because the shim runs entirely on the CPU path, `TorchFusedAwqQuantLinear` |
| 267 | +currently raises `NotImplementedError` when asked to run the fused transform on |
| 268 | +`xpu` devices. If the module has not been transformed yet (or fused ops are |
| 269 | +unavailable), inference falls back to the dense AWQ matmul computed by |
| 270 | +`awq_weight_dequantize`, which simply dequantizes the cached AWQ tensors on the fly. |
| 271 | + |
| 272 | +## Quick reference |
| 273 | + |
| 274 | +| Stage | Shape / dtype (int4) | Notes | |
| 275 | +|--------------------------------|-----------------------------------------------------------|------------------------------------------------| |
| 276 | +| Raw `qweight` | `[I / 8, O]`, `int32` | 8 nibbles per lane | |
| 277 | +| After unpack + transpose | `[O, I]`, `int8` (values in `[0, 15]`) | Used by both device paths | |
| 278 | +| Packed XPU `qweight` | `[O, I / 8]`, `int32` | Bits `[3:0]` hold the lowest-numbered channel | |
| 279 | +| Packed CPU `qweight` | `[O, I / 2]`, `uint8` | High nibble = odd input, low nibble = even | |
| 280 | +| `qzeros` (post-XPU transform) | `[G, O]`, matches `scales` | Passed separately to the XPU fused op | |
| 281 | +| `scales_and_zeros` (CPU only) | `[G, O, 2]`, float | `[..., 0] = scale`, `[..., 1] = zero` | |
| 282 | +| Raw AWQ `qweight` | `[I, O / 8]`, `int32` | Rows are single inputs packed across outputs | |
| 283 | +| Unpacked AWQ weights/zeros | `iweight[I, O]`, `izeros[G, O]`, `int8` | Produced by `unpack_awq` + `reverse_awq_order` | |
| 284 | +| AWQ zero offsets (final) | `[G, O]`, `float16` | `(2^{bits-1} - izeros) * scales`; merged via `pack_scales_and_zeros` | |
| 285 | + |
| 286 | +These details mirror the expectations of the Intel XPU and CPU fused matmul |
| 287 | +kernels, and the ASCII layouts above describe how rows/columns line up inside |
| 288 | +every packed tensor before the fused matmul executes. |
0 commit comments