Skip to content

Commit 8641fd6

Browse files
authored
float8 matmul benchmark: hook up cublas mxfp8 gemm (#1830)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 3afbc9e commit 8641fd6

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
get_name_to_shapes_iter,
1717
)
1818

19-
from torchao.float8.config import ScalingGranularity
2019
from torchao.testing.float8.roofline_utils import get_specs
2120

2221

@@ -53,15 +52,17 @@ def do_benchmarks(
5352
@torch.inference_mode()
5453
def run(
5554
n_limit: Optional[int] = None,
56-
shape_gen_name: str = "llama",
55+
shape_gen_name: str = "pow2_extended",
5756
out_filename: Optional[str] = None,
5857
M: Optional[int] = None,
5958
K: Optional[int] = None,
6059
N: Optional[int] = None,
61-
use_gpu_kernel_time: bool = False,
62-
scaling_granularity: str = "tensorwise",
60+
use_gpu_kernel_time: bool = True,
61+
recipe: str = "tensorwise",
6362
):
6463
device = "cuda"
64+
# TODO(future PR): this is ugly
65+
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas"), "unsupported"
6566

6667
specs = get_specs()
6768
bf16_peak_tops = specs["bf16_peak_tops"]
@@ -84,7 +85,6 @@ def run(
8485
dtype = torch.bfloat16
8586
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
8687
fast_accum_vals = [True, False]
87-
scaling_granularity = ScalingGranularity(scaling_granularity)
8888

8989
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
9090
itertools.product(fast_accum_vals, name_to_shapes)
@@ -112,13 +112,17 @@ def run(
112112
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
113113
A = torch.zeros(M, K, device=device, dtype=d1)
114114
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
115-
if scaling_granularity == ScalingGranularity.TENSORWISE:
115+
if recipe == "tensorwise":
116116
scale_a = torch.tensor([1.0], device=device)
117117
scale_b = torch.tensor([1.0], device=device)
118-
else:
119-
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
118+
elif recipe == "rowwise":
120119
scale_a = torch.ones(M, 1, device=device)
121120
scale_b = torch.ones(1, N, device=device)
121+
elif recipe == "mxfp8_cublas":
122+
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
123+
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
124+
else:
125+
assert False, f"unknown recipe {recipe}"
122126

123127
def do_matmul(A, B):
124128
nonlocal scale_a

0 commit comments

Comments
 (0)