Skip to content

float8 training axiswise scaling support with per-gemm-argument configuration #940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 51 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
183dbec
Update
vkuzo Sep 23, 2024
241f815
Update
vkuzo Sep 23, 2024
d0b1002
Update
vkuzo Sep 23, 2024
f15c2a0
Update
vkuzo Sep 23, 2024
c816ce9
Update
vkuzo Sep 23, 2024
9150b4f
Update
vkuzo Sep 23, 2024
40279fb
Update
vkuzo Sep 23, 2024
459e92c
Update
vkuzo Sep 23, 2024
732b231
Update
vkuzo Sep 23, 2024
e7c15d1
Update
vkuzo Sep 24, 2024
1d01df3
Update
vkuzo Sep 25, 2024
62acdaf
Update
vkuzo Sep 25, 2024
381e16e
Update
vkuzo Sep 27, 2024
afdf660
Update
vkuzo Sep 27, 2024
d53f2ce
Update
vkuzo Sep 27, 2024
0737eb8
Update
vkuzo Sep 27, 2024
2791eb3
Update
vkuzo Sep 27, 2024
6cfd1cd
Update
vkuzo Sep 27, 2024
94907e5
Update
vkuzo Sep 27, 2024
423760a
Update
vkuzo Sep 27, 2024
552db23
Update
vkuzo Sep 27, 2024
953bc2f
Update
vkuzo Sep 27, 2024
c5d19e0
Update
vkuzo Sep 27, 2024
24f0f3b
Update
vkuzo Sep 27, 2024
7e0fe97
Update
vkuzo Sep 27, 2024
10f2e0f
Update
vkuzo Sep 27, 2024
4437054
Update
vkuzo Sep 30, 2024
31a017b
Update
vkuzo Sep 30, 2024
743b4c1
Update
vkuzo Sep 30, 2024
0c473c4
Update
vkuzo Oct 2, 2024
c1be278
Update
vkuzo Oct 2, 2024
179e3b3
Update
vkuzo Oct 2, 2024
fc8d4ef
Update
vkuzo Oct 2, 2024
76322de
Update
vkuzo Oct 2, 2024
4eec2bd
Update
vkuzo Oct 2, 2024
ac6f768
Update
vkuzo Oct 2, 2024
02b9fca
Update
vkuzo Oct 2, 2024
5756aa5
Update
vkuzo Oct 2, 2024
1f01df9
Update
vkuzo Oct 4, 2024
b9bcc30
Update
vkuzo Oct 4, 2024
e595f30
Update
vkuzo Oct 4, 2024
4bb59a6
Update
vkuzo Oct 4, 2024
c1c218f
Update
vkuzo Oct 4, 2024
024fe94
Update
vkuzo Oct 4, 2024
4027694
Update
vkuzo Oct 4, 2024
076de91
Update
vkuzo Oct 4, 2024
ca127f0
Update
vkuzo Oct 4, 2024
712fd5d
Update
vkuzo Oct 5, 2024
d70326c
Update
vkuzo Oct 7, 2024
f2d104a
Update
vkuzo Oct 7, 2024
b536435
Update
vkuzo Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
ScalingType,
CastConfig,
)
from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName


class LNLinearSigmoid(torch.nn.Module):
Expand Down Expand Up @@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
else:
# cache does not exist yet, create it
cache = dict()
else:
cache = dict()
key = f"{M},{K},{N},{fast_accum}"
if key in cache:
return cache[key]
Expand All @@ -153,13 +156,18 @@ def do_matmul(A, B):
)
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
fast_accum = True # for axiswise
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

# save to cache if needed
if cache_filename is not None:
cache[key] = [bf16_time_s, f8_time_s]
cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s]
with open(cache_filename, 'w') as f:
json.dump(cache, f)

return bf16_time_s, f8_time_s
return bf16_time_s, f8_time_s, f8_axs_time_s

def run(
outfile: str,
Expand Down Expand Up @@ -231,13 +239,15 @@ def run(
headers = [
'fwd_M', 'fwd_K', 'fwd_N',
# gemm microbenchmarks
'bf16_gemm_s', 'fp8_gemm_s',
'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s',
# roofline memory overhead estimates
'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit',
'fp8_oh_del_limit', 'fp8_oh_del_nolimit',
# actual e2e measurements
'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s',
'fp8_dyn_speedup', 'fp8_del_speedup',
'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s',
# 'fp8_lw_s',
'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp',
# 'fp8_lw_sp',
]
results = []

Expand All @@ -248,15 +258,18 @@ def run(
break

if gemm_time_strategy == "benchmarks":
bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
else:
assert gemm_time_strategy == "roofline", "unsupported"
bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
# for now, assume axiswise gemm is similar to tensorwise
fp8_axs_gemm_time_s = fp8_gemm_time_s

fp8_mem_time_dyn_limit_s = \
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
Expand Down Expand Up @@ -291,23 +304,43 @@ def run(
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
m_fp8_del = convert_to_float8_training(m_orig)
m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_del = torch.compile(m_fp8_del)
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)

# get the lw recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)

results.append([
M_val, K_val, N_val,
# gemm microbenchmarks
bf16_time_val, fp8_gemm_time_s,
bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s,
# roofline overhead estimates
fp8_mem_time_dyn_limit_s,
fp8_mem_time_dyn_nolimit_s,
fp8_mem_time_del_limit_s,
fp8_mem_time_del_nolimit_s,
# e2e numbers
bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s,
fp8_dyn_axs_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_del_time_actual_s,
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
# bf16_time_actual_s / fp8_lw_time_actual_s,
])

df = pd.DataFrame(results, columns=headers)
Expand Down
53 changes: 13 additions & 40 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
Float8LinearConfig,
ScalingType,
ScalingGranularity,
Float8LinearRecipeName,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
sync_float8_amax_and_scale_history,
)
from torchao.testing.float8.test_utils import get_test_float8_linear_config
from torch.profiler import profile, ProfilerActivity, record_function
from utils import (
kernel_name_to_category,
Expand Down Expand Up @@ -257,7 +260,7 @@ def main(
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
scaling_granularity: str = "tensorwise",
recipe_name: Optional[str] = None,
model_type: str = "linear",
dtype_filter: str = "both",
add_inductor_metadata_to_trace: bool = True,
Expand All @@ -269,47 +272,17 @@ def main(
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
scaling_granularity = ScalingGranularity(scaling_granularity)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
if recipe_name is None:
config = get_test_float8_linear_config(
scaling_type_input,
scaling_type_weight,
scaling_type_grad_output,
emulate=False,
)
else:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
scaling_granularity=scaling_granularity,
)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
scaling_granularity=scaling_granularity,
)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
scaling_granularity=scaling_granularity,
)
else:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
scaling_granularity=scaling_granularity,
)

config = Float8LinearConfig(
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)
elif recipe_name is not None:
recipe_name = Float8LinearRecipeName(recipe_name)
config = recipe_name_to_linear_config(recipe_name)

scaling_repr = "_".join(
[
Expand Down
Loading
Loading