You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary: Adding autoquantization functionality, using hte do_quant api
we can test kernel speeds and pick the best quantization type (or no
quantization) for each layer.
Test Plan: python test/test.py -k "autoquant"
also tested on SAM and SDXL
pytorch-labs/segment-anything-fast#114HDCharles/sdxl-fast@8d9942a
Reviewers:
Subscribers:
Tasks:
Tags:
[ghstack-poisoned]
Copy file name to clipboardExpand all lines: README.md
+50-77Lines changed: 50 additions & 77 deletions
Original file line number
Diff line number
Diff line change
@@ -1,60 +1,53 @@
1
-
# torchao
1
+
# torchao: PyTorch Architecture Optimization
2
2
3
-
**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the APIs.**
3
+
**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue**
4
+
5
+
The `torchao` package allows you to quantize and prune your models using native PyTorch.
6
+
7
+
The repo hosts both
8
+
1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4
9
+
2. Quantization [algorithms](./torchao/quantization) such as dynamic quant, smoothquant
10
+
3. Sparsity [algorithms](./torchao/sparsity) such as Wanda
11
+
12
+
## Success stories
13
+
Our kernels have has been used to achieve SOTA inference performance on
14
+
15
+
1. Image segmentation modelss with [sam-fast](pytorch.org/blog/accelerating-generative-ai)
16
+
2. Language models with [gpt-fast](pytorch.org/blog/accelerating-generative-ai-2)
17
+
3. Diffusion models with [sd-fast](pytorch.org/blog/accelerating-generative-ai-3)
4
18
5
-
The torchao package contains apis and workflows used to apply AO techniques like quantization and pruning to models using only native pytorch.
6
19
7
20
## Installation
8
21
9
22
**Note: this library makes liberal use of several new features in pytorch, its recommended to use it with the current pytorch nightly if you want full feature coverage. If not, the subclass APIs may not work, though the module swap api's will still work.**
10
23
11
24
1. From PyPI:
12
-
```
25
+
```Shell
13
26
pip install torchao
14
27
```
15
28
16
29
2. From Source:
17
30
18
-
```
31
+
```Shell
19
32
git clone https://github.com/pytorch-labs/ao
20
33
cd ao
21
-
python setup.py install
34
+
pip install -e .
22
35
```
23
36
24
-
Verify Installation:
25
-
26
-
```
27
-
pip list | grep torchao
28
-
```
29
-
30
-
Expected Output
31
-
```
32
-
torchao 0.0.1 <install dir>
33
-
```
37
+
## Examples
34
38
35
-
## Usage
36
-
37
-
Relevant APIs can be found in torchao.quantization.quant_api
38
-
39
-
Note: While these techniques are designed to improve model performance, in some cases the opposite can occur.
40
-
This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
41
-
42
-
The following apis use quantized [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor). By taking a linear op/module and replacing the original weight with a q-tensor subclass, we're able to convert it into a quantized version of the op. Upon replacement, these q-tensor subclasses quantize the original weight and override the dispatch for linear ops to instead use the subclass' _quantized_op method.
43
-
44
-
This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
39
+
Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change.
45
40
46
41
### A8W8 Dynamic Quantization
47
42
48
-
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
49
-
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
50
-
51
-
Example
52
-
53
-
```
43
+
```Python
54
44
import torch
55
45
from torchao.quantization import quant_api
56
46
57
-
# some user model and example input
47
+
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
@@ -66,78 +59,54 @@ model = torch.compile(model, mode='max-autotune')
66
59
model(input)
67
60
```
68
61
69
-
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
70
-
71
-
72
62
### A16W8 WeightOnly Quantization
73
63
74
-
The `change_linear_weights_to_int8_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8WeightOnlyQuantizedLinearWeight`. In practice this
75
-
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul
76
-
77
-
Example
78
-
79
-
```
80
-
# some user model and example input
81
-
...
82
-
83
-
# convert linear modules to quantized linear modules
This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.
91
69
92
70
93
71
### A16W4 WeightOnly Quantization
94
72
95
-
The `change_linear_weights_to_int4_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int4WeightOnlyQuantizedLinearWeight`. In practice this
96
-
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul
97
-
98
-
Example
99
-
100
-
```
101
-
# some user model and example input
102
-
...
103
-
104
-
# convert linear modules to quantized linear modules
The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
112
-
113
-
## Other APIs
114
-
115
-
### Module Swap APIs
77
+
Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
116
78
117
-
The `apply_dynamic_quant` and `apply_weight_only_int8_quant` apis can be used in the same formula as above to achieve dynamic and weight-only quantization using module swaps instead of quantized tensor subclasses.
118
79
119
80
### A8W8 Dynamic Quantization with Smoothquant
120
81
121
-
We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above.
122
-
Due to requiring calibration, the API is slightly more complicated and currently only exists with a module swap api.
82
+
We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above. Due to requiring calibration, the API is more complicated.
123
83
124
84
Example
125
85
126
-
```
86
+
```Python
127
87
import torch
128
88
from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear, smooth_fq_linear_to_inference
129
89
130
-
# some user model
90
+
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
@@ -147,7 +116,11 @@ model = torch.compile(model, mode='max-autotune')
147
116
model(input)
148
117
```
149
118
150
-
like the other dynamic quantization apis, the torch._inductor.config.force_fuse_int_mm_with_mul option may significantly improve performance if enabled.
119
+
## Sharp edges
120
+
121
+
1. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
122
+
2. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
0 commit comments