Skip to content

Commit 11b2a2f

Browse files
committed
[float8] Re-enable slow-accum in the bwd of axis-wise scaling schemes
And circumvent the issue with the slow CUTLASS kernel by using the cuBLAS kernel + manual scaling. ghstack-source-id: 54eb6ce ghstack-comment-id: 2517855458 Pull Request resolved: #1377
1 parent 1a0dbf1 commit 11b2a2f

File tree

2 files changed

+24
-35
lines changed

2 files changed

+24
-35
lines changed

torchao/float8/config.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ class Float8LinearConfig:
170170
#
171171
# Per-gemm configuration for gemms calculating `output`, `grad_input` and
172172
# `grad_weight`
173-
# TODO(this PR): throw warning if fast_accum False is used with axiswise scaling
174173
#
175174
gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True)
176175
gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig()
@@ -317,21 +316,10 @@ def recipe_name_to_linear_config(
317316
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
318317
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
319318

320-
# The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
321-
# fast with `use_fast_accum=True`. Note that rowwise scaling is more
322-
# accurate than tensorwise scaling, so the overall impact on accuracy
323-
# of tensorwise vs rowwise taking this flag into account will vary.
324-
gc_o = Float8GemmConfig(use_fast_accum=True)
325-
gc_gi = Float8GemmConfig(use_fast_accum=True)
326-
gc_gw = Float8GemmConfig(use_fast_accum=True)
327-
328319
return Float8LinearConfig(
329320
cast_config_input=cc_i,
330321
cast_config_weight=cc_w,
331322
cast_config_grad_output=cc_go,
332-
gemm_config_output=gc_o,
333-
gemm_config_grad_input=gc_gi,
334-
gemm_config_grad_weight=gc_gw,
335323
)
336324

337325
elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
@@ -359,24 +347,13 @@ def recipe_name_to_linear_config(
359347
cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED)
360348
cc_go_gw = CastConfig(scaling_type=ScalingType.DISABLED)
361349

362-
# The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
363-
# fast with `use_fast_accum=True`. Note that rowwise scaling is more
364-
# accurate than tensorwise scaling, so the overall impact on accuracy
365-
# of tensorwise vs rowwise taking this flag into account will vary.
366-
gc_o = Float8GemmConfig(use_fast_accum=True)
367-
gc_gi = Float8GemmConfig(use_fast_accum=True)
368-
gc_gw = Float8GemmConfig(use_fast_accum=True)
369-
370350
return Float8LinearConfig(
371351
cast_config_input=cc_i,
372352
cast_config_weight=cc_w,
373353
cast_config_grad_output=cc_go,
374354
cast_config_input_for_grad_weight=cc_i_gw,
375355
cast_config_weight_for_grad_input=cc_w_gi,
376356
cast_config_grad_output_for_grad_weight=cc_go_gw,
377-
gemm_config_output=gc_o,
378-
gemm_config_grad_input=gc_gi,
379-
gemm_config_grad_weight=gc_gw,
380357
)
381358

382359
else:

torchao/float8/float8_python_api.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,25 @@ def addmm_float8_unwrapped(
3737
a_inverse_scale = a_scale.reciprocal()
3838
b_inverse_scale = b_scale.reciprocal()
3939

40-
if output_dtype == torch.float32 and bias is not None:
40+
post_inverse_scale = None
41+
if (
42+
a_scale.shape == (a_data.shape[0], 1)
43+
and b_scale.shape == (1, b_data.shape[1])
44+
and not use_fast_accum
45+
):
46+
# The rowwise CUTLASS-based kernel is so slow without fast-accum that
47+
# we'd rather use the tensorwise cuBLAS-based kernel and do the scaling
48+
# manually afterwards (hoping Inductor will be able to fuse it).
49+
post_inverse_scale = a_inverse_scale * b_inverse_scale
50+
a_inverse_scale = a_inverse_scale.new_ones(())
51+
b_inverse_scale = a_inverse_scale.new_ones(())
52+
53+
post_bias = None
54+
if output_dtype == torch.float32:
4155
# Bias is not supported by _scaled_mm when output is fp32
42-
output = torch._scaled_mm(
43-
a_data,
44-
b_data,
45-
scale_a=a_inverse_scale,
46-
scale_b=b_inverse_scale,
47-
scale_result=output_scale,
48-
out_dtype=output_dtype,
49-
use_fast_accum=use_fast_accum,
50-
)
51-
output += bias
52-
return output
56+
post_bias = bias
57+
bias = None
58+
5359
output = torch._scaled_mm(
5460
a_data,
5561
b_data,
@@ -60,4 +66,10 @@ def addmm_float8_unwrapped(
6066
out_dtype=output_dtype,
6167
use_fast_accum=use_fast_accum,
6268
)
69+
70+
if post_inverse_scale is not None:
71+
output *= post_inverse_scale
72+
if post_bias is not None:
73+
output += post_bias
74+
6375
return output

0 commit comments

Comments
 (0)