Skip to content

Commit 0a45ccd

Browse files
y-sqfacebook-github-bot
authored andcommitted
test rowwise fp32
Summary: Running rowwise scaling on fp32 tensors got the error, P1794222725 ``` RuntimeError: Only bf16 high precision output types are supported for row-wise scaling. ``` This pr adds an option to explicitly use bfloat16 as the output of rowwise_scaled, and cast it back to the original precision. It can be enabled by setting ``` config = dataclasses.replace(config, convert_dtypes_for_rowwise_scaled_mm=True) ``` Differential Revision: D73552660
1 parent d506cc7 commit 0a45ccd

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

torchao/float8/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ class Float8LinearConfig:
204204
# same value in the forward pass as the backward passes.
205205
round_scales_to_power_of_2: bool = False
206206

207+
# This is a workaround for using rowwise_scaled_mm for non-bf16 tensors.
208+
# Currently, rowwise_scaled_mm only supports bf16 outputs.
209+
# We workaround this by using bf16 as rowwise_scaled_mm output, and cast back to the original precision.
210+
convert_dtypes_for_rowwise_scaled_mm: bool = False
211+
207212
def __post_init__(self):
208213
# Populate the additional cast overrides, if the user did not specify them
209214
# Note: this hacks around the frozen-ness of this dataclass

torchao/float8/float8_linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,20 +281,23 @@ def __init__(self, *args, **kwargs):
281281
self.config.gemm_config_output.use_fast_accum,
282282
False,
283283
self.config.pad_inner_dim,
284+
config.convert_dtypes_for_rowwise_scaled_mm,
284285
),
285286
# grad_input
286287
ScaledMMConfig(
287288
config.emulate,
288289
self.config.gemm_config_grad_input.use_fast_accum,
289290
False,
290291
self.config.pad_inner_dim,
292+
config.convert_dtypes_for_rowwise_scaled_mm,
291293
),
292294
# grad_weight
293295
ScaledMMConfig(
294296
config.emulate,
295297
self.config.gemm_config_grad_weight.use_fast_accum,
296298
False,
297299
self.config.pad_inner_dim,
300+
config.convert_dtypes_for_rowwise_scaled_mm,
298301
),
299302
)
300303

torchao/float8/float8_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def addmm_float8_unwrapped(
3131
output_scale: Optional[torch.Tensor] = None,
3232
bias: Optional[torch.Tensor] = None,
3333
use_fast_accum: bool = False,
34+
convert_dtypes_for_rowwise_scaled_mm: bool = False,
3435
) -> torch.Tensor:
3536
"""
3637
This is the unwrapped version of addmm_float8, which does not take in Float8Tensors
@@ -54,6 +55,11 @@ def addmm_float8_unwrapped(
5455
a_inverse_scale = a_inverse_scale.new_ones(())
5556
b_inverse_scale = a_inverse_scale.new_ones(())
5657

58+
orig_dtype = output_dtype
59+
60+
if convert_dtypes_for_rowwise_scaled_mm and is_rowwise_scaling:
61+
output_dtype = torch.bfloat16
62+
5763
post_bias = None
5864
if output_dtype == torch.float32:
5965
# Bias is not supported by _scaled_mm when output is fp32
@@ -76,6 +82,9 @@ def addmm_float8_unwrapped(
7682
if post_bias is not None:
7783
output += post_bias
7884

85+
if convert_dtypes_for_rowwise_scaled_mm and is_rowwise_scaling:
86+
output = output.to(orig_dtype)
87+
7988
return output
8089

8190

@@ -379,6 +388,7 @@ def float8_mm(aten_op, args, kwargs=None):
379388
output_scale=None,
380389
bias=None,
381390
use_fast_accum=scaled_mm_config.use_fast_accum,
391+
convert_dtypes_for_rowwise_scaled_mm=scaled_mm_config.convert_dtypes_for_rowwise_scaled_mm,
382392
)
383393
return tensor_out
384394

torchao/float8/float8_tensor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class ScaledMMConfig(NamedTuple):
5959
use_fast_accum: bool = False
6060
fp8_output: bool = False
6161
pad_inner_dim: bool = False
62+
convert_dtypes_for_rowwise_scaled_mm: bool = False
6263

6364

6465
class LinearMMConfig(NamedTuple):
@@ -75,9 +76,9 @@ class LinearMMConfig(NamedTuple):
7576
grad_weight (ScaledMMConfig): Configuration for the grad_weight gemm.
7677
"""
7778

78-
output: ScaledMMConfig = ScaledMMConfig(False, True, False, False)
79-
grad_input: ScaledMMConfig = ScaledMMConfig(False, False, False, False)
80-
grad_weight: ScaledMMConfig = ScaledMMConfig(False, False, False, False)
79+
output: ScaledMMConfig = ScaledMMConfig(False, True, False, False, False)
80+
grad_input: ScaledMMConfig = ScaledMMConfig(False, False, False, False, False)
81+
grad_weight: ScaledMMConfig = ScaledMMConfig(False, False, False, False, False)
8182

8283

8384
class GemmInputRole(enum.Enum):

0 commit comments

Comments
 (0)