Skip to content

Commit b928266

Browse files
committed
roofline estimator: add float8 rowwise and mxfp8 recipe support
Summary: Test Plan: ``` python benchmarks/float8/float8_roofline.py ~/local/tmp/20250226_test.csv --n_limit 1 --float8_recipe_name rowwise python benchmarks/float8/float8_roofline.py ~/local/tmp/20250226_test.csv --n_limit 1 --mx_recipe_name mxfp8_emulated ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a34af93 ghstack-comment-id: 2686473047 Pull Request resolved: #1789
1 parent b9c51b7 commit b928266

File tree

3 files changed

+290
-73
lines changed

3 files changed

+290
-73
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 102 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import pandas as pd
4848
import sympy
4949
import torch
50+
import torch.nn as nn
5051
import torch.utils.benchmark as benchmark
5152
import tqdm
5253
from torch.profiler import ProfilerActivity, profile
@@ -57,8 +58,11 @@
5758
)
5859

5960
from torchao.float8 import (
61+
Float8LinearConfig,
6062
convert_to_float8_training,
6163
)
64+
from torchao.prototype.mx_formats.config import MXLinearConfig
65+
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
6266
from torchao.testing.float8.roofline_utils import (
6367
get_float8_mem_sympy,
6468
get_gemm_time_sympy,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
9397
return measurement.mean
9498

9599

96-
def get_gpu_kernel_time(m, x):
100+
def get_gpu_kernel_time(m, x, grad_output):
97101
# warm up
98102
for _ in range(2):
99-
m(x).sum().backward()
103+
y = m(x)
104+
y.backward(grad_output)
100105

101106
# capture a profiling run
102107
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
103108
n_iter = 5
104109
with profile(activities=activities) as prof:
105110
for _ in range(n_iter):
106-
m(x).sum().backward()
111+
y = m(x)
112+
y.backward(grad_output)
107113
torch.cuda.synchronize()
108114
# get the gpu kernel time and aggregate it
109115
num_leaf_tensors = 1 + len(list(m.parameters()))
@@ -114,7 +120,20 @@ def get_gpu_kernel_time(m, x):
114120
return total_time_s
115121

116122

117-
def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
123+
def get_gemm_times(
124+
M,
125+
K,
126+
N,
127+
fast_accum,
128+
bf16_memory_formats,
129+
cache_filename=None,
130+
):
131+
assert bf16_memory_formats in (
132+
"row_major:col_major",
133+
"row_major:row_major",
134+
"col_major:row_major",
135+
), "unsupported"
136+
118137
# Note: this is definitely not the best way to build a cache,
119138
# but it will do for now.
120139
if cache_filename is not None:
@@ -127,15 +146,24 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127146
cache = dict()
128147
else:
129148
cache = dict()
130-
key = f"{M},{K},{N},{fast_accum}"
149+
key = f"{M},{K},{N},{fast_accum},{bf16_memory_formats}"
131150
if key in cache:
132151
return cache[key]
133152

134153
device = torch.device("cuda")
135154

136155
# bf16 time
137156
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device)
138-
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
157+
# w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
158+
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
159+
160+
if bf16_memory_formats == "row_major:col_major":
161+
w_bf16 = w_bf16.t().contiguous().t()
162+
elif bf16_memory_formats == "col_major:row_major":
163+
x_bf16 = x_bf16.t().contiguous().t()
164+
elif bf16_memory_formats == "col_major:row_major":
165+
x_bf16 = x_bf16.t().contiguous().t()
166+
139167
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
140168

141169
# f8 time
@@ -164,33 +192,52 @@ def do_matmul(A, B):
164192
def run(
165193
outfile: str,
166194
do_benchmarks: bool = True,
167-
shape_gen_name: str = "square",
195+
shape_gen_name: str = "pow2",
168196
gemm_cache_filename: Optional[str] = None,
169197
n_limit: Optional[int] = None,
198+
float8_recipe_name: Optional[str] = None,
199+
mx_recipe_name: Optional[str] = None,
200+
enable_fusion_modeling: bool = False,
170201
):
171202
"""
172203
Args:
173204
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174-
* `shape_gen_name`: `llama`, `square`, or `sweep`
205+
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
175206
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176207
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
208+
* `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177209
"""
178210

211+
assert not (
212+
(float8_recipe_name is not None) and (mx_recipe_name is not None)
213+
), "unsupported"
214+
if float8_recipe_name is None and mx_recipe_name is None:
215+
float8_recipe_name = "tensorwise"
216+
217+
print(f"GPU: {torch.cuda.get_device_name(0)}")
179218
print(f"do_benchmarks: {do_benchmarks}")
180219
print(f"shape_gen_name: {shape_gen_name}")
220+
print(f"float8_recipe_name: {float8_recipe_name}")
221+
print(f"mx_recipe_name: {mx_recipe_name}")
222+
print(f"enable_fusion_modeling: {enable_fusion_modeling}")
181223

182224
M, K, N = sympy.symbols("M K N")
183225

184-
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
226+
fp8_ovhd_time_sympy = get_float8_mem_sympy(
185227
M,
186228
K,
187229
N,
230+
float8_recipe_name,
231+
mx_recipe_name,
232+
enable_fusion_modeling,
233+
)
234+
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None)
235+
fp8_gemm_time_sympy = get_gemm_time_sympy(
236+
M, K, N, torch.float8_e4m3fn, mx_recipe_name
188237
)
189-
190-
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
191238
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
192-
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
193239
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
240+
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
194241
print()
195242

196243
headers = [
@@ -217,6 +264,9 @@ def run(
217264
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218265
# we don't break them out and don't have a roofline for them.
219266
"b_fp8_e2e_spdp",
267+
# how well benchmarked gemms match roofline predicted gemms
268+
"rb_bf16_gemm_ratio",
269+
"rb_fp8_gemm_ratio",
220270
]
221271
results = []
222272

@@ -237,43 +287,72 @@ def run(
237287

238288
# if enabled, also measured observed gemm time
239289
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
290+
rb_bf16_gemm_ratio = -1
291+
rb_fp8_gemm_ratio = -1
292+
240293
if do_benchmarks:
294+
# TODO(future): make the bf16 gemm times exactly match the e2e
295+
# benchmarks, there is a slight deviation, probably related to gemm
296+
# operand memory formats/transpositions below not exactly matching
297+
# what PyTorch core is doing for `torch.mm`
298+
# input @ weight_t = output
241299
bf16_g1, f8_g1 = get_gemm_times(
242-
M_val, K_val, N_val, True, gemm_cache_filename
300+
M_val, K_val, N_val, True, "row_major:col_major", gemm_cache_filename
243301
)
302+
# grad_output @ weight = grad_input
244303
bf16_g2, f8_g2 = get_gemm_times(
245-
M_val, N_val, K_val, False, gemm_cache_filename
304+
M_val, N_val, K_val, False, "row_major:row_major", gemm_cache_filename
246305
)
306+
# input_t @ grad_output = grad_weight
247307
bf16_g3, f8_g3 = get_gemm_times(
248-
K_val, M_val, N_val, False, gemm_cache_filename
308+
K_val, M_val, N_val, False, "col_major:row_major", gemm_cache_filename
249309
)
250310
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251311
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
312+
rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
313+
rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252314

253315
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254316
r_fp8_ovhd_time_s = float(
255-
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
317+
fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
256318
)
257319

258320
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
259321
if do_benchmarks:
260322
# create the model
261-
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
323+
if enable_fusion_modeling:
324+
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
325+
else:
326+
m_orig = (
327+
nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16()
328+
)
262329
x = torch.randn(
263330
M_val, K_val, dtype=torch.bfloat16, device="cuda"
264331
).requires_grad_()
265332

333+
# get the gradient of the right shape
334+
grad_output = torch.randn(N_val, K_val, dtype=torch.bfloat16, device="cuda")
335+
266336
# get the bf16 gpu kernel time
267337
torch._dynamo.reset()
268338
m_bf16 = torch.compile(copy.deepcopy(m_orig))
269-
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x)
339+
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, grad_output)
270340

271341
# get the float8 dynamic scaling gpu kernel time
272342

273343
torch._dynamo.reset()
274-
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
344+
if float8_recipe_name is not None:
345+
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
346+
m_fp8_dyn = convert_to_float8_training(
347+
copy.deepcopy(m_orig), config=config
348+
)
349+
else:
350+
assert mx_recipe_name is not None
351+
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
352+
m_fp8_dyn = copy.deepcopy(m_orig)
353+
swap_linear_with_mx_linear(m_fp8_dyn, config=config)
275354
m_fp8_dyn = torch.compile(m_fp8_dyn)
276-
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
355+
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output)
277356

278357
results.append(
279358
[
@@ -295,6 +374,9 @@ def run(
295374
b_bf16_e2e_time_s,
296375
b_fp8_e2e_time_s,
297376
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20),
377+
# gemm ratios
378+
rb_bf16_gemm_ratio,
379+
rb_fp8_gemm_ratio,
298380
]
299381
)
300382

benchmarks/float8/utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,32 @@ def get_name_to_shapes_iter(
152152
}
153153
return name_to_shapes_70b.items()
154154

155-
elif shape_gen_name == "square":
155+
elif shape_gen_name == "pow2":
156156
assert (
157157
M == K == N == None
158158
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
159159
name_to_shapes = {}
160-
min_power_of_2 = 8 # 256
161-
max_power_of_2 = 15 # 32,768
160+
min_power_of_2 = 10 # 1024
161+
max_power_of_2 = 14 # 16,384
162162
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
163163
val = 2**power_of_2
164164
name_to_shapes[idx] = val, val, val
165165
return name_to_shapes.items()
166166

167+
elif shape_gen_name == "pow2_extended":
168+
assert (
169+
M == K == N == None
170+
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
171+
name_to_shapes = {}
172+
min_power_of_2 = 10 # 1024
173+
max_power_of_2 = 14 # 16,384
174+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
175+
val1 = 2**power_of_2
176+
name_to_shapes[idx * 2] = val1, val1, val1
177+
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
178+
name_to_shapes[idx * 2 + 1] = val2, val2, val2
179+
return name_to_shapes.items()
180+
167181
elif shape_gen_name == "sweep":
168182
assert (
169183
M == K == N == None

0 commit comments

Comments
 (0)