Skip to content

Commit d8af7d7

Browse files
authored
roofline estimator: add float8 rowwise and mxfp8 recipe support (#1789)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 81a2813 commit d8af7d7

File tree

3 files changed

+397
-106
lines changed

3 files changed

+397
-106
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 151 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import pandas as pd
4848
import sympy
4949
import torch
50+
import torch.nn as nn
5051
import torch.utils.benchmark as benchmark
5152
import tqdm
5253
from torch.profiler import ProfilerActivity, profile
@@ -57,8 +58,11 @@
5758
)
5859

5960
from torchao.float8 import (
61+
Float8LinearConfig,
6062
convert_to_float8_training,
6163
)
64+
from torchao.prototype.mx_formats.config import MXLinearConfig
65+
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
6266
from torchao.testing.float8.roofline_utils import (
6367
get_float8_mem_sympy,
6468
get_gemm_time_sympy,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
9397
return measurement.mean
9498

9599

96-
def get_gpu_kernel_time(m, x):
100+
def get_gpu_kernel_time(m, x, grad_output):
97101
# warm up
98102
for _ in range(2):
99-
m(x).sum().backward()
103+
y = m(x)
104+
y.backward(grad_output)
100105

101106
# capture a profiling run
102107
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
103108
n_iter = 5
104109
with profile(activities=activities) as prof:
105110
for _ in range(n_iter):
106-
m(x).sum().backward()
111+
y = m(x)
112+
y.backward(grad_output)
107113
torch.cuda.synchronize()
108114
# get the gpu kernel time and aggregate it
109115
num_leaf_tensors = 1 + len(list(m.parameters()))
@@ -114,10 +120,28 @@ def get_gpu_kernel_time(m, x):
114120
return total_time_s
115121

116122

117-
def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
123+
def get_gemm_times(
124+
gemm_role: str,
125+
M: int,
126+
K: int,
127+
N: int,
128+
fast_accum: bool,
129+
bf16_memory_formats: str,
130+
float8_recipe_name: Optional[str],
131+
mx_recipe_name: Optional[str],
132+
cache_filename=None,
133+
):
134+
assert gemm_role in ("output", "grad_input", "grad_weight"), "unsupported"
135+
assert bf16_memory_formats in (
136+
"row_major:col_major",
137+
"row_major:row_major",
138+
"col_major:row_major",
139+
), "unsupported"
140+
118141
# Note: this is definitely not the best way to build a cache,
119142
# but it will do for now.
120143
if cache_filename is not None:
144+
assert False, "TODO retest this for new arguments"
121145
if os.path.isfile(cache_filename):
122146
# cache already exists, use it
123147
with open(cache_filename, "r") as f:
@@ -127,30 +151,48 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127151
cache = dict()
128152
else:
129153
cache = dict()
130-
key = f"{M},{K},{N},{fast_accum}"
154+
key = f"{M},{K},{N},{fast_accum},{bf16_memory_formats}"
131155
if key in cache:
132156
return cache[key]
133157

134158
device = torch.device("cuda")
135159

136160
# bf16 time
137161
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device)
138-
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
162+
# w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
163+
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
164+
165+
if bf16_memory_formats == "row_major:col_major":
166+
w_bf16 = w_bf16.t().contiguous().t()
167+
elif bf16_memory_formats == "col_major:row_major":
168+
x_bf16 = x_bf16.t().contiguous().t()
169+
elif bf16_memory_formats == "col_major:row_major":
170+
x_bf16 = x_bf16.t().contiguous().t()
171+
139172
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
140173

141174
# f8 time
142-
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16
143-
A = torch.zeros(M, K, device=device, dtype=d1)
144-
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
145-
scale_a = torch.tensor([1.0], device=device)
146-
scale_b = torch.tensor([1.0], device=device)
147-
148-
def do_matmul(A, B):
149-
return torch._scaled_mm(
150-
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
151-
)
175+
if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight":
176+
f8_time_s = bf16_time_s
177+
else:
178+
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16
179+
A = torch.zeros(M, K, device=device, dtype=d1)
180+
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
181+
if float8_recipe_name == "tensorwise":
182+
scale_a = torch.tensor([1.0], device=device)
183+
scale_b = torch.tensor([1.0], device=device)
184+
elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"):
185+
scale_a = torch.ones(M, 1, device=device)
186+
scale_b = torch.ones(1, N, device=device)
187+
else:
188+
assert False, "TODO add mx gemm here"
189+
190+
def do_matmul(A, B):
191+
return torch._scaled_mm(
192+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
193+
)
152194

153-
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
195+
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
154196

155197
# save to cache if needed
156198
if cache_filename is not None:
@@ -164,33 +206,52 @@ def do_matmul(A, B):
164206
def run(
165207
outfile: str,
166208
do_benchmarks: bool = True,
167-
shape_gen_name: str = "square",
209+
shape_gen_name: str = "pow2",
168210
gemm_cache_filename: Optional[str] = None,
169211
n_limit: Optional[int] = None,
212+
float8_recipe_name: Optional[str] = None,
213+
mx_recipe_name: Optional[str] = None,
214+
enable_fusion_modeling: bool = False,
170215
):
171216
"""
172217
Args:
173218
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174-
* `shape_gen_name`: `llama`, `square`, or `sweep`
219+
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
175220
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176221
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
222+
* `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177223
"""
178224

225+
assert not (
226+
(float8_recipe_name is not None) and (mx_recipe_name is not None)
227+
), "unsupported"
228+
if float8_recipe_name is None and mx_recipe_name is None:
229+
float8_recipe_name = "tensorwise"
230+
231+
print(f"GPU: {torch.cuda.get_device_name(0)}")
179232
print(f"do_benchmarks: {do_benchmarks}")
180233
print(f"shape_gen_name: {shape_gen_name}")
234+
print(f"float8_recipe_name: {float8_recipe_name}")
235+
print(f"mx_recipe_name: {mx_recipe_name}")
236+
print(f"enable_fusion_modeling: {enable_fusion_modeling}")
181237

182238
M, K, N = sympy.symbols("M K N")
183239

184-
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
240+
fp8_ovhd_time_sympy = get_float8_mem_sympy(
185241
M,
186242
K,
187243
N,
244+
float8_recipe_name,
245+
mx_recipe_name,
246+
enable_fusion_modeling,
247+
)
248+
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None, None)
249+
fp8_gemm_time_sympy = get_gemm_time_sympy(
250+
M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name
188251
)
189-
190-
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
191252
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
192-
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
193253
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
254+
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
194255
print()
195256

196257
headers = [
@@ -217,6 +278,9 @@ def run(
217278
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218279
# we don't break them out and don't have a roofline for them.
219280
"b_fp8_e2e_spdp",
281+
# how well benchmarked gemms match roofline predicted gemms
282+
"rb_bf16_gemm_ratio",
283+
"rb_fp8_gemm_ratio",
220284
]
221285
results = []
222286

@@ -237,43 +301,96 @@ def run(
237301

238302
# if enabled, also measured observed gemm time
239303
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
304+
rb_bf16_gemm_ratio = -1
305+
rb_fp8_gemm_ratio = -1
306+
240307
if do_benchmarks:
308+
# TODO(future): make the bf16 gemm times exactly match the e2e
309+
# benchmarks, there is a slight deviation, probably related to gemm
310+
# operand memory formats/transpositions below not exactly matching
311+
# what PyTorch core is doing for `torch.mm`
312+
# input @ weight_t = output
241313
bf16_g1, f8_g1 = get_gemm_times(
242-
M_val, K_val, N_val, True, gemm_cache_filename
314+
"output",
315+
M_val,
316+
K_val,
317+
N_val,
318+
True,
319+
"row_major:col_major",
320+
float8_recipe_name,
321+
mx_recipe_name,
322+
gemm_cache_filename,
243323
)
324+
# grad_output @ weight = grad_input
244325
bf16_g2, f8_g2 = get_gemm_times(
245-
M_val, N_val, K_val, False, gemm_cache_filename
326+
"grad_input",
327+
M_val,
328+
N_val,
329+
K_val,
330+
False,
331+
"row_major:row_major",
332+
float8_recipe_name,
333+
mx_recipe_name,
334+
gemm_cache_filename,
246335
)
336+
# input_t @ grad_output = grad_weight
247337
bf16_g3, f8_g3 = get_gemm_times(
248-
K_val, M_val, N_val, False, gemm_cache_filename
338+
"grad_weight",
339+
K_val,
340+
M_val,
341+
N_val,
342+
False,
343+
"col_major:row_major",
344+
float8_recipe_name,
345+
mx_recipe_name,
346+
gemm_cache_filename,
249347
)
250348
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251349
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
350+
rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
351+
rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252352

253353
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254354
r_fp8_ovhd_time_s = float(
255-
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
355+
fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
256356
)
257357

258358
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
259359
if do_benchmarks:
260360
# create the model
261-
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
361+
if enable_fusion_modeling:
362+
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
363+
else:
364+
m_orig = (
365+
nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16()
366+
)
262367
x = torch.randn(
263368
M_val, K_val, dtype=torch.bfloat16, device="cuda"
264369
).requires_grad_()
265370

371+
# get the gradient of the right shape
372+
grad_output = torch.randn(N_val, K_val, dtype=torch.bfloat16, device="cuda")
373+
266374
# get the bf16 gpu kernel time
267375
torch._dynamo.reset()
268376
m_bf16 = torch.compile(copy.deepcopy(m_orig))
269-
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x)
377+
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, grad_output)
270378

271379
# get the float8 dynamic scaling gpu kernel time
272380

273381
torch._dynamo.reset()
274-
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
382+
if float8_recipe_name is not None:
383+
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
384+
m_fp8_dyn = convert_to_float8_training(
385+
copy.deepcopy(m_orig), config=config
386+
)
387+
else:
388+
assert mx_recipe_name is not None
389+
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
390+
m_fp8_dyn = copy.deepcopy(m_orig)
391+
swap_linear_with_mx_linear(m_fp8_dyn, config=config)
275392
m_fp8_dyn = torch.compile(m_fp8_dyn)
276-
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
393+
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output)
277394

278395
results.append(
279396
[
@@ -295,6 +412,9 @@ def run(
295412
b_bf16_e2e_time_s,
296413
b_fp8_e2e_time_s,
297414
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20),
415+
# gemm ratios
416+
rb_bf16_gemm_ratio,
417+
rb_fp8_gemm_ratio,
298418
]
299419
)
300420

benchmarks/float8/utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,32 @@ def get_name_to_shapes_iter(
152152
}
153153
return name_to_shapes_70b.items()
154154

155-
elif shape_gen_name == "square":
155+
elif shape_gen_name == "pow2":
156156
assert (
157157
M == K == N == None
158158
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
159159
name_to_shapes = {}
160-
min_power_of_2 = 8 # 256
161-
max_power_of_2 = 15 # 32,768
160+
min_power_of_2 = 10 # 1024
161+
max_power_of_2 = 14 # 16,384
162162
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
163163
val = 2**power_of_2
164164
name_to_shapes[idx] = val, val, val
165165
return name_to_shapes.items()
166166

167+
elif shape_gen_name == "pow2_extended":
168+
assert (
169+
M == K == N == None
170+
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
171+
name_to_shapes = {}
172+
min_power_of_2 = 10 # 1024
173+
max_power_of_2 = 14 # 16,384
174+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
175+
val1 = 2**power_of_2
176+
name_to_shapes[idx * 2] = val1, val1, val1
177+
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
178+
name_to_shapes[idx * 2 + 1] = val2, val2, val2
179+
return name_to_shapes.items()
180+
167181
elif shape_gen_name == "sweep":
168182
assert (
169183
M == K == N == None

0 commit comments

Comments
 (0)