Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add CUTLASS sparse support, heuristics, and torch operators #10340

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add cherry-picked configs for all data types
  • Loading branch information
Faraz9877 committed Dec 2, 2024
commit 4bc043a006993bca057dc5018099934c37368ef5
40 changes: 16 additions & 24 deletions benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
aT = a.t()
bT = b.t()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, bT, scale_a, scale_b, torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out.t(), out_ref):
if not torch.allclose(out, out_ref):
print("Incorrect result")
exit()

Expand Down Expand Up @@ -96,16 +95,15 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
aT = a.t()
bT = b
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, bT, scale_a, scale_b, torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out, out_ref, rtol=1e-2, atol=1e-2):
print(f"Incorrect result for {m}, {k}, {n}")
if not torch.allclose(out, out_ref):
print("Incorrect result")
exit()

timers = []
Expand All @@ -114,7 +112,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
bT.to(dtype=torch.bfloat16, device="cuda")))
b.to(dtype=torch.bfloat16, device="cuda")))

# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
Expand All @@ -123,7 +121,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
bT,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16))
Expand All @@ -135,7 +133,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
bT,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
Expand All @@ -148,7 +146,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
bT,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16))
Expand All @@ -160,7 +158,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
bT,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
Expand All @@ -183,26 +181,19 @@ def bench_fp16(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.float16

m, k, n = 1, 128, 256

# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float16, m, n, k)
aT = a.t()
bT = b.t()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)
out_ref = to_bf16(a@bT)
out_ref = to_bf16(a@b)

if not torch.allclose(out.t(), out_ref, rtol=1e-2, atol=1e-2):
if not torch.allclose(out, out_ref, rtol=1e-2, atol=1e-2):
print("Incorrect result")
print(out.t())
print(out_ref)
exit()
else:
print("Correct result")

timers = []

Expand Down Expand Up @@ -269,16 +260,17 @@ def bench_bf16(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.bfloat16, m, n, k)
aT = a.t()
bT = b.t()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(b_compressed, e, aT, scale_b, scale_a, torch.bfloat16)
out_ref = to_bf16(a@bT)
out_ref = to_bf16(a@b)

if not torch.allclose(out.t(), out_ref):
if not torch.allclose(out, out_ref, rtol=1e-1, atol=1e-1):
print("Incorrect result")
print(out)
print(out_ref)
exit()

timers = []
Expand Down
Loading