Skip to content

Commit 17c670a

Browse files
committed
Update on "Autoquant"
Summary: Adding autoquantization functionality, using hte do_quant api we can test kernel speeds and pick the best quantization type (or no quantization) for each layer. Test Plan: python test/test.py -k "autoquant" also tested on SAM and SDXL pytorch-labs/segment-anything-fast#114 HDCharles/sdxl-fast@8d9942a Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D55103983](https://our.internmc.facebook.com/intern/diff/D55103983) [ghstack-poisoned]
2 parents 5f5bc8e + b79c5bf commit 17c670a

39 files changed

+6382
-1530
lines changed

.github/workflows/regression_test.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,32 @@ jobs:
2727
pip install torch
2828
2929
30+
- name: Install package
31+
run: |
32+
pip install .
33+
34+
- name: Run tests
35+
run: |
36+
pytest test
37+
38+
test-nightly:
39+
runs-on: 4-core-ubuntu-gpu-t4
40+
steps:
41+
- uses: actions/checkout@v2
42+
43+
- name: Set up Python
44+
uses: actions/setup-python@v2
45+
with:
46+
python-version: 3.9
47+
48+
- name: Install dependencies
49+
run: |
50+
python -m pip install --upgrade pip
51+
pip install -r requirements.txt
52+
pip install -r dev-requirements.txt
53+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
54+
55+
3056
- name: Install package
3157
run: |
3258
pip install .

.lintrunner.toml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
merge_base_with = "origin/main"
2+
3+
[[linter]]
4+
code = 'FLAKE8'
5+
include_patterns = ['**/*.py']
6+
exclude_patterns = [
7+
'third-party/**',
8+
'**/third-party/**',
9+
]
10+
command = [
11+
'python',
12+
'-m',
13+
'lintrunner_adapters',
14+
'run',
15+
'flake8_linter',
16+
'--',
17+
'@{{PATHSFILE}}'
18+
]
19+
init_command = [
20+
'python',
21+
'-m',
22+
'lintrunner_adapters',
23+
'run',
24+
'pip_init',
25+
'--dry-run={{DRYRUN}}',
26+
'--requirement=requirements-lintrunner.txt',
27+
]
28+
29+
# Black + usort
30+
[[linter]]
31+
code = 'UFMT'
32+
include_patterns = [
33+
'**/*.py',
34+
'**/*.pyi',
35+
]
36+
exclude_patterns = [
37+
'third-party/**',
38+
'**/third-party/**',
39+
]
40+
command = [
41+
'python',
42+
'-m',
43+
'lintrunner_adapters',
44+
'run',
45+
'ufmt_linter',
46+
'--',
47+
'@{{PATHSFILE}}'
48+
]
49+
init_command = [
50+
'python',
51+
'-m',
52+
'lintrunner_adapters',
53+
'run',
54+
'pip_init',
55+
'--dry-run={{DRYRUN}}',
56+
'--no-black-binary',
57+
'--requirement=requirements-lintrunner.txt',
58+
]
59+
is_formatter = true

CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
msaroufim
2+
cpuhrsch

benchmarks/intmm.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import argparse
2+
import csv
3+
import itertools
4+
import math
5+
import pathlib
6+
7+
import torch
8+
import torch.nn.functional as F
9+
import torch.utils.benchmark as benchmark
10+
from torchao.kernel.intmm_triton import int_matmul, int_scaled_matmul
11+
12+
torch._dynamo.config.cache_size_limit = 128
13+
torch._dynamo.config.accumulated_cache_size_limit = 128
14+
15+
dtype = torch.float16
16+
device = "cuda"
17+
18+
19+
def benchmark_in_ms(warmup, iters, f, *args, **kwargs):
20+
for _ in range(warmup):
21+
f(*args, **kwargs)
22+
torch.cuda.synchronize()
23+
start_event = torch.cuda.Event(enable_timing=True)
24+
end_event = torch.cuda.Event(enable_timing=True)
25+
start_event.record()
26+
27+
for _ in range(iters):
28+
f(*args, **kwargs)
29+
30+
end_event.record()
31+
torch.cuda.synchronize()
32+
return start_event.elapsed_time(end_event) / float(iters)
33+
34+
35+
@torch.compile(mode="max-autotune")
36+
def compiled_mm(x, w):
37+
return torch.mm(x, w)
38+
39+
40+
@torch.compile(mode="max-autotune")
41+
def compiled_int_mm(x, w):
42+
return torch._int_mm(x, w)
43+
44+
45+
def run_int_mm_benchmark(x, w, b):
46+
fp_time = benchmark_in_ms(10, 100, torch.mm, x, w)
47+
x_int = x.to(dtype=torch.int8)
48+
w_int = w.to(dtype=torch.int8)
49+
int_mm_time = benchmark_in_ms(10, 100, int_matmul, x_int, w_int)
50+
return fp_time, int_mm_time
51+
52+
53+
def run_int_scaled_mm_benchmark(x, w, b):
54+
scales = x.sum(-1, keepdim=True)
55+
fp_time = benchmark_in_ms(10, 100, lambda x, w, s: torch.mm(x, w) * s, x, w, scales)
56+
x_int = x.to(dtype=torch.int8)
57+
w_int = w.to(dtype=torch.int8)
58+
int_scaled_mm_time = benchmark_in_ms(
59+
10, 100, int_scaled_matmul, x_int, w_int, scales
60+
)
61+
return fp_time, int_scaled_mm_time
62+
63+
64+
def run_benchmarks(shapes):
65+
print("fn,m,k,n,fp_time,int_mm_time,ratio")
66+
positives = []
67+
dtype = torch.bfloat16
68+
device = "cuda"
69+
for fn, (m, k, n) in itertools.product(
70+
[run_int_mm_benchmark, run_int_scaled_mm_benchmark], shapes
71+
):
72+
x = torch.randn(m, k, dtype=dtype, device=device)
73+
w = torch.randn(n, k, dtype=dtype, device=device).t()
74+
b = torch.randn(m, n, dtype=dtype, device=device)
75+
76+
fp_time, int_mm_time = fn(x, w, b)
77+
ratio = fp_time / int_mm_time
78+
result = ",".join(map(str, [fn, m, k, n, fp_time, int_mm_time, ratio]))
79+
print(result)
80+
81+
82+
if __name__ == "__main__":
83+
parser = argparse.ArgumentParser(description="integer matmul benchmarks")
84+
parser.add_argument("file_path", type=str, help="Path to csv file with shapes")
85+
args = parser.parse_args()
86+
# Access the file path provided as an argument
87+
file_path = args.file_path
88+
file_path = pathlib.Path(file_path)
89+
assert file_path.is_file()
90+
91+
# Format is (m, k, n)
92+
shapes = list(csv.reader(open(file_path, "r")))[1:]
93+
# Turn into list of int tuples
94+
shapes = list(map(lambda x: tuple(map(int, x)), shapes))
95+
96+
run_benchmarks(shapes)

benchmarks/intmm_shapes.csv

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
m,k,n
2+
1024,1024,2304
3+
1024,1024,4608
4+
1024,8192,2304
5+
1024,8192,4608
6+
1152,1024,2048
7+
1152,2048,16384
8+
1152,2048,2048
9+
1152,3072,2048
10+
1152,4096,2048
11+
1152,8192,2048
12+
1,2048,1024
13+
1,2048,2048
14+
1,2048,4096
15+
144,2048,16384
16+
144,2048,2048
17+
144,4096,2048
18+
144,8192,2048
19+
1472,1024,154
20+
1472,1024,308
21+
1472,2048,154
22+
1472,2048,308
23+
1472,512,154
24+
1472,512,308
25+
1,512,2048
26+
154,1472,1024
27+
154,1472,2048
28+
154,1472,512
29+
18432,1024,512
30+
18432,1536,512
31+
18432,2048,512
32+
18432,512,4096
33+
18432,512,512
34+
2048,1024,1
35+
2048,1024,2
36+
2048,16384,1152
37+
2048,16384,144
38+
2048,16384,288
39+
2048,16384,576
40+
2048,2048,1
41+
2048,2048,1152
42+
2048,2048,144
43+
2048,2048,2
44+
2048,2048,288
45+
2048,2048,576
46+
2048,4096,1
47+
2048,4096,2
48+
2048,512,18432
49+
2048,512,9216
50+
2,2048,1024
51+
2,2048,2048
52+
2,2048,4096
53+
2304,1024,1024
54+
2304,1024,8192
55+
2304,1536,1024
56+
2304,2048,1024
57+
2304,3072,1024
58+
2304,4096,1024
59+
2304,512,1024
60+
231,4096,1024
61+
231,4096,2048
62+
231,4096,512
63+
231,768,1024
64+
231,768,2048
65+
231,768,512
66+
2,512,2048
67+
288,2048,16384
68+
288,2048,2048
69+
288,4096,2048
70+
288,8192,2048
71+
308,1472,1024
72+
308,1472,2048
73+
308,1472,512
74+
4096,1024,2304
75+
4096,1024,231
76+
4096,1024,4608
77+
4096,1024,462
78+
4096,2048,231
79+
4096,2048,462
80+
4096,512,231
81+
4096,512,462
82+
4608,1024,1024
83+
4608,1024,8192
84+
4608,1536,1024
85+
4608,2048,1024
86+
4608,3072,1024
87+
4608,4096,1024
88+
4608,512,1024
89+
462,4096,1024
90+
462,4096,2048
91+
462,4096,512
92+
462,768,1024
93+
462,768,2048
94+
462,768,512
95+
512,2048,1
96+
512,2048,2
97+
512,4096,18432
98+
512,4096,9216
99+
512,512,18432
100+
512,512,9216
101+
576,1024,2048
102+
576,2048,16384
103+
576,2048,2048
104+
576,3072,2048
105+
576,4096,2048
106+
576,8192,2048
107+
768,1024,231
108+
768,1024,462
109+
768,2048,231
110+
768,2048,462
111+
768,512,231
112+
768,512,462
113+
8192,2048,1152
114+
8192,2048,144
115+
8192,2048,288
116+
8192,2048,576
117+
9216,1024,512
118+
9216,1536,512
119+
9216,2048,512
120+
9216,512,4096
121+
9216,512,512
122+
32768,3072,768
123+
32768,768,2304
124+
32768,768,3072
125+
32768,768,768
126+
39200,768,2304
127+
39200,768,768

benchmarks/print_config_shapes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torchao
2+
3+
from torchao.kernel import autotuner
4+
5+
configs = autotuner._load_best_configs()
6+
7+
print("m,k,n")
8+
for k, v in configs.items():
9+
a_shape = k[1]
10+
b_shape = k[4]
11+
M, K0 = a_shape
12+
K1, N = b_shape
13+
14+
assert K0 == K1
15+
16+
print(f"{M},{K0},{N}")

benchmarks/sam_vit_b_shapes.csv

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
m,k,n
2+
32768,3072,768
3+
32768,768,2304
4+
32768,768,3072
5+
32768,768,768
6+
39200,768,2304
7+
39200,768,768

dev-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pytest
22
expecttest
3-
packaging
3+
parameterized
4+
packaging

requirements-lintrunner.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Lintrunner itself
2+
lintrunner==0.11.0
3+
lintrunner-adapters==0.11.0
4+
5+
# Flake 8 and its dependencies
6+
flake8==6.0.0
7+
flake8-breakpoint==1.1.0
8+
flake8-bugbear==23.6.5
9+
flake8-comprehensions==3.12.0
10+
flake8-pyi==23.5.0
11+
mccabe==0.7.0
12+
pycodestyle==2.10.0
13+
torchfix==0.1.1
14+
15+
# UFMT
16+
black==24.3.0
17+
ufmt==2.5.1
18+
usort==1.0.5
19+
20+
# Other linters
21+
clang-format==12.0.1
22+
cmakelint==1.4.1

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
torch
22
numpy
33
sentencepiece
4+
packaging

0 commit comments

Comments
 (0)