Skip to content

Autotuner for int mm Triton kernels #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ba24af7
int_mm benchmarks
cpuhrsch Feb 29, 2024
1f5a811
Baseline results
cpuhrsch Feb 29, 2024
b8bcf27
Running shapes over night
cpuhrsch Feb 29, 2024
44c97a6
Rerun result 1 with empty cache
cpuhrsch Feb 29, 2024
579fc5e
Filter results_2
cpuhrsch Feb 29, 2024
8251780
Triton matmul
cpuhrsch Feb 29, 2024
434b92f
Store best configs
cpuhrsch Mar 2, 2024
e3dcf6d
Faster autotuner
cpuhrsch Mar 2, 2024
cacef6b
Inject intmm triton
cpuhrsch Mar 2, 2024
0f2a706
Scaled int mm
cpuhrsch Mar 3, 2024
395b93c
More scaled matmul
cpuhrsch Mar 3, 2024
f2bf0fd
evict_last
cpuhrsch Mar 5, 2024
ea635ed
Only 1 scales and better Triton code
cpuhrsch Mar 6, 2024
ffd0d66
a100 specific configs based on all configs
cpuhrsch Mar 6, 2024
75ef774
data pkl based on regular configs
cpuhrsch Mar 6, 2024
404814f
Enable all configs
cpuhrsch Mar 6, 2024
145a498
Move configs to torchao
cpuhrsch Mar 12, 2024
016a839
Read configs from library package
cpuhrsch Mar 12, 2024
36bfe59
More environment variables
cpuhrsch Mar 13, 2024
62bc110
Benchmark for sam shapes
cpuhrsch Mar 13, 2024
8ce5707
More A100 configs
cpuhrsch Mar 13, 2024
2797cf3
Revert quant primitives
cpuhrsch Mar 13, 2024
73f6671
Merge remote-tracking branch 'origin/main' into intmmbenchmarks1
cpuhrsch Mar 13, 2024
20333cf
Make autotuner work with compile
cpuhrsch Mar 14, 2024
3f6ddb4
Merge branch 'main' of github.com:pytorch-labs/ao into intmmbenchmarks1
cpuhrsch Mar 14, 2024
c846b70
Make benchmark output a bit more clear
cpuhrsch Mar 14, 2024
8f4d6cc
Make benchmark output a bit more clear
cpuhrsch Mar 14, 2024
54976d4
Merge branch 'main' of github.com:pytorch-labs/ao into intmmbenchmarks1
cpuhrsch Mar 18, 2024
e590bce
Basic test harness
cpuhrsch Mar 18, 2024
b376dae
Basic test harness
cpuhrsch Mar 18, 2024
171c491
scaled int mm tests
cpuhrsch Mar 19, 2024
5437476
dev requirements
cpuhrsch Mar 19, 2024
ef26555
Address comments
cpuhrsch Mar 19, 2024
f4e6b7f
Merge branch 'main' of github.com:pytorch-labs/ao into intmmbenchmarks1
cpuhrsch Mar 19, 2024
bc6d23f
lintrunner.toml
cpuhrsch Mar 19, 2024
9c45710
Much lint, so wow
cpuhrsch Mar 19, 2024
b4ddebc
Merge branch 'main' of github.com:pytorch-labs/ao into intmmbenchmarks1
cpuhrsch Mar 20, 2024
771fadd
Merge branch 'main' of github.com:pytorch-labs/ao into intmmbenchmarks1
cpuhrsch Mar 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
merge_base_with = "origin/main"

[[linter]]
code = 'FLAKE8'
include_patterns = ['**/*.py']
exclude_patterns = [
'third-party/**',
'**/third-party/**',
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'flake8_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--requirement=requirements-lintrunner.txt',
]

# Black + usort
[[linter]]
code = 'UFMT'
include_patterns = [
'**/*.py',
'**/*.pyi',
]
exclude_patterns = [
'third-party/**',
'**/third-party/**',
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'ufmt_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--no-black-binary',
'--requirement=requirements-lintrunner.txt',
]
is_formatter = true
96 changes: 96 additions & 0 deletions benchmarks/intmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse
import csv
import itertools
import math
import pathlib

import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from torchao.kernel.intmm_triton import int_matmul, int_scaled_matmul

torch._dynamo.config.cache_size_limit = 128
torch._dynamo.config.accumulated_cache_size_limit = 128

dtype = torch.float16
device = "cuda"


def benchmark_in_ms(warmup, iters, f, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this in benchmark util instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once I add the next benchmark for weight only

for _ in range(warmup):
f(*args, **kwargs)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

for _ in range(iters):
f(*args, **kwargs)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / float(iters)


@torch.compile(mode="max-autotune")
def compiled_mm(x, w):
return torch.mm(x, w)


@torch.compile(mode="max-autotune")
def compiled_int_mm(x, w):
return torch._int_mm(x, w)


def run_int_mm_benchmark(x, w, b):
fp_time = benchmark_in_ms(10, 100, torch.mm, x, w)
x_int = x.to(dtype=torch.int8)
w_int = w.to(dtype=torch.int8)
int_mm_time = benchmark_in_ms(10, 100, int_matmul, x_int, w_int)
return fp_time, int_mm_time


def run_int_scaled_mm_benchmark(x, w, b):
scales = x.sum(-1, keepdim=True)
fp_time = benchmark_in_ms(10, 100, lambda x, w, s: torch.mm(x, w) * s, x, w, scales)
x_int = x.to(dtype=torch.int8)
w_int = w.to(dtype=torch.int8)
int_scaled_mm_time = benchmark_in_ms(
10, 100, int_scaled_matmul, x_int, w_int, scales
)
return fp_time, int_scaled_mm_time


def run_benchmarks(shapes):
print("fn,m,k,n,fp_time,int_mm_time,ratio")
positives = []
dtype = torch.bfloat16
device = "cuda"
for fn, (m, k, n) in itertools.product(
[run_int_mm_benchmark, run_int_scaled_mm_benchmark], shapes
):
x = torch.randn(m, k, dtype=dtype, device=device)
w = torch.randn(n, k, dtype=dtype, device=device).t()
b = torch.randn(m, n, dtype=dtype, device=device)

fp_time, int_mm_time = fn(x, w, b)
ratio = fp_time / int_mm_time
result = ",".join(map(str, [fn, m, k, n, fp_time, int_mm_time, ratio]))
print(result)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="integer matmul benchmarks")
parser.add_argument("file_path", type=str, help="Path to csv file with shapes")
args = parser.parse_args()
# Access the file path provided as an argument
file_path = args.file_path
file_path = pathlib.Path(file_path)
assert file_path.is_file()

# Format is (m, k, n)
shapes = list(csv.reader(open(file_path, "r")))[1:]
# Turn into list of int tuples
shapes = list(map(lambda x: tuple(map(int, x)), shapes))

run_benchmarks(shapes)
127 changes: 127 additions & 0 deletions benchmarks/intmm_shapes.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
m,k,n
1024,1024,2304
1024,1024,4608
1024,8192,2304
1024,8192,4608
1152,1024,2048
1152,2048,16384
1152,2048,2048
1152,3072,2048
1152,4096,2048
1152,8192,2048
1,2048,1024
1,2048,2048
1,2048,4096
144,2048,16384
144,2048,2048
144,4096,2048
144,8192,2048
1472,1024,154
1472,1024,308
1472,2048,154
1472,2048,308
1472,512,154
1472,512,308
1,512,2048
154,1472,1024
154,1472,2048
154,1472,512
18432,1024,512
18432,1536,512
18432,2048,512
18432,512,4096
18432,512,512
2048,1024,1
2048,1024,2
2048,16384,1152
2048,16384,144
2048,16384,288
2048,16384,576
2048,2048,1
2048,2048,1152
2048,2048,144
2048,2048,2
2048,2048,288
2048,2048,576
2048,4096,1
2048,4096,2
2048,512,18432
2048,512,9216
2,2048,1024
2,2048,2048
2,2048,4096
2304,1024,1024
2304,1024,8192
2304,1536,1024
2304,2048,1024
2304,3072,1024
2304,4096,1024
2304,512,1024
231,4096,1024
231,4096,2048
231,4096,512
231,768,1024
231,768,2048
231,768,512
2,512,2048
288,2048,16384
288,2048,2048
288,4096,2048
288,8192,2048
308,1472,1024
308,1472,2048
308,1472,512
4096,1024,2304
4096,1024,231
4096,1024,4608
4096,1024,462
4096,2048,231
4096,2048,462
4096,512,231
4096,512,462
4608,1024,1024
4608,1024,8192
4608,1536,1024
4608,2048,1024
4608,3072,1024
4608,4096,1024
4608,512,1024
462,4096,1024
462,4096,2048
462,4096,512
462,768,1024
462,768,2048
462,768,512
512,2048,1
512,2048,2
512,4096,18432
512,4096,9216
512,512,18432
512,512,9216
576,1024,2048
576,2048,16384
576,2048,2048
576,3072,2048
576,4096,2048
576,8192,2048
768,1024,231
768,1024,462
768,2048,231
768,2048,462
768,512,231
768,512,462
8192,2048,1152
8192,2048,144
8192,2048,288
8192,2048,576
9216,1024,512
9216,1536,512
9216,2048,512
9216,512,4096
9216,512,512
32768,3072,768
32768,768,2304
32768,768,3072
32768,768,768
39200,768,2304
39200,768,768
16 changes: 16 additions & 0 deletions benchmarks/print_config_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torchao

from torchao.kernel import autotuner

configs = autotuner._load_best_configs()

print("m,k,n")
for k, v in configs.items():
a_shape = k[1]
b_shape = k[4]
M, K0 = a_shape
K1, N = b_shape

assert K0 == K1

print(f"{M},{K0},{N}")
7 changes: 7 additions & 0 deletions benchmarks/sam_vit_b_shapes.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
m,k,n
32768,3072,768
32768,768,2304
32768,768,3072
32768,768,768
39200,768,2304
39200,768,768
3 changes: 2 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
expecttest
packaging
parameterized
packaging
22 changes: 22 additions & 0 deletions requirements-lintrunner.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Lintrunner itself
lintrunner==0.11.0
lintrunner-adapters==0.11.0

# Flake 8 and its dependencies
flake8==6.0.0
flake8-breakpoint==1.1.0
flake8-bugbear==23.6.5
flake8-comprehensions==3.12.0
flake8-pyi==23.5.0
mccabe==0.7.0
pycodestyle==2.10.0
torchfix==0.1.1

# UFMT
black==24.2.0
ufmt==2.5.1
usort==1.0.5

# Other linters
clang-format==12.0.1
cmakelint==1.4.1
27 changes: 17 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,35 @@

import os
from datetime import datetime
from setuptools import setup, find_packages
current_date = datetime.now().strftime('%Y.%m.%d')

from setuptools import find_packages, setup

current_date = datetime.now().strftime("%Y.%m.%d")


def read_requirements(file_path):
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
return file.read().splitlines()


# Determine the package name based on the presence of an environment variable
package_name = 'torchao-nightly' if os.environ.get('TORCHAO_NIGHTLY') else 'torchao'
package_name = "torchao-nightly" if os.environ.get("TORCHAO_NIGHTLY") else "torchao"

# Version is year.month.date if using nightlies
version = current_date if package_name == 'torchao-nightly' else '0.0.3'
version = current_date if package_name == "torchao-nightly" else "0.0.3"


setup(
name=package_name,
version=version,
packages=find_packages(),
install_requires=read_requirements('requirements.txt'),
description='Package for applying ao techniques to GPU models',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
url='https://github.com/pytorch-labs/ao',
include_package_data=True,
package_data={
"torchao.kernel.configs": ["*.pkl"],
},
install_requires=read_requirements("requirements.txt"),
description="Package for applying ao techniques to GPU models",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/pytorch-labs/ao",
)
Loading