Skip to content

Commit 173f6c1

Browse files
committed
roofline estimator: add float8 rowwise and mxfp8 recipe support
Summary: Test Plan: ``` python benchmarks/float8/float8_roofline.py ~/local/tmp/20250226_test.csv --n_limit 1 --float8_recipe_name rowwise python benchmarks/float8/float8_roofline.py ~/local/tmp/20250226_test.csv --n_limit 1 --mx_recipe_name mxfp8_emulated ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 7857b5c ghstack-comment-id: 2686473047 Pull Request resolved: #1789
1 parent b9c51b7 commit 173f6c1

File tree

3 files changed

+314
-75
lines changed

3 files changed

+314
-75
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 133 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import pandas as pd
4848
import sympy
4949
import torch
50+
import torch.nn as nn
5051
import torch.utils.benchmark as benchmark
5152
import tqdm
5253
from torch.profiler import ProfilerActivity, profile
@@ -57,8 +58,11 @@
5758
)
5859

5960
from torchao.float8 import (
61+
Float8LinearConfig,
6062
convert_to_float8_training,
6163
)
64+
from torchao.prototype.mx_formats.config import MXLinearConfig
65+
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
6266
from torchao.testing.float8.roofline_utils import (
6367
get_float8_mem_sympy,
6468
get_gemm_time_sympy,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
9397
return measurement.mean
9498

9599

96-
def get_gpu_kernel_time(m, x):
100+
def get_gpu_kernel_time(m, x, grad_output):
97101
# warm up
98102
for _ in range(2):
99-
m(x).sum().backward()
103+
y = m(x)
104+
y.backward(grad_output)
100105

101106
# capture a profiling run
102107
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
103108
n_iter = 5
104109
with profile(activities=activities) as prof:
105110
for _ in range(n_iter):
106-
m(x).sum().backward()
111+
y = m(x)
112+
y.backward(grad_output)
107113
torch.cuda.synchronize()
108114
# get the gpu kernel time and aggregate it
109115
num_leaf_tensors = 1 + len(list(m.parameters()))
@@ -114,7 +120,22 @@ def get_gpu_kernel_time(m, x):
114120
return total_time_s
115121

116122

117-
def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
123+
def get_gemm_times(
124+
M,
125+
K,
126+
N,
127+
fast_accum,
128+
bf16_memory_formats,
129+
float8_recipe_name,
130+
mx_recipe_name,
131+
cache_filename=None,
132+
):
133+
assert bf16_memory_formats in (
134+
"row_major:col_major",
135+
"row_major:row_major",
136+
"col_major:row_major",
137+
), "unsupported"
138+
118139
# Note: this is definitely not the best way to build a cache,
119140
# but it will do for now.
120141
if cache_filename is not None:
@@ -127,23 +148,38 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127148
cache = dict()
128149
else:
129150
cache = dict()
130-
key = f"{M},{K},{N},{fast_accum}"
151+
key = f"{M},{K},{N},{fast_accum},{bf16_memory_formats}"
131152
if key in cache:
132153
return cache[key]
133154

134155
device = torch.device("cuda")
135156

136157
# bf16 time
137158
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device)
138-
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
159+
# w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
160+
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
161+
162+
if bf16_memory_formats == "row_major:col_major":
163+
w_bf16 = w_bf16.t().contiguous().t()
164+
elif bf16_memory_formats == "col_major:row_major":
165+
x_bf16 = x_bf16.t().contiguous().t()
166+
elif bf16_memory_formats == "col_major:row_major":
167+
x_bf16 = x_bf16.t().contiguous().t()
168+
139169
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
140170

141171
# f8 time
142172
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16
143173
A = torch.zeros(M, K, device=device, dtype=d1)
144174
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
145-
scale_a = torch.tensor([1.0], device=device)
146-
scale_b = torch.tensor([1.0], device=device)
175+
if float8_recipe_name == "tensorwise":
176+
scale_a = torch.tensor([1.0], device=device)
177+
scale_b = torch.tensor([1.0], device=device)
178+
elif float8_recipe_name == "rowwise":
179+
scale_a = torch.ones(M, 1, device=device)
180+
scale_b = torch.ones(1, N, device=device)
181+
else:
182+
assert False, "TODO add mx gemm here"
147183

148184
def do_matmul(A, B):
149185
return torch._scaled_mm(
@@ -164,33 +200,52 @@ def do_matmul(A, B):
164200
def run(
165201
outfile: str,
166202
do_benchmarks: bool = True,
167-
shape_gen_name: str = "square",
203+
shape_gen_name: str = "pow2",
168204
gemm_cache_filename: Optional[str] = None,
169205
n_limit: Optional[int] = None,
206+
float8_recipe_name: Optional[str] = None,
207+
mx_recipe_name: Optional[str] = None,
208+
enable_fusion_modeling: bool = False,
170209
):
171210
"""
172211
Args:
173212
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174-
* `shape_gen_name`: `llama`, `square`, or `sweep`
213+
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
175214
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176215
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
216+
* `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177217
"""
178218

219+
assert not (
220+
(float8_recipe_name is not None) and (mx_recipe_name is not None)
221+
), "unsupported"
222+
if float8_recipe_name is None and mx_recipe_name is None:
223+
float8_recipe_name = "tensorwise"
224+
225+
print(f"GPU: {torch.cuda.get_device_name(0)}")
179226
print(f"do_benchmarks: {do_benchmarks}")
180227
print(f"shape_gen_name: {shape_gen_name}")
228+
print(f"float8_recipe_name: {float8_recipe_name}")
229+
print(f"mx_recipe_name: {mx_recipe_name}")
230+
print(f"enable_fusion_modeling: {enable_fusion_modeling}")
181231

182232
M, K, N = sympy.symbols("M K N")
183233

184-
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
234+
fp8_ovhd_time_sympy = get_float8_mem_sympy(
185235
M,
186236
K,
187237
N,
238+
float8_recipe_name,
239+
mx_recipe_name,
240+
enable_fusion_modeling,
241+
)
242+
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None)
243+
fp8_gemm_time_sympy = get_gemm_time_sympy(
244+
M, K, N, torch.float8_e4m3fn, mx_recipe_name
188245
)
189-
190-
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
191246
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
192-
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
193247
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
248+
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
194249
print()
195250

196251
headers = [
@@ -217,6 +272,9 @@ def run(
217272
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218273
# we don't break them out and don't have a roofline for them.
219274
"b_fp8_e2e_spdp",
275+
# how well benchmarked gemms match roofline predicted gemms
276+
"rb_bf16_gemm_ratio",
277+
"rb_fp8_gemm_ratio",
220278
]
221279
results = []
222280

@@ -237,43 +295,93 @@ def run(
237295

238296
# if enabled, also measured observed gemm time
239297
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
298+
rb_bf16_gemm_ratio = -1
299+
rb_fp8_gemm_ratio = -1
300+
240301
if do_benchmarks:
302+
# TODO(future): make the bf16 gemm times exactly match the e2e
303+
# benchmarks, there is a slight deviation, probably related to gemm
304+
# operand memory formats/transpositions below not exactly matching
305+
# what PyTorch core is doing for `torch.mm`
306+
# input @ weight_t = output
241307
bf16_g1, f8_g1 = get_gemm_times(
242-
M_val, K_val, N_val, True, gemm_cache_filename
308+
M_val,
309+
K_val,
310+
N_val,
311+
True,
312+
"row_major:col_major",
313+
float8_recipe_name,
314+
mx_recipe_name,
315+
gemm_cache_filename,
243316
)
317+
# grad_output @ weight = grad_input
244318
bf16_g2, f8_g2 = get_gemm_times(
245-
M_val, N_val, K_val, False, gemm_cache_filename
319+
M_val,
320+
N_val,
321+
K_val,
322+
False,
323+
"row_major:row_major",
324+
float8_recipe_name,
325+
mx_recipe_name,
326+
gemm_cache_filename,
246327
)
328+
# input_t @ grad_output = grad_weight
247329
bf16_g3, f8_g3 = get_gemm_times(
248-
K_val, M_val, N_val, False, gemm_cache_filename
330+
K_val,
331+
M_val,
332+
N_val,
333+
False,
334+
"col_major:row_major",
335+
float8_recipe_name,
336+
mx_recipe_name,
337+
gemm_cache_filename,
249338
)
250339
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251340
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
341+
rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
342+
rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252343

253344
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254345
r_fp8_ovhd_time_s = float(
255-
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
346+
fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
256347
)
257348

258349
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
259350
if do_benchmarks:
260351
# create the model
261-
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
352+
if enable_fusion_modeling:
353+
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
354+
else:
355+
m_orig = (
356+
nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16()
357+
)
262358
x = torch.randn(
263359
M_val, K_val, dtype=torch.bfloat16, device="cuda"
264360
).requires_grad_()
265361

362+
# get the gradient of the right shape
363+
grad_output = torch.randn(N_val, K_val, dtype=torch.bfloat16, device="cuda")
364+
266365
# get the bf16 gpu kernel time
267366
torch._dynamo.reset()
268367
m_bf16 = torch.compile(copy.deepcopy(m_orig))
269-
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x)
368+
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, grad_output)
270369

271370
# get the float8 dynamic scaling gpu kernel time
272371

273372
torch._dynamo.reset()
274-
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
373+
if float8_recipe_name is not None:
374+
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
375+
m_fp8_dyn = convert_to_float8_training(
376+
copy.deepcopy(m_orig), config=config
377+
)
378+
else:
379+
assert mx_recipe_name is not None
380+
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
381+
m_fp8_dyn = copy.deepcopy(m_orig)
382+
swap_linear_with_mx_linear(m_fp8_dyn, config=config)
275383
m_fp8_dyn = torch.compile(m_fp8_dyn)
276-
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
384+
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output)
277385

278386
results.append(
279387
[
@@ -295,6 +403,9 @@ def run(
295403
b_bf16_e2e_time_s,
296404
b_fp8_e2e_time_s,
297405
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20),
406+
# gemm ratios
407+
rb_bf16_gemm_ratio,
408+
rb_fp8_gemm_ratio,
298409
]
299410
)
300411

benchmarks/float8/utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,32 @@ def get_name_to_shapes_iter(
152152
}
153153
return name_to_shapes_70b.items()
154154

155-
elif shape_gen_name == "square":
155+
elif shape_gen_name == "pow2":
156156
assert (
157157
M == K == N == None
158158
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
159159
name_to_shapes = {}
160-
min_power_of_2 = 8 # 256
161-
max_power_of_2 = 15 # 32,768
160+
min_power_of_2 = 10 # 1024
161+
max_power_of_2 = 14 # 16,384
162162
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
163163
val = 2**power_of_2
164164
name_to_shapes[idx] = val, val, val
165165
return name_to_shapes.items()
166166

167+
elif shape_gen_name == "pow2_extended":
168+
assert (
169+
M == K == N == None
170+
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
171+
name_to_shapes = {}
172+
min_power_of_2 = 10 # 1024
173+
max_power_of_2 = 14 # 16,384
174+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
175+
val1 = 2**power_of_2
176+
name_to_shapes[idx * 2] = val1, val1, val1
177+
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
178+
name_to_shapes[idx * 2 + 1] = val2, val2, val2
179+
return name_to_shapes.items()
180+
167181
elif shape_gen_name == "sweep":
168182
assert (
169183
M == K == N == None

0 commit comments

Comments
 (0)