Skip to content

Commit f6d56ca

Browse files
HDCharlesjeromekujerryzh168msaroufimsvekars
authored
Composing autoquant with compile (#175)
* Composing autoquant with compile Summary: this PR rewrites how torchao.autoquant works so that it works with torch.compile. Previously you had to do: torchao.autoquant(model, input) mod=torch.compile(model) mod(input) now you can do torchao.autoquant(torch.compile(model)) model(input) The new method works with/without compile. Also this is BC so the old path also works. We use a forward_prehook to intercept the model call before torch.compile tracing occurs at which point we do the autoquantization and clean up all remaining hooks before passing things off to the normal torch.compile tracing functionality. note: in the case of multiple inputs, you can also do: model.forward_log_only(input) to run the model forward with autoquant shape logging and prevent the torch.compile tracing/autoquant quantization from occuring. Test Plan: python test/integration/test_integration.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: * Fused DoRA kernels (#216) * add dora kernels * allowing error_on_unseen in autoquant func Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Unified AffineQuantizedTensor subclass (#214) Summary: Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine and dequantize_affine ops) only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w Reviewers: Subscribers: Tasks: Tags: Co-authored-by: Mark Saroufim <marksaroufim@meta.com> * add expecttest to requirements.txt (#225) * add expecttest to requirements.txt * update * Install dev-requirements.txt in doc build (#224) Install dev-requirements.txt --------- Co-authored-by: Mark Saroufim <marksaroufim@meta.com> * Fix an error in subclass impl (#226) Summary: Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * update readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * trying to fix the error in CI on cleanup hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * correct docs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Some follow up fixes for quant primitives (#220) Summary: att Test Plan: python test/quantization/test_quant_primitives.py -k test_raises Reviewers: Subscribers: Tasks: Tags: * Composing autoquant with compile Summary: this PR rewrites how torchao.autoquant works so that it works with torch.compile. Previously you had to do: torchao.autoquant(model, input) mod=torch.compile(model) mod(input) now you can do torchao.autoquant(torch.compile(model)) model(input) The new method works with/without compile. Also this is BC so the old path also works. We use a forward_prehook to intercept the model call before torch.compile tracing occurs at which point we do the autoquantization and clean up all remaining hooks before passing things off to the normal torch.compile tracing functionality. note: in the case of multiple inputs, you can also do: model.forward_log_only(input) to run the model forward with autoquant shape logging and prevent the torch.compile tracing/autoquant quantization from occuring. Test Plan: python test/integration/test_integration.py -k "autoquant" Reviewers: Subscribers: Tasks: Tags: * allowing error_on_unseen in autoquant func Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * update readme.md Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * trying to fix the error in CI on cleanup hooks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * correct docs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: jeromeku <jerome.ku@gmail.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: Mark Saroufim <marksaroufim@meta.com> Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 63c5ac5 commit f6d56ca

File tree

5 files changed

+133
-33
lines changed

5 files changed

+133
-33
lines changed

README.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,9 @@ torch._inductor.config.use_mixed_mm = True
4444
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
4545
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
4646

47-
# perform autoquantization
48-
torchao.autoquant(model, (input))
49-
50-
# compile the model to recover performance
51-
model = torch.compile(model, mode='max-autotune')
52-
model(input)
47+
# perform autoquantization and compilation
48+
q_model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
49+
q_model(input)
5350
```
5451

5552
### Sparsity

test/integration/test_integration.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,7 +1388,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
13881388
torch.nn.ReLU(),
13891389
).to(device).to(dtype)
13901390
out = model(example_input)
1391-
torchao.autoquant(model, example_input)
1391+
torchao.autoquant(model)
13921392
out2 = model(example_input)
13931393
sqnr = SQNR(out, out2)
13941394
self.assertTrue(sqnr >= 30)
@@ -1400,7 +1400,9 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
14001400
(32, 32, 128, 128),
14011401
]))
14021402
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
1403-
def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
1403+
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
1404+
if device != "cuda" and dtype != torch.bfloat16:
1405+
self.skipTest(f"autoquant currently does not support {device}")
14041406
if device != "cuda" or not torch.cuda.is_available():
14051407
self.skipTest(f"autoquant currently does not support {device}")
14061408
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
@@ -1414,15 +1416,60 @@ def test_autoquant_multi_input(self, device, dtype, m1, m2, k, n):
14141416
torch.nn.ReLU(),
14151417
).to(device).to(dtype)
14161418
example_input = torch.randn(m1, k, device=device, dtype=dtype)
1417-
example_input2 = torch.randn(m2, k, device=device, dtype=dtype)
1418-
torchao.quantization.change_linears_to_autoquantizable(model)
1419-
out=model(example_input)
1420-
model(example_input2)
1421-
torchao.quantization.change_autoquantizable_to_quantized(model)
1422-
out2 = model(example_input)
1419+
example_input2 = torch.randn(m1, k, device=device, dtype=dtype)
1420+
out = model(example_input)
1421+
1422+
mod = torchao.autoquant(torch.compile(model))
1423+
mod.forward_log_only(example_input)
1424+
mod(example_input2)
1425+
1426+
out2 = mod(example_input)
14231427
sqnr = SQNR(out, out2)
14241428
self.assertTrue(sqnr >= 30)
14251429

1430+
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
1431+
[
1432+
(1, 1, 128, 128),
1433+
(1, 32, 128, 128),
1434+
(32, 32, 128, 128),
1435+
]))
1436+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
1437+
def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n):
1438+
if device != "cuda" and dtype != torch.bfloat16:
1439+
self.skipTest(f"autoquant currently does not support {device}")
1440+
if device != "cuda" or not torch.cuda.is_available():
1441+
self.skipTest(f"autoquant currently does not support {device}")
1442+
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
1443+
if dtype == torch.bfloat16:
1444+
self.skipTest(f"bfloat16 requires sm80+")
1445+
if m1 == 1 or m2 == 1:
1446+
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")
1447+
1448+
class NeedsKwargs(torch.nn.Module):
1449+
def __init__(self):
1450+
super().__init__()
1451+
self.rel = torch.nn.ReLU()
1452+
self.lin = torch.nn.Linear(k,n)
1453+
1454+
def forward(self, x, y):
1455+
x = self.rel(x)
1456+
z = self.lin(x + y)
1457+
return z
1458+
1459+
model = NeedsKwargs().to(device).to(dtype)
1460+
example_input = {
1461+
"x": torch.randn(m1, k, device=device, dtype=dtype),
1462+
"y": torch.randn(m1, k, device=device, dtype=dtype),
1463+
}
1464+
out = model(**example_input)
1465+
1466+
mod = torchao.autoquant(torch.compile(model))
1467+
mod.forward_log_only(**example_input)
1468+
mod(**example_input)
1469+
1470+
out2 = mod(**example_input)
1471+
sqnr = SQNR(out, out2)
1472+
self.assertTrue(sqnr >= 30)
14261473

14271474
class TestAOTI(unittest.TestCase):
14281475
@parameterized.expand(

torchao/quantization/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ torch._inductor.config.use_mixed_mm = True
2828
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
2929
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
3030

31-
# perform autoquantization
32-
torchao.autoquant(model, (input))
31+
# perform autoquantization and torch.compile
32+
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
3333

34-
# compile the model to improve performance
35-
model = torch.compile(model, mode='max-autotune')
34+
# pass in an input which is used in order to pick fastest quantization operations
35+
# and apply torch compilation.
3636
model(input)
3737
```
3838

@@ -167,6 +167,6 @@ model(input)
167167

168168
## Notes
169169

170-
1. APIs have been hardware tested on A100 and T4(colab)
170+
1. APIs have been hardware tested on A100 and T4(colab)
171171
2. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
172172
3. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.

torchao/quantization/autoquant.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
7474
res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode)
7575
update_cache(q_cls, shapes_and_dtype, res)
7676

77+
@torch.no_grad()
7778
def to_quantized(self, error_on_unseen, **kwargs):
7879
if error_on_unseen and self.logged_data == {}:
7980
raise RuntimeError("must run module normally to get shape, dtype info for autoquant")
@@ -123,7 +124,7 @@ def count_shapes(self, do_print=True):
123124
torch._dynamo.reset()
124125
cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen
125126
if shape_count is not None and shape_count > 1:
126-
print(f">total_time: {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms")
127+
print(f">time (all shapes): {cur_time:0.3f}ms for {q_cls}, prev_best: {best_time:0.3f}ms")
127128
if best_time >= cur_time:
128129
best_time = cur_time
129130
best_cls = q_cls
@@ -176,6 +177,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
176177
if func is aten.detach.default:
177178
return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach))
178179

180+
@torch.no_grad()
179181
def do_autoquant_bench(op, *args, **kwargs):
180182
"""
181183
runs benchmark op(*args, **kwargs) avoiding torch.compile overhead
@@ -335,6 +337,7 @@ def change_linears_to_autoquantizable(model, **kwargs):
335337
"""
336338
from torchao.quantization.quant_api import _is_linear
337339
filter_fn = kwargs.pop("filter_fn", _is_linear)
340+
_ = kwargs.pop("error_on_unseen", True) # same kwargs used for this and to_quantized
338341
kwargs["qtensor_class_list"] = kwargs.get("qtensor_class_list", DEFAULT_CLASS_LIST)
339342
kwargs["mode"] = kwargs.get("mode", ["relu", None])
340343
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
@@ -374,20 +377,71 @@ def change_autoquantizable_to_quantized(model, **kwargs):
374377
torch._dynamo.reset()
375378

376379
@torch.no_grad()
377-
def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **kwargs):
380+
def autoquant(model, example_input=None, qtensor_class_list=DEFAULT_CLASS_LIST, filter_fn=None, mode=["relu",None], **aq_kwargs):
378381
"""
379-
Runs the model with example_input to record shapes and then compares benchmark performance of the seen shape
380-
across the qtensor subclasses in qtensor_class_list. Determines best performing qtensor subclass for each layer
381-
and applies that type of quantization.
382+
wraps model in AutoQuantWrapper, if example_input is provided, runs forward on it, otherwise returns the wrapped model.
383+
AutoQuantWrapper handles instances where model is torch.compiled by first performing autoquantization on the original
384+
model and then letting the torch.compile run/tracing occur.
385+
386+
Example usage::
387+
388+
torchao.autoquant(torch.compile(model))
389+
model(*example_input)
390+
382391
"""
383-
if filter_fn is None:
384-
from torchao.quantization.quant_api import _is_linear
385-
filter_fn = _is_linear
392+
# the hook we will use to intercept the model forward and perform
393+
# autoquantization
394+
def autoquant_prehook(module, args, kwargs):
395+
module.forward_log_only(*args, **kwargs)
396+
change_autoquantizable_to_quantized(
397+
module,
398+
**aq_kwargs,
399+
)
400+
module.clean_up_autoquant_hooks_and_attrs()
401+
return args, kwargs
402+
403+
# perform initial swap from linear weights
404+
# to AutoQuantizableLinearWeight
405+
change_linears_to_autoquantizable(
406+
model,
407+
filter_fn=filter_fn,
408+
qtensor_class_list=qtensor_class_list,
409+
mode=mode,
410+
**aq_kwargs
411+
)
412+
413+
# access actual model of torch.compile wrapper if needed
414+
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
415+
real_model = model._orig_mod
416+
else:
417+
real_model = model
418+
419+
# we need a consistent way to run the model which bypasses both
420+
# A) the torch.compile tracing (so we need to run the inner model directly)
421+
# B) the autoquant_prehook we're about to register (so we call forward directly)
422+
model.forward_log_only = lambda *args, **kwargs: real_model.forward(*args, **kwargs)
423+
424+
# the autoquant_prehook intercepts the forward call and performs autoquantization
425+
# and then deletes the hook. if model is a torch.compile wrapper, it then
426+
# does the tracing/compile since the prehook is naturally followed by the normal.
427+
# model run.
428+
handle = model.register_forward_pre_hook(autoquant_prehook, with_kwargs=True)
429+
430+
# note the torch.compile wrapper eval_frame moved the assignment of any assigned
431+
# attributes to the inner model, so we have to call delattr on the inner model
432+
def clean_up_autoquant_hooks_and_attrs():
433+
try:
434+
handle.remove()
435+
delattr(real_model, "clean_up_autoquant_hooks_and_attrs")
436+
delattr(real_model, "forward_log_only")
437+
except:
438+
pass
439+
model.clean_up_autoquant_hooks_and_attrs = clean_up_autoquant_hooks_and_attrs
386440

387-
change_linears_to_autoquantizable(model, filter_fn=filter_fn, qtensor_class_list=qtensor_class_list, mode=mode, **kwargs)
388-
if not isinstance(example_input, (tuple, list)):
389-
assert isinstance(example_input, torch.Tensor)
441+
# if example input was provided, check it and run it
442+
if isinstance(example_input, torch.Tensor):
390443
example_input = [example_input]
391-
model(*example_input)
392-
change_autoquantizable_to_quantized(model, **kwargs)
444+
if isinstance(example_input, (tuple, list)):
445+
model(*example_input)
446+
393447
return model

torchao/quantization/quant_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Int4WeightOnlyGPTQQuantizer,
3535
Int4WeightOnlyQuantizer,
3636
)
37+
from .autoquant import autoquant
3738

3839

3940
__all__ = [
@@ -46,7 +47,8 @@
4647
"Quantizer",
4748
"TwoStepQuantizer",
4849
"Int4WeightOnlyGPTQQuantizer",
49-
"Int4WeightOnlyQuantizer"
50+
"Int4WeightOnlyQuantizer",
51+
"autoquant"
5052
]
5153

5254
if TORCH_VERSION_AFTER_2_3:

0 commit comments

Comments
 (0)