Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit a8b75a4

Browse files
tlrmchlsmthRobert Shaw
authored and
Robert Shaw
committed
[Bugfix] Fix w8a8 benchmarks for int8 case (vllm-project#5643)
1 parent 010f2e8 commit a8b75a4

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,8 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
120120

121121
# cutlass impl
122122
timers.append(
123-
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
124-
torch.bfloat16, label, sub_label, cutlass_impl,
125-
"cutlass_i8_i8_bf16_scaled_mm"))
123+
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
124+
cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
126125

127126
return timers
128127

0 commit comments

Comments
 (0)