Skip to content

Commit 284fc37

Browse files
committed
Add CUTLASS-based row-wise scaled sparse FP8 kernel
1 parent 7963f9c commit 284fc37

32 files changed

+2148
-499
lines changed

benchmarks/benchmark_rowwise_scaled_linear_cutlass.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,55 @@
77
rowwise_scaled_linear_cutlass_s4s4,
88
rowwise_scaled_linear_cutlass_s8s4,
99
)
10+
from torchao.quantization.quant_api import (
11+
_int4_symm_cutlass_quant,
12+
_int8_symm_cutlass_quant,
13+
)
14+
15+
dtype = torch.bfloat16
16+
dtypeq = torch.int8
17+
dtype_scale = torch.float32
18+
device = torch.device("cuda")
1019

1120

1221
def benchmark_microseconds(f, *args):
1322
return do_bench(lambda: f(*args), return_mode="median") * 1e3
1423

1524

16-
def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
17-
assert A_nbits in (4, 8) and B_nbits in (4, 8)
25+
def get_problem(m: int, n: int, k: int, Xq_nbits: int):
26+
assert k % 2 == 0
27+
assert Xq_nbits in [4, 8]
28+
29+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
30+
W_ref = torch.rand((n, k), dtype=dtype, device=device)
1831

19-
dev = torch.device("cuda")
20-
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
21-
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
22-
B = torch.randint(
23-
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
32+
X_quant_func = (
33+
_int4_symm_cutlass_quant if Xq_nbits == 4 else _int8_symm_cutlass_quant
2434
)
25-
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
26-
C = None
35+
W_quant_func = _int4_symm_cutlass_quant
36+
X_aqt = X_quant_func(X_ref)
37+
W_aqt = W_quant_func(W_ref)
2738

28-
return A, A_scale, B, B_scale, C
39+
Xq = X_aqt.tensor_impl.int_data
40+
X_scale = X_aqt.tensor_impl.scale
41+
Wq = W_aqt.tensor_impl.int_data
42+
W_scale = W_aqt.tensor_impl.scale
43+
bias = None
44+
out_dtype = dtype
2945

46+
return (X_ref, W_ref), (Xq, X_scale, Wq, W_scale, bias, out_dtype)
3047

31-
def benchmark(m: int, k: int, n: int):
32-
dev = torch.device("cuda")
33-
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
34-
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
35-
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)
3648

37-
A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
38-
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
39-
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
49+
def benchmark(m: int, k: int, n: int):
50+
ref_args, args = get_problem(m, n, k, 4)
51+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args)
52+
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
53+
rowwise_scaled_linear_cutlass_s4s4, *args
4054
)
4155

42-
A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
43-
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
44-
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
56+
_, args = get_problem(m, n, k, 8)
57+
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
58+
rowwise_scaled_linear_cutlass_s8s4, *args
4559
)
4660

4761
return {
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pandas as pd
2+
import torch
3+
from tqdm import tqdm
4+
from triton.testing import do_bench
5+
6+
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
7+
from torchao.quantization.quant_api import (
8+
_float8_cutlass_quant,
9+
_float8_cutlass_quant_sparse,
10+
)
11+
from torchao.sparsity.utils import create_semi_structured_tensor
12+
13+
dtype = torch.bfloat16
14+
dtypeq_X = torch.float8_e5m2
15+
dtypeq_W = torch.float8_e4m3fn
16+
device = torch.device("cuda")
17+
18+
19+
def benchmark_microseconds(f, *args):
20+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
21+
22+
23+
def get_problem(m: int, n: int, k: int):
24+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
25+
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)
26+
27+
X_quant_func = _float8_cutlass_quant
28+
W_quant_func = _float8_cutlass_quant_sparse
29+
X_aqt = X_quant_func(X_ref, dtypeq_X)
30+
W_aqt = W_quant_func(W_ref, dtypeq_W)
31+
32+
Xq = X_aqt.tensor_impl.float8_data
33+
X_scale = X_aqt.tensor_impl.scale
34+
Wq_sparse = W_aqt.tensor_impl.sparse
35+
W_meta = W_aqt.tensor_impl.meta
36+
W_scale = W_aqt.tensor_impl.scale
37+
bias = None
38+
out_dtype = dtype
39+
40+
return (X_ref, W_ref), (Xq, X_scale, Wq_sparse, W_meta, W_scale, bias, out_dtype)
41+
42+
43+
def benchmark(m: int, k: int, n: int):
44+
ref_args, args = get_problem(m, n, k)
45+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args)
46+
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
47+
rowwise_scaled_linear_sparse_cutlass_f8f8, *args
48+
)
49+
50+
return {
51+
"m": m,
52+
"k": k,
53+
"n": n,
54+
"fp16_latency (ms)": fp16_time,
55+
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
56+
"f8f8 speedup (d/s)": fp16_time
57+
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
58+
}
59+
60+
61+
if __name__ == "__main__":
62+
k_vals = (8192, 8192, 8192, 28672)
63+
n_vals = (8192, 10240, 57344, 8192)
64+
65+
results = []
66+
for m in tqdm([1 << i for i in range(10)]):
67+
for n, k in zip(n_vals, k_vals):
68+
results.append(benchmark(m, k, n))
69+
70+
df = pd.DataFrame(results)
71+
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
72+
print(df.to_markdown(index=False))

docs/source/api_ref_dtypes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Layouts and Tensor Subclasses
2828
MarlinQQQLayout
2929
Int4CPULayout
3030
CutlassInt4PackedLayout
31+
CutlassSemiSparseLayout
3132

3233
Quantization techniques
3334
-----------------------

setup.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import copy
67
import glob
78
import os
89
import subprocess
@@ -73,6 +74,7 @@ def use_debug_mode():
7374
BuildExtension,
7475
CppExtension,
7576
CUDAExtension,
77+
_get_cuda_arch_flags,
7678
)
7779

7880
build_torchao_experimental_mps = (
@@ -234,7 +236,12 @@ def get_extensions():
234236
extra_link_args = []
235237
extra_compile_args = {
236238
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
237-
"nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"],
239+
"nvcc": [
240+
"-DNDEBUG" if not debug_mode else "-DDEBUG",
241+
"-O3" if not debug_mode else "-O0",
242+
"-t=0",
243+
"-std=c++17",
244+
],
238245
}
239246

240247
if not IS_WINDOWS:
@@ -269,6 +276,7 @@ def get_extensions():
269276
sources += cuda_sources
270277

271278
use_cutlass = False
279+
cutlass_90a_sources = None
272280
if use_cuda and not IS_WINDOWS:
273281
use_cutlass = True
274282
cutlass_dir = os.path.join(third_party_path, "cutlass")
@@ -284,8 +292,44 @@ def get_extensions():
284292
"-I" + cutlass_include_dir,
285293
"-I" + cutlass_tools_include_dir,
286294
"-I" + cutlass_extensions_include_dir,
295+
"-DCUTE_USE_PACKED_TUPLE=1",
296+
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
297+
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
298+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
299+
"--ftemplate-backtrace-limit=0",
300+
# "--keep",
301+
# "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage",
302+
# "--resource-usage",
303+
# "-lineinfo",
304+
# "-DCUTLASS_ENABLE_GDC_FOR_SM90", # https://github.com/NVIDIA/cutlass/blob/main/media/docs/dependent_kernel_launch.md
287305
]
288306
)
307+
308+
cuda_arch_flags = _get_cuda_arch_flags()
309+
build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
310+
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
311+
if build_for_sm90 and not build_for_sm90a:
312+
cutlass_90a_sources = [
313+
os.path.join(
314+
extensions_cuda_dir,
315+
"rowwise_scaled_linear_sparse_cutlass",
316+
"rowwise_scaled_linear_sparse_cutlass_f8f8.cu",
317+
),
318+
os.path.join(
319+
extensions_cuda_dir,
320+
"to_sparse_semi_structured_cutlass_sm9x",
321+
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
322+
),
323+
]
324+
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
325+
cutlass_90a_sources.append(
326+
os.path.join(
327+
extensions_cuda_dir,
328+
"rowwise_scaled_linear_sparse_cutlass",
329+
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
330+
)
331+
)
332+
sources = [s for s in sources if s not in cutlass_90a_sources]
289333
else:
290334
# Remove CUTLASS-based kernels from the cuda_sources list. An
291335
# assumption is that these files will have "cutlass" in its
@@ -309,6 +353,21 @@ def get_extensions():
309353
)
310354
)
311355

356+
if cutlass_90a_sources is not None and len(cutlass_90a_sources) > 0:
357+
cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
358+
cutlass_90a_extra_compile_args["nvcc"].extend(
359+
cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a"]
360+
)
361+
ext_modules.append(
362+
extension(
363+
"torchao._C",
364+
cutlass_90a_sources,
365+
py_limited_api=True,
366+
extra_compile_args=cutlass_90a_extra_compile_args,
367+
extra_link_args=extra_link_args,
368+
)
369+
)
370+
312371
if build_torchao_experimental:
313372
ext_modules.append(
314373
CMakeExtension(
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# This actually belongs to test_ops.py, extracted here for easier
2+
# maintenance.
3+
4+
import itertools
5+
6+
import pytest
7+
import torch
8+
from torch.testing._internal.optests import opcheck
9+
10+
from torchao.quantization.quant_api import (
11+
_int4_symm_cutlass_quant,
12+
_int8_symm_cutlass_quant,
13+
)
14+
15+
DTYPES = [torch.float16, torch.bfloat16]
16+
BATCH_SIZE = [1, 4, 8, 16, 32, 64]
17+
SIZE_MNK = [
18+
(2, 512, 128),
19+
(3, 2048, 2048),
20+
(4, 3584, 640),
21+
(13, 8704, 8576),
22+
(26, 18944, 1664),
23+
(67, 6656, 1408),
24+
]
25+
USE_BIAS = [False, True]
26+
TEST_PARAMS = list(
27+
itertools.product(
28+
DTYPES,
29+
BATCH_SIZE,
30+
SIZE_MNK,
31+
USE_BIAS,
32+
)
33+
)
34+
35+
36+
def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias):
37+
size_m, size_n, size_k = size_mnk
38+
39+
X = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
40+
W = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
41+
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
42+
43+
Xq_bits = 4 if op == torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4 else 8
44+
45+
X_quant_func = (
46+
_int4_symm_cutlass_quant if Xq_bits == 4 else _int8_symm_cutlass_quant
47+
)
48+
W_quant_func = _int4_symm_cutlass_quant
49+
X_aqt = X_quant_func(X)
50+
W_aqt = W_quant_func(W)
51+
52+
Xq = X_aqt.tensor_impl.int_data
53+
X_scale = X_aqt.tensor_impl.scale
54+
Wq = W_aqt.tensor_impl.int_data
55+
W_scale = W_aqt.tensor_impl.scale
56+
Xq_int8, _, _ = X_aqt.tensor_impl.get_plain()
57+
Wq_int8, _, _ = W_aqt.tensor_impl.get_plain()
58+
59+
# If torch.nn.functional.linear(X, W, bias) used as reference, the
60+
# error would be too big. The calculation below is approximately
61+
# what rowwise_scaled_linear_cutlass kernel is doing.
62+
output_ref = (Xq_int8.float() @ Wq_int8.float().T) * X_scale[..., None] * W_scale
63+
if bias is not None:
64+
output_ref += bias
65+
output_ref = output_ref.to(dtype).reshape(X.shape[:-1] + (size_n,))
66+
67+
fn_inputs = (Xq, X_scale, Wq, W_scale, bias, dtype)
68+
try:
69+
output = op(*fn_inputs)
70+
except NotImplementedError:
71+
pytest.xfail("operator not implemented")
72+
73+
torch.testing.assert_close(output, output_ref)
74+
75+
# Perform opcheck.
76+
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
77+
opcheck(
78+
op,
79+
fn_inputs,
80+
test_utils=test_utils,
81+
)
82+
83+
84+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
85+
@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS)
86+
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
87+
run_test_for_op(
88+
torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4,
89+
dtype,
90+
batch_size,
91+
size_mnk,
92+
use_bias,
93+
)
94+
95+
96+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
97+
@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS)
98+
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
99+
run_test_for_op(
100+
torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4,
101+
dtype,
102+
batch_size,
103+
size_mnk,
104+
use_bias,
105+
)

0 commit comments

Comments
 (0)