Skip to content

Commit bc2deb7

Browse files
committed
Update on "Autoquant"
Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL pytorch-labs/segment-anything-fast#114 HDCharles/sdxl-fast@8d9942a Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
2 parents 29214a9 + 2ae74d3 commit bc2deb7

27 files changed

+3047
-134
lines changed

.github/workflows/nightly-build.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: PyPI Nightly Build
2+
3+
on:
4+
schedule:
5+
- cron: '0 0 * * *' # Runs at midnight UTC every day
6+
workflow_dispatch:
7+
8+
jobs:
9+
build-and-publish:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v3
13+
- name: Set up Python
14+
uses: actions/setup-python@v4
15+
with:
16+
python-version: '3.x'
17+
- name: Install dependencies
18+
run: |
19+
python -m pip install --upgrade pip
20+
pip install setuptools wheel twine
21+
- name: Build package
22+
run: |
23+
export TORCHAO_NIGHTLY=1
24+
python setup.py sdist bdist_wheel
25+
- name: Publish package to PyPI
26+
uses: pypa/gh-action-pypi-publish@release/v1
27+
with:
28+
user: __token__
29+
password: ${{ secrets.PYPI_API_TOKEN }}
30+
repository_url: https://upload.pypi.org/legacy/
31+
packages_dir: dist/

.github/workflows/regression_test.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: Run Regression Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
branches:
9+
- main
10+
11+
jobs:
12+
test:
13+
runs-on: 4-core-ubuntu-gpu-t4
14+
steps:
15+
- uses: actions/checkout@v2
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v2
19+
with:
20+
python-version: 3.9
21+
22+
- name: Install dependencies
23+
run: |
24+
python -m pip install --upgrade pip
25+
pip install -r requirements.txt
26+
pip install -r dev-requirements.txt
27+
pip install torch
28+
29+
30+
- name: Install package
31+
run: |
32+
pip install .
33+
34+
- name: Run tests
35+
run: |
36+
pytest test

.github/workflows/test_install.yml

Lines changed: 0 additions & 31 deletions
This file was deleted.

README.md

Lines changed: 49 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,57 @@
1-
# torchao
1+
# torchao: PyTorch Architecture Optimization
22

3-
**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the APIs.**
3+
**Note: This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue**
4+
5+
The `torchao` package allows you to quantize and prune your models using native PyTorch.
6+
7+
The repo hosts both
8+
1. lower precision [dtypes](./torchao/dtypes) such as nf4, uint4
9+
2. Quantization [algorithms](./torchao/quantization) such as dynamic quant, smoothquant
10+
3. Sparsity [algorithms](./torchao/sparsity) such as Wanda
11+
12+
## Success stories
13+
Our kernels have has been used to achieve SOTA inference performance on
14+
15+
1. Image segmentation modelss with [sam-fast](pytorch.org/blog/accelerating-generative-ai)
16+
2. Language models with [gpt-fast](pytorch.org/blog/accelerating-generative-ai-2)
17+
3. Diffusion models with [sd-fast](pytorch.org/blog/accelerating-generative-ai-3)
418

5-
The torchao package contains apis and workflows used to apply AO techniques like quantization and pruning to models using only native pytorch.
619

720
## Installation
821

922
**Note: this library makes liberal use of several new features in pytorch, its recommended to use it with the current pytorch nightly if you want full feature coverage. If not, the subclass APIs may not work, though the module swap api's will still work.**
1023

1124
1. From PyPI:
12-
```
25+
```Shell
1326
pip install torchao
1427
```
1528

1629
2. From Source:
1730

18-
```
31+
```Shell
1932
git clone https://github.com/pytorch-labs/ao
2033
cd ao
21-
python setup.py install
22-
```
23-
24-
Verify Installation:
25-
26-
```
27-
pip list | grep torchao
28-
```
29-
30-
Expected Output
31-
```
32-
torchao 0.0.1 <install dir>
34+
pip install -e .
3335
```
3436

35-
## Usage
37+
## Examples
3638

37-
Relevant APIs can be found in torchao.quantization.quant_api
38-
39-
Note: While these techniques are designed to improve model performance, in some cases the opposite can occur.
40-
This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
41-
42-
The following apis use quantized [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor). By taking a linear op/module and replacing the original weight with a q-tensor subclass, we're able to convert it into a quantized version of the op. Upon replacement, these q-tensor subclasses quantize the original weight and override the dispatch for linear ops to instead use the subclass' _quantized_op method.
43-
44-
This tensor subclass method of quantization is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
39+
Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change.
4540

4641
### Autoquantization
4742

4843
The `autoquant` api can be used to quickly and accurately quantize your model. When used as in the example below, the api first identifies the shapes
4944
of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. Currently this api chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer.
5045

51-
```
46+
```python
5247
import torch
5348
import torchao
5449

55-
# inductor settings which improve torch.compile runtime for quantized modules
50+
# inductor settings which improve torch.compile performance for quantized modules
5651
torch._inductor.config.force_fuse_int_mm_with_mul
5752
torch._inductor.config.use_mixed_mm
5853

59-
# some user model and example input
54+
# Plug in your model and example input
6055
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
6156
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')
6257

@@ -71,94 +66,59 @@ model(input)
7166

7267
### A8W8 Dynamic Quantization
7368

74-
The `change_linear_weights_to_int8_dqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8DynamicallyQuantizedLinearWeight`. In practice this
75-
converts the floating point linear matmul of the original linear op to a dynamically quantized linear matmul.
76-
77-
Example
78-
79-
```
80-
# some user model and example input
81-
...
82-
69+
```python
8370
# convert linear modules to quantized linear modules
8471
torchao.change_linear_weights_to_int8_dqtensors(model)
85-
86-
# compile the model to improve performance
87-
...
8872
```
8973

90-
This technique works best when the torch._inductor.config.force_fuse_int_mm_with_mul option is enabled. This allows fusion of the int8*int8 -> int32 matmul and subsequent mul op, thereby avoiding materialization of the int32 intermediary tensor.
91-
92-
9374
### A16W8 WeightOnly Quantization
9475

95-
The `change_linear_weights_to_int8_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int8WeightOnlyQuantizedLinearWeight`. In practice this
96-
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul
97-
98-
Example
99-
100-
```
101-
# some user model and example input
102-
...
103-
104-
# convert linear modules to quantized linear modules
76+
```python
10577
torchao.change_linear_weights_to_int8_woqtensors(model)
106-
107-
# compile the model to improve performance
108-
...
10978
```
11079

11180
This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor.
11281

11382

11483
### A16W4 WeightOnly Quantization
11584

116-
The `change_linear_weights_to_int4_woqtensors` function converts the linear weights in a model to a quantized tensor subclass `Int4WeightOnlyQuantizedLinearWeight`. In practice this
117-
converts the floating point linear matmul of the original linear op to a weight only quantized linear matmul
118-
119-
Example
120-
121-
```
122-
# some user model and example input
123-
...
124-
125-
# convert linear modules to quantized linear modules
85+
```python
12686
torchao.change_linear_weights_to_int4_woqtensors(model)
127-
128-
# compile the model to improve performance
129-
...
13087
```
13188

132-
The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
133-
134-
## Other APIs
89+
Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
13590

136-
### Module Swap APIs
137-
138-
The `apply_dynamic_quant` and `apply_weight_only_int8_quant` apis can be used in the same formula as above to achieve dynamic and weight-only quantization using module swaps instead of quantized tensor subclasses.
13991

14092
### A8W8 Dynamic Quantization with Smoothquant
14193

142-
We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above.
143-
Due to requiring calibration, the API is slightly more complicated and currently only exists with a module swap api.
94+
We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above. Due to requiring calibration, the API is more complicated.
14495

14596
Example
14697

147-
```
98+
```Python
14899
import torch
149100
from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear, smooth_fq_linear_to_inference
150101

151-
# some user model
102+
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
103+
torch._inductor.config.force_fuse_int_mm_with_mul = True
104+
105+
# plug in your model
152106
model = get_model()
153107

154108
# convert linear modules to smoothquant
155109
# linear module in calibration mode
156110
swap_linear_with_smooth_fq_linear(model)
157111

158-
# calibration
159-
for i in range(calibration_amount):
160-
input = get_input()
161-
model(input)
112+
# Create a data loader for calibration
113+
calibration_data = get_calibration_data()
114+
calibration_dataset = MyDataset(calibration_data)
115+
calibration_loader = DataLoader(calibration_dataset, batch_size=32, shuffle=True)
116+
117+
# Calibrate the model
118+
model.train()
119+
for batch in calibration_loader:
120+
inputs = batch
121+
model(inputs)
162122

163123
# set it to inference mode
164124
smooth_fq_linear_to_inference(model)
@@ -168,7 +128,11 @@ model = torch.compile(model, mode='max-autotune')
168128
model(input)
169129
```
170130

171-
like the other dynamic quantization apis, the torch._inductor.config.force_fuse_int_mm_with_mul option may significantly improve performance if enabled.
131+
## Sharp edges
132+
133+
1. While these techniques are designed to improve model performance, in some cases the opposite can occur. This is because quantization adds additional overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization) or loading weights faster (weight-only quantization). If your matmuls are small enough or your non-quantized perf isn't bottlenecked by weight load time, these techniques may reduce performance.
134+
2. Use the PyTorch nightlies so you can leverage [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) which is preferred over older module swap based methods because it doesn't modify the graph and is generally more composable and flexible.
135+
172136

173137
## License
174138

dev-requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pytest
2+
expecttest
3+
packaging

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch
2+
numpy
3+
sentencepiece

setup.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,30 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3-
43
# This source code is licensed under the license found in the
54
# LICENSE file in the root directory of this source tree.
65

6+
import os
7+
from datetime import datetime
78
from setuptools import setup, find_packages
9+
current_date = datetime.now().strftime('%Y.%m.%d')
10+
11+
12+
def read_requirements(file_path):
13+
with open(file_path, 'r') as file:
14+
return file.read().splitlines()
15+
16+
# Determine the package name based on the presence of an environment variable
17+
package_name = 'torchao-nightly' if os.environ.get('TORCHAO_NIGHTLY') else 'torchao'
18+
19+
# Version is year.month.date if using nightlies
20+
version = current_date if package_name == 'torchao-nightly' else '0.0.3'
21+
822

923
setup(
10-
name='torchao',
11-
version='0.0.3',
24+
name=package_name,
25+
version=version,
1226
packages=find_packages(),
13-
install_requires=[
14-
'torch',
15-
],
27+
install_requires=read_requirements('requirements.txt'),
1628
description='Package for applying ao techniques to GPU models',
1729
long_description=open('README.md').read(),
1830
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)