Skip to content

Commit d1d96f7

Browse files
committed
Add CUTLASS-based row-wise scaled sparse FP8 kernel
1 parent d00ee41 commit d1d96f7

30 files changed

+1957
-420
lines changed

benchmarks/benchmark_rowwise_scaled_linear_cutlass.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
1818

1919
dev = torch.device("cuda")
2020
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
21-
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
21+
A_scale = torch.randn((m,), dtype=torch.float32, device=dev)
2222
B = torch.randint(
2323
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
2424
)
25-
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
26-
C = None
25+
B_scale = torch.randn((n,), dtype=torch.float32, device=dev)
26+
bias = None
27+
out_dtype = torch.bfloat16
2728

28-
return A, A_scale, B, B_scale, C
29+
return A, A_scale, B, B_scale, bias, out_dtype
2930

3031

3132
def benchmark(m: int, k: int, n: int):
@@ -34,14 +35,14 @@ def benchmark(m: int, k: int, n: int):
3435
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
3536
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)
3637

37-
A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
38+
A, A_scale, B, B_scale, bias, out_dtype = get_problem(m, n, k, 8, 4)
3839
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
39-
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
40+
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, bias, out_dtype
4041
)
4142

42-
A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
43+
A, A_scale, B, B_scale, bias, out_dtype = get_problem(m, n, k, 4, 4)
4344
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
44-
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
45+
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, bias, out_dtype
4546
)
4647

4748
return {
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pandas as pd
2+
import torch
3+
from tqdm import tqdm
4+
from triton.testing import do_bench
5+
6+
from torchao.ops import (
7+
rowwise_scaled_linear_sparse_cutlass_f8f8,
8+
to_sparse_semi_structured_cutlass_sm9x_f8,
9+
)
10+
11+
12+
def benchmark_microseconds(f, *args):
13+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
14+
15+
16+
def get_problem(m: int, n: int, k: int):
17+
dev = torch.device("cuda")
18+
19+
A = torch.randn((m, k), dtype=torch.half, device=dev).to(torch.float8_e5m2)
20+
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
21+
B = torch.randn((n, k), dtype=torch.half, device=dev).to(torch.float8_e4m3fn)
22+
B_sp, B_meta = to_sparse_semi_structured_cutlass_sm9x_f8(B)
23+
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
24+
25+
return A, A_scale, B_sp, B_meta, B_scale
26+
27+
28+
def benchmark(m: int, k: int, n: int):
29+
dev = torch.device("cuda")
30+
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
31+
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
32+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)
33+
34+
A, A_scale, B_sp, B_meta, B_scale = get_problem(m, n, k)
35+
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
36+
rowwise_scaled_linear_sparse_cutlass_f8f8, A, A_scale, B_sp, B_meta, B_scale
37+
)
38+
39+
return {
40+
"m": m,
41+
"k": k,
42+
"n": n,
43+
"fp16_latency (ms)": fp16_time,
44+
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
45+
"f8f8 speedup (d/s)": fp16_time
46+
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
47+
}
48+
49+
50+
if __name__ == "__main__":
51+
k_vals = (8192, 8192, 8192, 28672)
52+
n_vals = (8192, 10240, 57344, 8192)
53+
54+
results = []
55+
for m in tqdm([1 << i for i in range(10)]):
56+
for n, k in zip(n_vals, k_vals):
57+
results.append(benchmark(m, k, n))
58+
59+
df = pd.DataFrame(results)
60+
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
61+
print(df.to_markdown(index=False))

docs/source/api_ref_dtypes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Layouts and Tensor Subclasses
2828
MarlinQQQLayout
2929
Int4CPULayout
3030
CutlassInt4PackedLayout
31+
CutlassSemiSparseLayout
3132

3233
Quantization techniques
3334
-----------------------

setup.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import copy
67
import glob
78
import os
89
import subprocess
@@ -73,6 +74,7 @@ def use_debug_mode():
7374
BuildExtension,
7475
CppExtension,
7576
CUDAExtension,
77+
_get_cuda_arch_flags,
7678
)
7779

7880
# Constant known variables used throughout this file
@@ -251,6 +253,7 @@ def get_extensions():
251253
sources += cuda_sources
252254

253255
use_cutlass = False
256+
cutlass_90a_sources = None
254257
if use_cuda and not IS_WINDOWS:
255258
use_cutlass = True
256259
cutlass_dir = os.path.join(third_party_path, "cutlass")
@@ -266,8 +269,46 @@ def get_extensions():
266269
"-I" + cutlass_include_dir,
267270
"-I" + cutlass_tools_include_dir,
268271
"-I" + cutlass_extensions_include_dir,
272+
"-DNDEBUG" if not debug_mode else "",
273+
"-DCUTE_USE_PACKED_TUPLE=1",
274+
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
275+
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
276+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
277+
"--use_fast_math",
278+
"--ftemplate-backtrace-limit=0",
279+
# "--keep",
280+
# "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage",
281+
# "--resource-usage",
282+
# "-lineinfo",
283+
# "-DCUTLASS_ENABLE_GDC_FOR_SM90", # https://github.com/NVIDIA/cutlass/blob/main/media/docs/dependent_kernel_launch.md
269284
]
270285
)
286+
287+
cuda_arch_flags = _get_cuda_arch_flags()
288+
build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
289+
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
290+
if build_for_sm90 and not build_for_sm90a:
291+
cutlass_90a_sources = [
292+
os.path.join(
293+
extensions_cuda_dir,
294+
"rowwise_scaled_linear_sparse_cutlass",
295+
"rowwise_scaled_linear_sparse_cutlass_f8f8.cu",
296+
),
297+
os.path.join(
298+
extensions_cuda_dir,
299+
"to_sparse_semi_structured_cutlass_sm9x",
300+
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
301+
),
302+
]
303+
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
304+
cutlass_90a_sources.append(
305+
os.path.join(
306+
extensions_cuda_dir,
307+
"rowwise_scaled_linear_sparse_cutlass",
308+
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
309+
)
310+
)
311+
sources = [s for s in sources if s not in cutlass_90a_sources]
271312
else:
272313
# Remove CUTLASS-based kernels from the cuda_sources list. An
273314
# assumption is that these files will have "cutlass" in its
@@ -291,6 +332,21 @@ def get_extensions():
291332
)
292333
)
293334

335+
if cutlass_90a_sources is not None and len(cutlass_90a_sources) > 0:
336+
cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
337+
cutlass_90a_extra_compile_args["nvcc"].extend(
338+
cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a"]
339+
)
340+
ext_modules.append(
341+
extension(
342+
"torchao._C",
343+
cutlass_90a_sources,
344+
py_limited_api=True,
345+
extra_compile_args=cutlass_90a_extra_compile_args,
346+
extra_link_args=extra_link_args,
347+
)
348+
)
349+
294350
if build_torchao_experimental:
295351
ext_modules.append(
296352
CMakeExtension(

test/test_rowwise_scaled_linear_cutlass.py

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,75 +7,73 @@
77
rowwise_scaled_linear_cutlass_s4s4,
88
rowwise_scaled_linear_cutlass_s8s4,
99
)
10-
from torchao.quantization.utils import group_quantize_tensor_symmetric
10+
from torchao.quantization.quant_api import (
11+
_int4_symm_per_token_quant_cutlass,
12+
_int8_symm_per_token_quant_cutlass,
13+
)
14+
from torchao.quantization.quant_primitives import (
15+
MappingType,
16+
ZeroPointDomain,
17+
)
18+
from torchao.quantization.utils import _get_per_token_block_size
1119

12-
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
13-
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
14-
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
20+
DTYPES = [torch.float16, torch.bfloat16]
21+
BATCH_SIZE = [1, 4, 8, 16, 32, 64]
22+
SIZE_MNK = [
1523
(2, 512, 128),
1624
(3, 2048, 2048),
1725
(4, 3584, 640),
1826
(13, 8704, 8576),
1927
(26, 18944, 1664),
2028
(67, 6656, 1408),
2129
]
22-
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True]
23-
ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list(
30+
USE_BIAS = [False, True]
31+
TEST_PARAMS = list(
2432
itertools.product(
25-
ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE,
26-
ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE,
27-
ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK,
28-
ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS,
33+
DTYPES,
34+
BATCH_SIZE,
35+
SIZE_MNK,
36+
USE_BIAS,
2937
)
3038
)
3139

3240

33-
def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias):
34-
assert xq_bits in [4, 8]
35-
assert wq_bits in [4, 8]
36-
41+
def run_test_for_op(op, dtype, batch_size, size_mnk, use_bias):
3742
size_m, size_n, size_k = size_mnk
3843

3944
x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
4045
w = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
4146
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None
4247

43-
x_2d = x.view(-1, x.shape[-1])
44-
xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric(
45-
x_2d, xq_bits, size_k, dtype
46-
)
47-
assert torch.all(xq_2d_zeros == 0)
48-
xq_s8 = xq_2d_s8.reshape(x.shape)
49-
if xq_bits == 4:
50-
xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF)
51-
else:
52-
xq = xq_s8
53-
xq_scales = xq_2d_scales.reshape(x.shape[:-1])
54-
55-
wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric(
56-
w, wq_bits, size_n, dtype
48+
xq_bits = 4 if op == rowwise_scaled_linear_cutlass_s4s4 else 8
49+
pack_s4 = lambda x: (x[..., 1::2] << 4) | (x[..., 0::2] & 0xF)
50+
51+
x_quant_func = (
52+
_int4_symm_per_token_quant_cutlass
53+
if xq_bits == 4
54+
else _int8_symm_per_token_quant_cutlass
5755
)
58-
assert torch.all(wq_zeros == 0)
59-
if wq_bits == 4:
60-
wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF)
61-
else:
62-
wq = wq_s8
56+
x_aqt = x_quant_func(x)
57+
xq_s8, xq_scales, zero_points = x_aqt.tensor_impl.get_plain()
58+
assert zero_points is None
59+
xq = pack_s4(xq_s8) if xq_bits == 4 else xq_s8
60+
61+
w_quant_func = _int4_symm_per_token_quant_cutlass
62+
w_aqt = w_quant_func(w)
63+
wq_s8, wq_scales, zero_points = w_aqt.tensor_impl.get_plain()
64+
assert zero_points is None
65+
wq = pack_s4(wq_s8)
6366

6467
# If torch.nn.functional.linear(x, w, bias) used as reference, the
6568
# error would be too big. The calculation below is approximately
6669
# what rowwise_scaled_linear_cutlass kernel is doing (except that
6770
# matrix multiplication is over integers there).
68-
size_m_2d = x_2d.shape[0]
69-
output_ref = (
70-
(xq_2d_s8.float() @ wq_s8.float().T)
71-
* xq_2d_scales.view(size_m_2d, 1)
72-
* wq_scales.view(1, size_n)
73-
)
71+
output_ref = (xq_s8.float() @ wq_s8.float().T) * xq_scales[..., None] * wq_scales
7472
if bias is not None:
7573
output_ref += bias
7674
output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,))
7775

78-
fn_inputs = (xq, xq_scales, wq, wq_scales, bias)
76+
fn_inputs = (xq, xq_scales, wq, wq_scales, bias, dtype)
7977
try:
8078
output = op(*fn_inputs)
8179
except NotImplementedError:
@@ -85,20 +83,16 @@ def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias)
8583

8684

8785
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
88-
@pytest.mark.parametrize(
89-
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
90-
)
86+
@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS)
9187
def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias):
9288
run_test_for_op(
93-
rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias
89+
rowwise_scaled_linear_cutlass_s4s4, dtype, batch_size, size_mnk, use_bias
9490
)
9591

9692

9793
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
98-
@pytest.mark.parametrize(
99-
"dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS
100-
)
94+
@pytest.mark.parametrize("dtype, batch_size, size_mnk, use_bias", TEST_PARAMS)
10195
def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias):
10296
run_test_for_op(
103-
rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias
97+
rowwise_scaled_linear_cutlass_s8s4, dtype, batch_size, size_mnk, use_bias
10498
)

0 commit comments

Comments
 (0)