Skip to content

Commit 0ffe550

Browse files
committed
roofline estimator: add float8 rowwise and mxfp8 recipe support
Summary: Test Plan: ``` python benchmarks/float8/float8_roofline.py ~/local/tmp/20250226_test.csv --n_limit 1 --float8_recipe_name rowwise python benchmarks/float8/float8_roofline.py ~/local/tmp/20250226_test.csv --n_limit 1 --mx_recipe_name mxfp8_emulated ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 01f4cf1 ghstack-comment-id: 2686473047 Pull Request resolved: #1789
1 parent 0b00253 commit 0ffe550

File tree

2 files changed

+102
-25
lines changed

2 files changed

+102
-25
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,11 @@
5757
)
5858

5959
from torchao.float8 import (
60+
Float8LinearConfig,
6061
convert_to_float8_training,
6162
)
63+
from torchao.prototype.mx_formats.config import MXLinearConfig
64+
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
6265
from torchao.testing.float8.roofline_utils import (
6366
get_float8_mem_sympy,
6467
get_gemm_time_sympy,
@@ -167,6 +170,8 @@ def run(
167170
shape_gen_name: str = "square",
168171
gemm_cache_filename: Optional[str] = None,
169172
n_limit: Optional[int] = None,
173+
float8_recipe_name: Optional[str] = None,
174+
mx_recipe_name: Optional[str] = None,
170175
):
171176
"""
172177
Args:
@@ -176,21 +181,32 @@ def run(
176181
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
177182
"""
178183

184+
assert not (
185+
(float8_recipe_name is not None) and (mx_recipe_name is not None)
186+
), "unsupported"
187+
if float8_recipe_name is None and mx_recipe_name is None:
188+
float8_recipe_name = "tensorwise"
189+
190+
print(f"GPU: {torch.cuda.get_device_name(0)}")
179191
print(f"do_benchmarks: {do_benchmarks}")
180192
print(f"shape_gen_name: {shape_gen_name}")
193+
print(f"float8_recipe_name: {float8_recipe_name}")
194+
print(f"mx_recipe_name: {mx_recipe_name}")
181195

182196
M, K, N = sympy.symbols("M K N")
183197

184-
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
198+
fp8_ovhd_time_sympy = get_float8_mem_sympy(
185199
M,
186200
K,
187201
N,
202+
float8_recipe_name,
203+
mx_recipe_name,
188204
)
189-
190205
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
191-
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
192206
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
207+
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
193208
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
209+
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
194210
print()
195211

196212
headers = [
@@ -252,7 +268,7 @@ def run(
252268

253269
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254270
r_fp8_ovhd_time_s = float(
255-
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
271+
fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
256272
)
257273

258274
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
@@ -271,7 +287,16 @@ def run(
271287
# get the float8 dynamic scaling gpu kernel time
272288

273289
torch._dynamo.reset()
274-
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
290+
if float8_recipe_name is not None:
291+
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
292+
m_fp8_dyn = convert_to_float8_training(
293+
copy.deepcopy(m_orig), config=config
294+
)
295+
else:
296+
assert mx_recipe_name is not None
297+
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
298+
m_fp8_dyn = copy.deepcopy(m_orig)
299+
swap_linear_with_mx_linear(m_fp8_dyn, config=config)
275300
m_fp8_dyn = torch.compile(m_fp8_dyn)
276301
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
277302

torchao/testing/float8/roofline_utils.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional
8+
79
import torch
810

911
BYTES_PER_EL_FLOAT8 = 1
@@ -55,29 +57,67 @@ def get_specs():
5557
def get_tensor_memory_traffic_bytes(
5658
dim0,
5759
dim1,
60+
float8_recipe_name: Optional[str],
61+
mx_recipe_name: Optional[str],
5862
fuse_with_prev=False,
5963
):
6064
# assumes input bf16, output f8
6165
numel = dim0 * dim1
6266

63-
# x_bf16 = ...
64-
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
65-
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
66-
# kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
67+
if float8_recipe_name == "tensorwise":
68+
# x_bf16 = ...
69+
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
70+
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
71+
# kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
72+
73+
if fuse_with_prev:
74+
kernel_1_rw = 0
75+
else:
76+
# kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
77+
kernel_1_rw = BYTES_PER_EL_BF16 * numel
78+
79+
# kernel 3: read in bf16, write twice in float8 (row-major and col-major)
80+
kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
81+
82+
return kernel_1_rw + kernel_3_rw
83+
84+
elif float8_recipe_name == "rowwise":
85+
# x_bf16 = ...
86+
# kernel 1: x_bf16 -> x_float8_dim0
87+
# kernel 2: x_bf16 -> x_float8_dim1
88+
89+
# assume that we can't fuse 1 and 2 because that would require loading
90+
# the entire tensor to shared memory
91+
92+
if fuse_with_prev:
93+
# assume we can fuse one of the reads with previous op
94+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel
95+
else:
96+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
97+
98+
kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel
99+
100+
return kernel_1_rw + kernel_2_rw
67101

68-
if fuse_with_prev:
69-
kernel_1_rw = 0
70102
else:
71-
# kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
72-
kernel_1_rw = BYTES_PER_EL_BF16 * numel
103+
assert mx_recipe_name in ("mxfp8_emulated", "mxfp8_cutlass"), "unsupported"
73104

74-
# kernel 3: read in bf16, write twice in float8 (row-major and col-major)
75-
kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
105+
# x_bf16 = ...
106+
# kernel 1: x_bf16 -> x_mxfp8_dim0, x_mxfp8_dim1
76107

77-
return kernel_1_rw + kernel_3_rw
108+
if fuse_with_prev:
109+
kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel * 2
110+
else:
111+
kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel * 2
112+
113+
return kernel_1_rw
78114

79115

80116
def get_gemm_time_sympy(M, K, N, dtype):
117+
# currently this assumes gemm is compute bound
118+
# TODO(future): maybe make more accurate for small shapes by taking max of
119+
# time to read/write and time to do the dot product, this might also
120+
# slightly differ for MX since scales are larger
81121
specs = get_specs()
82122
gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
83123
if dtype is torch.bfloat16:
@@ -89,9 +129,7 @@ def get_gemm_time_sympy(M, K, N, dtype):
89129

90130

91131
def get_float8_mem_sympy(
92-
M,
93-
K,
94-
N,
132+
M, K, N, float8_recipe_name: Optional[str], mx_recipe_name: Optional[str]
95133
):
96134
specs = get_specs()
97135

@@ -112,11 +150,15 @@ def get_float8_mem_sympy(
112150
fwd_fp8_input_mem = get_tensor_memory_traffic_bytes(
113151
M,
114152
K,
153+
float8_recipe_name,
154+
mx_recipe_name,
115155
fuse_with_prev=True,
116156
)
117157
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes(
118158
K,
119159
N,
160+
float8_recipe_name,
161+
mx_recipe_name,
120162
fuse_with_prev=False,
121163
)
122164
fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem
@@ -127,6 +169,8 @@ def get_float8_mem_sympy(
127169
gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes(
128170
M,
129171
N,
172+
float8_recipe_name,
173+
mx_recipe_name,
130174
fuse_with_prev=True,
131175
)
132176
# already casted, assuming that we save weight from fw to bw
@@ -158,12 +202,20 @@ def get_float8_mem_sympy(
158202
# kernel overhead in the units of seconds, and the per-gemm-input memory
159203
# estimations are in the units of bytes.
160204
num_extra_kernels = 0
161-
# second stage of max-abs reduction for input
162-
num_extra_kernels += 1
163-
# second stage of max-abs reduction for weight
164-
num_extra_kernels += 1
165-
# second stage of max-abs reduction for grad_output
166-
num_extra_kernels += 1
205+
if float8_recipe_name == "tensorwise":
206+
# second stage of max-abs reduction for input
207+
num_extra_kernels += 1
208+
# second stage of max-abs reduction for weight
209+
num_extra_kernels += 1
210+
# second stage of max-abs reduction for grad_output
211+
num_extra_kernels += 1
212+
elif float8_recipe_name == "rowwise":
213+
# for simplicity, assume all rowwise kernels are large and bandwidth bound
214+
pass
215+
else:
216+
assert mx_recipe_name in ("mxfp8_emulated", "mxfp8_cutlass"), "unsupported"
217+
# for simplicity, assume all mxfp8 kernels are large and bandwidth bound
218+
pass
167219

168220
extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC
169221

0 commit comments

Comments
 (0)