Skip to content

roofline estimator: add float8 rowwise and mxfp8 recipe support #1789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 151 additions & 31 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import pandas as pd
import sympy
import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
import tqdm
from torch.profiler import ProfilerActivity, profile
Expand All @@ -57,8 +58,11 @@
)

from torchao.float8 import (
Float8LinearConfig,
convert_to_float8_training,
)
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.testing.float8.roofline_utils import (
get_float8_mem_sympy,
get_gemm_time_sympy,
Expand Down Expand Up @@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
return measurement.mean


def get_gpu_kernel_time(m, x):
def get_gpu_kernel_time(m, x, grad_output):
# warm up
for _ in range(2):
m(x).sum().backward()
y = m(x)
y.backward(grad_output)

# capture a profiling run
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
n_iter = 5
with profile(activities=activities) as prof:
for _ in range(n_iter):
m(x).sum().backward()
y = m(x)
y.backward(grad_output)
torch.cuda.synchronize()
# get the gpu kernel time and aggregate it
num_leaf_tensors = 1 + len(list(m.parameters()))
Expand All @@ -114,10 +120,28 @@ def get_gpu_kernel_time(m, x):
return total_time_s


def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
def get_gemm_times(
gemm_role: str,
M: int,
K: int,
N: int,
fast_accum: bool,
bf16_memory_formats: str,
float8_recipe_name: Optional[str],
mx_recipe_name: Optional[str],
cache_filename=None,
):
assert gemm_role in ("output", "grad_input", "grad_weight"), "unsupported"
assert bf16_memory_formats in (
"row_major:col_major",
"row_major:row_major",
"col_major:row_major",
), "unsupported"

# Note: this is definitely not the best way to build a cache,
# but it will do for now.
if cache_filename is not None:
assert False, "TODO retest this for new arguments"
if os.path.isfile(cache_filename):
# cache already exists, use it
with open(cache_filename, "r") as f:
Expand All @@ -127,30 +151,48 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
cache = dict()
else:
cache = dict()
key = f"{M},{K},{N},{fast_accum}"
key = f"{M},{K},{N},{fast_accum},{bf16_memory_formats}"
if key in cache:
return cache[key]

device = torch.device("cuda")

# bf16 time
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device)
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
# w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)

if bf16_memory_formats == "row_major:col_major":
w_bf16 = w_bf16.t().contiguous().t()
elif bf16_memory_formats == "col_major:row_major":
x_bf16 = x_bf16.t().contiguous().t()
elif bf16_memory_formats == "col_major:row_major":
x_bf16 = x_bf16.t().contiguous().t()

bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)

# f8 time
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)

def do_matmul(A, B):
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)
if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight":
f8_time_s = bf16_time_s
else:
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16
A = torch.zeros(M, K, device=device, dtype=d1)
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
if float8_recipe_name == "tensorwise":
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"):
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
else:
assert False, "TODO add mx gemm here"

def do_matmul(A, B):
return torch._scaled_mm(
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
)

f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

# save to cache if needed
if cache_filename is not None:
Expand All @@ -164,33 +206,52 @@ def do_matmul(A, B):
def run(
outfile: str,
do_benchmarks: bool = True,
shape_gen_name: str = "square",
shape_gen_name: str = "pow2",
gemm_cache_filename: Optional[str] = None,
n_limit: Optional[int] = None,
float8_recipe_name: Optional[str] = None,
mx_recipe_name: Optional[str] = None,
enable_fusion_modeling: bool = False,
):
"""
Args:
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
* `shape_gen_name`: `llama`, `square`, or `sweep`
* `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep`
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
* `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
"""

assert not (
(float8_recipe_name is not None) and (mx_recipe_name is not None)
), "unsupported"
if float8_recipe_name is None and mx_recipe_name is None:
float8_recipe_name = "tensorwise"

print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"do_benchmarks: {do_benchmarks}")
print(f"shape_gen_name: {shape_gen_name}")
print(f"float8_recipe_name: {float8_recipe_name}")
print(f"mx_recipe_name: {mx_recipe_name}")
print(f"enable_fusion_modeling: {enable_fusion_modeling}")

M, K, N = sympy.symbols("M K N")

fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
fp8_ovhd_time_sympy = get_float8_mem_sympy(
M,
K,
N,
float8_recipe_name,
mx_recipe_name,
enable_fusion_modeling,
)
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None, None)
fp8_gemm_time_sympy = get_gemm_time_sympy(
M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name
)

bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
print()

headers = [
Expand All @@ -217,6 +278,9 @@ def run(
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
# we don't break them out and don't have a roofline for them.
"b_fp8_e2e_spdp",
# how well benchmarked gemms match roofline predicted gemms
"rb_bf16_gemm_ratio",
"rb_fp8_gemm_ratio",
]
results = []

Expand All @@ -237,43 +301,96 @@ def run(

# if enabled, also measured observed gemm time
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
rb_bf16_gemm_ratio = -1
rb_fp8_gemm_ratio = -1

if do_benchmarks:
# TODO(future): make the bf16 gemm times exactly match the e2e
# benchmarks, there is a slight deviation, probably related to gemm
# operand memory formats/transpositions below not exactly matching
# what PyTorch core is doing for `torch.mm`
# input @ weight_t = output
bf16_g1, f8_g1 = get_gemm_times(
M_val, K_val, N_val, True, gemm_cache_filename
"output",
M_val,
K_val,
N_val,
True,
"row_major:col_major",
float8_recipe_name,
mx_recipe_name,
gemm_cache_filename,
)
# grad_output @ weight = grad_input
bf16_g2, f8_g2 = get_gemm_times(
M_val, N_val, K_val, False, gemm_cache_filename
"grad_input",
M_val,
N_val,
K_val,
False,
"row_major:row_major",
float8_recipe_name,
mx_recipe_name,
gemm_cache_filename,
)
# input_t @ grad_output = grad_weight
bf16_g3, f8_g3 = get_gemm_times(
K_val, M_val, N_val, False, gemm_cache_filename
"grad_weight",
K_val,
M_val,
N_val,
False,
"col_major:row_major",
float8_recipe_name,
mx_recipe_name,
gemm_cache_filename,
)
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s

# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
r_fp8_ovhd_time_s = float(
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
)

b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
if do_benchmarks:
# create the model
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
if enable_fusion_modeling:
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
else:
m_orig = (
nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16()
)
x = torch.randn(
M_val, K_val, dtype=torch.bfloat16, device="cuda"
).requires_grad_()

# get the gradient of the right shape
grad_output = torch.randn(N_val, K_val, dtype=torch.bfloat16, device="cuda")

# get the bf16 gpu kernel time
torch._dynamo.reset()
m_bf16 = torch.compile(copy.deepcopy(m_orig))
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x)
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, grad_output)

# get the float8 dynamic scaling gpu kernel time

torch._dynamo.reset()
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
if float8_recipe_name is not None:
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
m_fp8_dyn = convert_to_float8_training(
copy.deepcopy(m_orig), config=config
)
else:
assert mx_recipe_name is not None
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
m_fp8_dyn = copy.deepcopy(m_orig)
swap_linear_with_mx_linear(m_fp8_dyn, config=config)
m_fp8_dyn = torch.compile(m_fp8_dyn)
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output)

results.append(
[
Expand All @@ -295,6 +412,9 @@ def run(
b_bf16_e2e_time_s,
b_fp8_e2e_time_s,
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20),
# gemm ratios
rb_bf16_gemm_ratio,
rb_fp8_gemm_ratio,
]
)

Expand Down
20 changes: 17 additions & 3 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,32 @@ def get_name_to_shapes_iter(
}
return name_to_shapes_70b.items()

elif shape_gen_name == "square":
elif shape_gen_name == "pow2":
assert (
M == K == N == None
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
name_to_shapes = {}
min_power_of_2 = 8 # 256
max_power_of_2 = 15 # 32,768
min_power_of_2 = 10 # 1024
max_power_of_2 = 14 # 16,384
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
val = 2**power_of_2
name_to_shapes[idx] = val, val, val
return name_to_shapes.items()

elif shape_gen_name == "pow2_extended":
assert (
M == K == N == None
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
name_to_shapes = {}
min_power_of_2 = 10 # 1024
max_power_of_2 = 14 # 16,384
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
val1 = 2**power_of_2
name_to_shapes[idx * 2] = val1, val1, val1
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
name_to_shapes[idx * 2 + 1] = val2, val2, val2
return name_to_shapes.items()

elif shape_gen_name == "sweep":
assert (
M == K == N == None
Expand Down
Loading
Loading