You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Composing autoquant with compile
Summary:
this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:
torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)
now you can do
torchao.autoquant(torch.compile(model))
model(input)
The new method works with/without compile. Also this is BC so the old
path also works.
We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.
note: in the case of multiple inputs, you can also do:
model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.
Test Plan: python test/integration/test_integration.py -k "autoquant"
Reviewers:
Subscribers:
Tasks:
Tags:
* Fused DoRA kernels (#216)
* add dora kernels
* allowing error_on_unseen in autoquant func
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
* Unified AffineQuantizedTensor subclass (#214)
Summary:
Creatd a `AffineQuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)
only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later
Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w
Reviewers:
Subscribers:
Tasks:
Tags:
Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
* add expecttest to requirements.txt (#225)
* add expecttest to requirements.txt
* update
* Install dev-requirements.txt in doc build (#224)
Install dev-requirements.txt
---------
Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
* Fix an error in subclass impl (#226)
Summary:
Accidently changed the device check code for old subclass instead of the new one, forgot to fix before landing
Test Plan:
CI
Reviewers:
Subscribers:
Tasks:
Tags:
* update readme.md
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
* trying to fix the error in CI on cleanup hooks
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
* correct docs
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
* Some follow up fixes for quant primitives (#220)
Summary:
att
Test Plan:
python test/quantization/test_quant_primitives.py -k test_raises
Reviewers:
Subscribers:
Tasks:
Tags:
* Composing autoquant with compile
Summary:
this PR rewrites how torchao.autoquant works so that it works with
torch.compile. Previously you had to do:
torchao.autoquant(model, input)
mod=torch.compile(model)
mod(input)
now you can do
torchao.autoquant(torch.compile(model))
model(input)
The new method works with/without compile. Also this is BC so the old
path also works.
We use a forward_prehook to intercept the model call before
torch.compile tracing occurs at which point we do the autoquantization
and clean up all remaining hooks before passing things off to the
normal torch.compile tracing functionality.
note: in the case of multiple inputs, you can also do:
model.forward_log_only(input) to run the model forward with autoquant
shape logging and prevent the torch.compile tracing/autoquant
quantization from occuring.
Test Plan: python test/integration/test_integration.py -k "autoquant"
Reviewers:
Subscribers:
Tasks:
Tags:
* allowing error_on_unseen in autoquant func
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
* update readme.md
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
* trying to fix the error in CI on cleanup hooks
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
* correct docs
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
---------
Co-authored-by: jeromeku <jerome.ku@gmail.com>
Co-authored-by: Jerry Zhang <jerryzh168@gmail.com>
Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
model =torchao.autoquant(torch.compile(model, mode='max-autotune'))
33
33
34
-
#compile the model to improve performance
35
-
model =torch.compile(model, mode='max-autotune')
34
+
#pass in an input which is used in order to pick fastest quantization operations
35
+
# and apply torch compilation.
36
36
model(input)
37
37
```
38
38
@@ -167,6 +167,6 @@ model(input)
167
167
168
168
## Notes
169
169
170
-
1. APIs have been hardware tested on A100 and T4(colab)
170
+
1. APIs have been hardware tested on A100 and T4(colab)
171
171
2. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
172
172
3. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
0 commit comments