Skip to content

Commit 9bedb01

Browse files
committed
Update
[ghstack-poisoned]
1 parent 6a5718d commit 9bedb01

File tree

2 files changed

+168
-116
lines changed

2 files changed

+168
-116
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,17 @@ def get_gpu_kernel_time(m, x, grad_output):
121121

122122

123123
def get_gemm_times(
124-
M,
125-
K,
126-
N,
127-
fast_accum,
128-
bf16_memory_formats,
129-
float8_recipe_name,
130-
mx_recipe_name,
124+
gemm_role: str,
125+
M: int,
126+
K: int,
127+
N: int,
128+
fast_accum: bool,
129+
bf16_memory_formats: str,
130+
float8_recipe_name: Optional[str],
131+
mx_recipe_name: Optional[str],
131132
cache_filename=None,
132133
):
134+
assert gemm_role in ("output", "grad_input", "grad_weight"), "unsupported"
133135
assert bf16_memory_formats in (
134136
"row_major:col_major",
135137
"row_major:row_major",
@@ -139,6 +141,7 @@ def get_gemm_times(
139141
# Note: this is definitely not the best way to build a cache,
140142
# but it will do for now.
141143
if cache_filename is not None:
144+
assert False, "TODO retest this for new arguments"
142145
if os.path.isfile(cache_filename):
143146
# cache already exists, use it
144147
with open(cache_filename, "r") as f:
@@ -169,24 +172,27 @@ def get_gemm_times(
169172
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
170173

171174
# f8 time
172-
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16
173-
A = torch.zeros(M, K, device=device, dtype=d1)
174-
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
175-
if float8_recipe_name == "tensorwise":
176-
scale_a = torch.tensor([1.0], device=device)
177-
scale_b = torch.tensor([1.0], device=device)
178-
elif float8_recipe_name == "rowwise":
179-
scale_a = torch.ones(M, 1, device=device)
180-
scale_b = torch.ones(1, N, device=device)
175+
if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight":
176+
f8_time_s = bf16_time_s
181177
else:
182-
assert False, "TODO add mx gemm here"
178+
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16
179+
A = torch.zeros(M, K, device=device, dtype=d1)
180+
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
181+
if float8_recipe_name == "tensorwise":
182+
scale_a = torch.tensor([1.0], device=device)
183+
scale_b = torch.tensor([1.0], device=device)
184+
elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"):
185+
scale_a = torch.ones(M, 1, device=device)
186+
scale_b = torch.ones(1, N, device=device)
187+
else:
188+
assert False, "TODO add mx gemm here"
183189

184-
def do_matmul(A, B):
185-
return torch._scaled_mm(
186-
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
187-
)
190+
def do_matmul(A, B):
191+
return torch._scaled_mm(
192+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
193+
)
188194

189-
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
195+
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
190196

191197
# save to cache if needed
192198
if cache_filename is not None:
@@ -239,9 +245,9 @@ def run(
239245
mx_recipe_name,
240246
enable_fusion_modeling,
241247
)
242-
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None)
248+
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None, None)
243249
fp8_gemm_time_sympy = get_gemm_time_sympy(
244-
M, K, N, torch.float8_e4m3fn, mx_recipe_name
250+
M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name
245251
)
246252
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
247253
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
@@ -305,6 +311,7 @@ def run(
305311
# what PyTorch core is doing for `torch.mm`
306312
# input @ weight_t = output
307313
bf16_g1, f8_g1 = get_gemm_times(
314+
"output",
308315
M_val,
309316
K_val,
310317
N_val,
@@ -316,6 +323,7 @@ def run(
316323
)
317324
# grad_output @ weight = grad_input
318325
bf16_g2, f8_g2 = get_gemm_times(
326+
"grad_input",
319327
M_val,
320328
N_val,
321329
K_val,
@@ -327,6 +335,7 @@ def run(
327335
)
328336
# input_t @ grad_output = grad_weight
329337
bf16_g3, f8_g3 = get_gemm_times(
338+
"grad_weight",
330339
K_val,
331340
M_val,
332341
N_val,

0 commit comments

Comments
 (0)