Skip to content

Commit dec0313

Browse files
authored
float8 training axiswise scaling support with per-gemm-argument configuration (#940)
Summary: This PR finalizes the UX for axiswise scaling for float8 training, and introduces per-gemm-argument configurability to Float8Linear to enable exploration of future recipes. Not all combinations are supported yet. Specifically, the additional combination we now support and test is a recipe from @lw , where we do the following: ``` output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise grad_weight_hp = input_t_hp @ grad_output_hp ``` Key characteristics of this recipe: 1. increased accuracy for `grad_weight`, which is important for real workloads 2. `output` and `weight` now only need to be scaled axiswise across a single dim compared to vanilla all-axiswise, which is more amenable to fast kernels Here is how a user can configure this: ```python # # short form # config = torchao.float8.config.recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) # # or, long form # # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w_go = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED) # ensure fast_accum is on to get fast kernels gc_o = Float8GemmConfig(use_fast_accum=True) gc_gi = Float8GemmConfig(use_fast_accum=True) gc_gw = Float8GemmConfig(use_fast_accum=True) config = Float8Config( cast_config_input = cc_i, cast_config_weight = cc_w, cast_config_grad_output = cc_go, cast_config_input_for_grad_weight = cc_i_gw, cast_config_weight_for_grad_output = cc_w_go, cast_config_grad_output_for_grad_weight = cc_go_gw, gemm_config_output=gc_o, gemm_config_grad_input=gc_gi, gemm_config_grad_weight=gc_gw, ) ``` # performance Below we provide basic performance characteristics of axiswise scaling in general, and the all-axiswise and lw recipes. ## gemm performance of torch._scaled_mm baseline: tensorwise scaling ``` > python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True fast_accum name M K N ref_time_s fp8_time_s fp8_speedup 0 True 0 256 256 256 0.000004 0.000006 0.573115 1 True 1 512 512 512 0.000005 0.000007 0.659333 2 True 2 1024 1024 1024 0.000011 0.000010 1.080664 3 True 3 2048 2048 2048 0.000028 0.000017 1.596239 4 True 4 4096 4096 4096 0.000210 0.000082 2.551705 5 True 5 8192 8192 8192 0.001671 0.000680 2.457972 6 True 6 16384 16384 16384 0.015030 0.006498 2.313032 7 True 7 32768 32768 32768 0.103236 0.048097 2.146411 8 False 0 256 256 256 0.000004 0.000006 0.630061 9 False 1 512 512 512 0.000005 0.000007 0.767236 10 False 2 1024 1024 1024 0.000012 0.000008 1.391347 11 False 3 2048 2048 2048 0.000029 0.000020 1.457922 12 False 4 4096 4096 4096 0.000211 0.000101 2.100081 13 False 5 8192 8192 8192 0.001676 0.000788 2.128628 14 False 6 16384 16384 16384 0.014933 0.006351 2.351209 15 False 7 32768 32768 32768 0.103457 0.049498 2.090134 ``` experiment: axiswise-scaling ``` > python benchmarks/float8/bench_matmul.py --shape_gen_name square --use_gpu_kernel_time True --scaling_granularity axiswise fast_accum name M K N ref_time_s fp8_time_s fp8_speedup 0 True 0 256 256 256 0.000004 0.000004 0.966772 1 True 1 512 512 512 0.000005 0.000004 1.095791 2 True 2 1024 1024 1024 0.000011 0.000006 1.988363 3 True 3 2048 2048 2048 0.000027 0.000015 1.890065 4 True 4 4096 4096 4096 0.000210 0.000082 2.552356 5 True 5 8192 8192 8192 0.001674 0.001092 1.533132 6 True 6 16384 16384 16384 0.015114 0.008785 1.720480 7 True 7 32768 32768 32768 0.103286 0.071456 1.445439 8 False 0 256 256 256 0.000004 0.000004 0.899054 9 False 1 512 512 512 0.000005 0.000005 1.005340 10 False 2 1024 1024 1024 0.000011 0.000006 1.692868 11 False 3 2048 2048 2048 0.000028 0.000049 0.567655 12 False 4 4096 4096 4096 0.000210 0.000341 0.616193 13 False 5 8192 8192 8192 0.001678 0.002640 0.635541 14 False 6 16384 16384 16384 0.015051 0.021557 0.698212 15 False 7 32768 32768 32768 0.103497 0.169797 0.609533 ``` ## performance on microbenchmark of ln -> linear -> sigmoid Note: for large square shapes, performance tends to be fp8_delayed_tensorwise > fp8_dynamic_tensorwise > fp8_dynamic_axiswise > custom_recipe. For performance of fp8_dynamic_axiswise, it seems that the gap from tensorwise is mostly due to the gemm performance being behind tensorwise. ``` > python benchmarks/float8/float8_roofline.py ~/local/tmp/20241004_roofline.csv fwd_M fwd_K fwd_N bf16_gemm_s fp8_gemm_s fp8_axs_gemm_time_s fp8_oh_dyn_limit ... fp8_del_s fp8_dyn_axs_s fp8_lw_s fp8_dyn_sp fp8_del_sp fp8_dyn_axs_sp fp8_lw_sp 0 256 256 256 0.000011 0.000018 0.000012 6.50457971014493e-6 ... 0.000043 0.000049 0.000030 0.465634 0.457907 0.398357 0.643088 1 512 512 512 0.000014 0.000020 0.000013 8.01831884057971e-6 ... 0.000047 0.000054 0.000034 0.489556 0.493467 0.432643 0.685842 2 1024 1024 1024 0.000033 0.000026 0.000017 1.40732753623188e-5 ... 0.000060 0.000063 0.000050 0.734123 0.741467 0.705941 0.891199 3 2048 2048 2048 0.000081 0.000055 0.000044 3.82931014492754e-5 ... 0.000147 0.000159 0.000142 0.815678 0.800811 0.739865 0.827441 4 4096 4096 4096 0.000632 0.000274 0.000247 0.000135172405797101 ... 0.000602 0.000622 0.000662 1.236320 1.261848 1.221755 1.147678 5 8192 8192 8192 0.005027 0.002216 0.003292 0.000522689623188406 ... 0.003665 0.004776 0.005720 1.432213 1.513035 1.161130 0.969448 6 16384 16384 16384 0.045113 0.018975 0.025706 0.00207275849275362 ... 0.024664 0.032254 0.038051 1.803456 1.883291 1.440118 1.220738 7 32768 32768 32768 0.312459 0.147255 0.214492 0.00827303397101449 ... 0.182645 0.240962 0.270973 1.696376 1.766307 1.338827 1.190552 ``` ## performance on torchtitan LLaMa 3 8B on 8 H100 GPUs, float8 compute only: * baseline (bf16 + compile): 6,294 wps * f8 all-tensorwise: 7,359 wps (1.17x vs baseline) * f8 all-axiswise: 7,135 wps (1.13x vs baseline - surprising that this is close to all-tensorwise) * LW_AXISWISE_WITH_GW_HP: 6,506 wps (1.03x vs baseline) so, looks like we have performance work to do with `LW_AXISWISE_WITH_GW_HP` in future PRs # accuracy I did a very quick check that loss curves on torchtitan LLaMa 3 8B pretraining with 8 H100 GPUs look good for bf16/f8_tensorwise/f8_axiswise/f8_lw on 0.5k iterations. I will leave longer accuracy verifications for future work. <img width="973" alt="Screenshot 2024-10-04 at 10 05 24 PM" src="https://github.com/user-attachments/assets/0d682183-41ef-4f04-992f-cd0d0fc8a65c"> Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent e76db70 commit dec0313

File tree

10 files changed

+566
-429
lines changed

10 files changed

+566
-429
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
ScalingType,
7171
CastConfig,
7272
)
73+
from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName
7374

7475

7576
class LNLinearSigmoid(torch.nn.Module):
@@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
129130
else:
130131
# cache does not exist yet, create it
131132
cache = dict()
133+
else:
134+
cache = dict()
132135
key = f"{M},{K},{N},{fast_accum}"
133136
if key in cache:
134137
return cache[key]
@@ -153,13 +156,18 @@ def do_matmul(A, B):
153156
)
154157
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
155158

159+
scale_a = torch.ones(M, 1, device=device)
160+
scale_b = torch.ones(1, N, device=device)
161+
fast_accum = True # for axiswise
162+
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
163+
156164
# save to cache if needed
157165
if cache_filename is not None:
158-
cache[key] = [bf16_time_s, f8_time_s]
166+
cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s]
159167
with open(cache_filename, 'w') as f:
160168
json.dump(cache, f)
161169

162-
return bf16_time_s, f8_time_s
170+
return bf16_time_s, f8_time_s, f8_axs_time_s
163171

164172
def run(
165173
outfile: str,
@@ -231,13 +239,15 @@ def run(
231239
headers = [
232240
'fwd_M', 'fwd_K', 'fwd_N',
233241
# gemm microbenchmarks
234-
'bf16_gemm_s', 'fp8_gemm_s',
242+
'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s',
235243
# roofline memory overhead estimates
236244
'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit',
237245
'fp8_oh_del_limit', 'fp8_oh_del_nolimit',
238246
# actual e2e measurements
239-
'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s',
240-
'fp8_dyn_speedup', 'fp8_del_speedup',
247+
'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s',
248+
# 'fp8_lw_s',
249+
'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp',
250+
# 'fp8_lw_sp',
241251
]
242252
results = []
243253

@@ -248,15 +258,18 @@ def run(
248258
break
249259

250260
if gemm_time_strategy == "benchmarks":
251-
bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
252-
bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
253-
bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
261+
bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
262+
bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
263+
bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
254264
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
255265
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
266+
fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
256267
else:
257268
assert gemm_time_strategy == "roofline", "unsupported"
258269
bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
259270
fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
271+
# for now, assume axiswise gemm is similar to tensorwise
272+
fp8_axs_gemm_time_s = fp8_gemm_time_s
260273

261274
fp8_mem_time_dyn_limit_s = \
262275
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
@@ -291,23 +304,43 @@ def run(
291304
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
292305
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
293306
)
294-
m_fp8_del = convert_to_float8_training(m_orig)
307+
m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
295308
m_fp8_del = torch.compile(m_fp8_del)
296309
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)
297310

311+
# get the float8 dynamic axiswise scaling gpu kernel time
312+
torch._dynamo.reset()
313+
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
314+
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
315+
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
316+
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
317+
318+
# get the lw recipe scaling gpu kernel time
319+
# TODO(future PR): enable below once basic performance issues
320+
# are fixed
321+
# torch._dynamo.reset()
322+
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
323+
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
324+
# m_fp8_lw = torch.compile(m_fp8_lw)
325+
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
326+
298327
results.append([
299328
M_val, K_val, N_val,
300329
# gemm microbenchmarks
301-
bf16_time_val, fp8_gemm_time_s,
330+
bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s,
302331
# roofline overhead estimates
303332
fp8_mem_time_dyn_limit_s,
304333
fp8_mem_time_dyn_nolimit_s,
305334
fp8_mem_time_del_limit_s,
306335
fp8_mem_time_del_nolimit_s,
307336
# e2e numbers
308337
bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s,
338+
fp8_dyn_axs_time_actual_s,
339+
# fp8_lw_time_actual_s,
309340
bf16_time_actual_s / fp8_dyn_time_actual_s,
310341
bf16_time_actual_s / fp8_del_time_actual_s,
342+
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
343+
# bf16_time_actual_s / fp8_lw_time_actual_s,
311344
])
312345

313346
df = pd.DataFrame(results, columns=headers)

benchmarks/float8/profile_linear_float8.py

Lines changed: 13 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@
2727
Float8LinearConfig,
2828
ScalingType,
2929
ScalingGranularity,
30+
Float8LinearRecipeName,
31+
recipe_name_to_linear_config,
3032
)
3133
from torchao.float8.float8_linear_utils import (
3234
convert_to_float8_training,
3335
linear_requires_sync,
3436
sync_float8_amax_and_scale_history,
3537
)
38+
from torchao.testing.float8.test_utils import get_test_float8_linear_config
3639
from torch.profiler import profile, ProfilerActivity, record_function
3740
from utils import (
3841
kernel_name_to_category,
@@ -257,7 +260,7 @@ def main(
257260
scaling_type_input: str = "dynamic",
258261
scaling_type_weight: str = "dynamic",
259262
scaling_type_grad_output: str = "dynamic",
260-
scaling_granularity: str = "tensorwise",
263+
recipe_name: Optional[str] = None,
261264
model_type: str = "linear",
262265
dtype_filter: str = "both",
263266
add_inductor_metadata_to_trace: bool = True,
@@ -269,47 +272,17 @@ def main(
269272
scaling_type_input = ScalingType(scaling_type_input)
270273
scaling_type_weight = ScalingType(scaling_type_weight)
271274
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
272-
scaling_granularity = ScalingGranularity(scaling_granularity)
273275

274-
if scaling_type_input is ScalingType.STATIC:
275-
cast_config_input=CastConfig(
276-
scaling_type=scaling_type_input,
277-
static_scale=torch.tensor([1.0], device="cuda"),
278-
scaling_granularity=scaling_granularity,
276+
if recipe_name is None:
277+
config = get_test_float8_linear_config(
278+
scaling_type_input,
279+
scaling_type_weight,
280+
scaling_type_grad_output,
281+
emulate=False,
279282
)
280-
else:
281-
cast_config_input=CastConfig(
282-
scaling_type=scaling_type_input,
283-
scaling_granularity=scaling_granularity,
284-
)
285-
if scaling_type_weight is ScalingType.STATIC:
286-
cast_config_weight=CastConfig(
287-
scaling_type=scaling_type_weight,
288-
static_scale=torch.tensor([1.0], device="cuda"),
289-
scaling_granularity=scaling_granularity,
290-
)
291-
else:
292-
cast_config_weight=CastConfig(
293-
scaling_type=scaling_type_weight,
294-
scaling_granularity=scaling_granularity,
295-
)
296-
if scaling_type_grad_output is ScalingType.STATIC:
297-
cast_config_grad_output=CastConfig(
298-
scaling_type=scaling_type_grad_output,
299-
static_scale=torch.tensor([1.0], device="cuda"),
300-
scaling_granularity=scaling_granularity,
301-
)
302-
else:
303-
cast_config_grad_output=CastConfig(
304-
scaling_type=scaling_type_grad_output,
305-
scaling_granularity=scaling_granularity,
306-
)
307-
308-
config = Float8LinearConfig(
309-
cast_config_input=cast_config_input,
310-
cast_config_weight=cast_config_weight,
311-
cast_config_grad_output=cast_config_grad_output,
312-
)
283+
elif recipe_name is not None:
284+
recipe_name = Float8LinearRecipeName(recipe_name)
285+
config = recipe_name_to_linear_config(recipe_name)
313286

314287
scaling_repr = "_".join(
315288
[

0 commit comments

Comments
 (0)