Skip to content

Commit 3044ee5

Browse files
Add SpinQuant to generate.py (#1069)
* Only import SpinQuant when necessary No need to import the large Hadamard matrices required for SpinQuant if it isn't necessary * Add SpinQaunt to `generate.py` * Custom op for Hadamard transform for torch.compile compatability * Add spinquant to arg parser info * Add Spinquant benchmark results to README * Add performance testing details * Fix broken custom op for PyTorch < 2.4
1 parent f1b4c8e commit 3044ee5

File tree

5 files changed

+91
-14
lines changed

5 files changed

+91
-14
lines changed

torchao/_models/llama/eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from tokenizer import get_tokenizer
3232
import time
3333
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
34-
from torchao.prototype.spinquant import apply_spinquant
3534

3635
def run_evaluation(
3736
checkpoint_path: Path,
@@ -71,6 +70,7 @@ def run_evaluation(
7170

7271
if quantization:
7372
if "spinquant" in quantization:
73+
from torchao.prototype.spinquant import apply_spinquant
7474
apply_spinquant(model)
7575
if "int8wo" in quantization:
7676
quantize_(model, int8_weight_only())

torchao/_models/llama/generate.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ def main(
217217
float8_dynamic_activation_float8_weight,
218218
)
219219
from torchao.quantization.granularity import PerTensor, PerRow
220+
if "spinquant" in quantization:
221+
from torchao.prototype.spinquant import apply_spinquant
222+
apply_spinquant(model)
220223
if "int8wo" in quantization:
221224
quantize_(model, int8_weight_only())
222225
if "int8dq" in quantization:
@@ -460,7 +463,7 @@ def callback(x):
460463
parser.add_argument('-q', '--quantization', type=str,
461464
help=(
462465
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
463-
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
466+
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant'
464467
)
465468
)
466469
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")

torchao/prototype/spinquant/README.md

+31-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,37 @@ Re-implementation of SpinQuant based on the official code implementation (https:
44

55
## Usage
66

7-
Using this implementation with CUDA requires installing the Fast Hadamard Transform CUDA package, which can be done as follows:
7+
For optimal performance on CUDA GPUs, install the Fast Hadamard Transform package:
88

99
```shell
1010
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
11-
```
11+
```
12+
13+
## Performance
14+
15+
See https://github.com/pytorch/ao/pull/983 for Wikitext benchmark results.
16+
17+
Tested on:
18+
19+
- Llama-2-7b
20+
- PyTorch 2.4.1
21+
- NVIDIA A100
22+
- CUDA 12.1
23+
24+
Without `torch.compile`:
25+
26+
| Configuration | Average tokens/sec | Average Bandwidth (GB/s) | Peak Memory Usage (GB) | Model Size (GB) |
27+
|----------------|--------------------|--------------------------|------------------------|-----------------|
28+
| Baseline | 27.33 | 361.21 | 13.62 | 13.21 |
29+
| Spinquant (R4) | 23.01 | 304.10 | 14.24 | 13.22 |
30+
31+
With `torch.compile`:
32+
33+
| Configuration | Average tokens/sec | Average Bandwidth (GB/s) | Peak Memory Usage (GB) | Model Size (GB) |
34+
|----------------------|--------------------|--------------------------|------------------------|-----------------|
35+
| Baseline | 114.08 | 1507.58 | 13.88 | 13.21 |
36+
| Spinquant (R4) | 109.59 | 1448.61 | 13.72 | 13.22 |
37+
| Spinquant (R1+R2+R4) | 109.64 | 1449.28 | 14.90 | 13.22 |
38+
39+
40+
NB: R1 and R2 are fused into the linear weights before inference takes place, so it is expected that they do not lead to additional overhead at inference time.

torchao/prototype/spinquant/hadamard_utils.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111

1212
import torch
1313

14+
from torchao.ops import lib
1415
from torchao.prototype.spinquant._hadamard_matrices import get_had172, get_had156, get_had140, get_had108, get_had60, get_had52, get_had36, get_had28, get_had44, get_had40, get_had20, get_had12
16+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
1517

1618
try:
17-
from fast_hadamard_transform import hadamard_transform
19+
from fast_hadamard_transform import hadamard_transform as _fast_hadamard_transform
1820

1921
def matmul_hadU(X, hadK, K):
2022
if X.is_cuda:
@@ -32,16 +34,59 @@ def matmul_hadU(X, hadK, K):
3234
return matmul_hadU_slow(X, hadK, K)
3335

3436

37+
def register_custom_op_impl(name):
38+
def decorator(func):
39+
if TORCH_VERSION_AT_LEAST_2_4:
40+
return torch.library.custom_op(f"{name}", mutates_args=())(func)
41+
else:
42+
lib.define("hadamard_transform(Tensor x, float scale = 0.0) -> Tensor")
43+
return torch.library.impl(f"{name}", "cuda")(func)
44+
return decorator
45+
46+
47+
def register_custom_op_abstract(name):
48+
def decorator(func):
49+
if TORCH_VERSION_AT_LEAST_2_4:
50+
return torch.library.register_fake(f"{name}")(func)
51+
else:
52+
return torch.library.impl_abstract(f"{name}")(func)
53+
return decorator
54+
55+
56+
@register_custom_op_impl("torchao::hadamard_transform")
57+
def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
58+
"""
59+
Arguments:
60+
x: (..., dim)
61+
scale: float. Multiply the output by this number.
62+
Returns:
63+
out: (..., dim)
64+
65+
Multiply each row of x by the Hadamard transform matrix.
66+
Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
67+
If dim is not a power of 2, we implicitly pad x with zero so that dim is the next power of 2.
68+
69+
Source: https://github.com/Dao-AILab/fast-hadamard-transform
70+
"""
71+
return _fast_hadamard_transform(x, scale)
72+
73+
74+
@register_custom_op_abstract("torchao::hadamard_transform")
75+
def _(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
76+
torch._check(x.dim() >= 1, lambda: f"input should be at least a 1D tensor, got {x.dim()}D")
77+
return torch.empty_like(x)
78+
79+
3580
class HadamardTransform(torch.autograd.Function):
3681
"""The unnormalized Hadamard transform (i.e. without dividing by sqrt(2))"""
3782

3883
@staticmethod
3984
def forward(ctx, u):
40-
return hadamard_transform(u)
85+
return _fast_hadamard_transform(u)
4186

4287
@staticmethod
4388
def backward(ctx, grad):
44-
return hadamard_transform(grad)
89+
return _fast_hadamard_transform(grad)
4590

4691

4792
def is_pow2(n):
@@ -144,9 +189,9 @@ def matmul_hadU_slow(X, hadK, K):
144189
def matmul_hadU_fast(X, hadK, K):
145190
n = X.shape[-1]
146191
if K == 1:
147-
return HadamardTransform.apply(X.contiguous()) / torch.tensor(n).sqrt()
192+
return torch.ops.torchao.hadamard_transform.default(X.contiguous()) / torch.tensor(n).sqrt()
148193
input = X.view(-1, K, n // K)
149-
input = HadamardTransform.apply(input.contiguous()) / torch.tensor(n).sqrt()
194+
input = torch.ops.torchao.hadamard_transform.default(input.contiguous()) / torch.tensor(n).sqrt()
150195
input = hadK.to(input.device).to(input.dtype) @ input
151196
return input.reshape(X.shape)
152197

torchao/prototype/spinquant/spinquant.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def apply_spinquant_r4(model, device):
103103
_add_activation_wrappers_r4(model)
104104

105105

106-
@torch.inference_mode()
106+
@torch.no_grad()
107107
def _fuse_layernorm_into_linear(layernorm: RMSNorm, linear_layers: typing.Iterable[torch.nn.Linear]):
108108
"""Fuse the linear operations in Layernorm into the adjacent linear blocks."""
109109
for linear in linear_layers:
@@ -127,7 +127,7 @@ def _fuse_layernorm_into_linear(layernorm: RMSNorm, linear_layers: typing.Iterab
127127
layernorm.weight.data = torch.ones_like(layernorm.weight.data)
128128

129129

130-
@torch.inference_mode()
130+
@torch.no_grad()
131131
def _rotate_model_r1(model, R1):
132132
_rotate_embeddings(model, R1)
133133
_rotate_head(model, R1)
@@ -139,7 +139,7 @@ def _rotate_model_r1(model, R1):
139139
_rotate_mlp_output(layer, R1)
140140

141141

142-
@torch.inference_mode()
142+
@torch.no_grad()
143143
def _rotate_model_r2(model, R2s):
144144
"""Rotate the W_v and W_o weights of the multi-head self-attention modules."""
145145

@@ -168,7 +168,7 @@ def _rotate_model_r2(model, R2s):
168168
attn.wqkv.weight.data = torch.cat([wq, wk, wv_mod.weight.data], dim=0)
169169

170170

171-
@torch.inference_mode()
171+
@torch.no_grad()
172172
def _rotate_model_r4(model):
173173
"""Rotate the MLP output weights."""
174174

@@ -193,7 +193,7 @@ def _add_activation_wrappers_r4(model):
193193
)
194194

195195

196-
@torch.inference_mode()
196+
@torch.no_grad()
197197
def fuse_layernorm_into_linear(model):
198198
"""
199199
Fuse RMSNorm weights into the subsequent linear layers.

0 commit comments

Comments
 (0)