Skip to content

test rowwise fp32 #2431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ class Float8LinearConfig:
# same value in the forward pass as the backward passes.
round_scales_to_power_of_2: bool = False

# This is a workaround for using rowwise_scaled_mm for non-bf16 tensors.
# Currently, rowwise_scaled_mm only supports bf16 outputs.
# We workaround this by using bf16 as rowwise_scaled_mm output, and cast back to the original precision.
convert_dtypes_for_rowwise_scaled_mm: bool = False

def __post_init__(self):
# Populate the additional cast overrides, if the user did not specify them
# Note: this hacks around the frozen-ness of this dataclass
Expand Down
3 changes: 3 additions & 0 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,20 +281,23 @@ def __init__(self, *args, **kwargs):
self.config.gemm_config_output.use_fast_accum,
False,
self.config.pad_inner_dim,
config.convert_dtypes_for_rowwise_scaled_mm,
),
# grad_input
ScaledMMConfig(
config.emulate,
self.config.gemm_config_grad_input.use_fast_accum,
False,
self.config.pad_inner_dim,
config.convert_dtypes_for_rowwise_scaled_mm,
),
# grad_weight
ScaledMMConfig(
config.emulate,
self.config.gemm_config_grad_weight.use_fast_accum,
False,
self.config.pad_inner_dim,
config.convert_dtypes_for_rowwise_scaled_mm,
),
)

Expand Down
10 changes: 10 additions & 0 deletions torchao/float8/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def addmm_float8_unwrapped(
output_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_fast_accum: bool = False,
convert_dtypes_for_rowwise_scaled_mm: bool = False,
) -> torch.Tensor:
"""
This is the unwrapped version of addmm_float8, which does not take in Float8Tensors
Expand All @@ -54,6 +55,11 @@ def addmm_float8_unwrapped(
a_inverse_scale = a_inverse_scale.new_ones(())
b_inverse_scale = a_inverse_scale.new_ones(())

orig_dtype = output_dtype

if convert_dtypes_for_rowwise_scaled_mm and is_rowwise_scaling:
output_dtype = torch.bfloat16

Copy link
Contributor

@vkuzo vkuzo Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding a flag, TBH I think we can just enable this on-by-default, like this:

file issue in PyTorch core to add float32 output to scaled_mm

output_dtype_to_use = output_dtype
if is_rowwise_scaling:
    # work around torch._scaled_mm not having float32 output type
    # TODO(issue number): remove this once torch._scaled_mm supports float32 output
    output_dtype_to_use = torch.bfloat16
output = torch._scaled_mm(..., output_dtype_to_use, ...)
...
if is_rowwise_scaling and output_dtype == torch.float32:
    # work around torch._scaled_mm not having float32 output type
    # TODO(issue number): remove this once torch._scaled_mm supports float32 output
    output = output.to(orig_dtype)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, I'll change to enable by default and file an issue.

post_bias = None
if output_dtype == torch.float32:
# Bias is not supported by _scaled_mm when output is fp32
Expand All @@ -76,6 +82,9 @@ def addmm_float8_unwrapped(
if post_bias is not None:
output += post_bias

if convert_dtypes_for_rowwise_scaled_mm and is_rowwise_scaling:
output = output.to(orig_dtype)

return output


Expand Down Expand Up @@ -379,6 +388,7 @@ def float8_mm(aten_op, args, kwargs=None):
output_scale=None,
bias=None,
use_fast_accum=scaled_mm_config.use_fast_accum,
convert_dtypes_for_rowwise_scaled_mm=scaled_mm_config.convert_dtypes_for_rowwise_scaled_mm,
)
return tensor_out

Expand Down
7 changes: 4 additions & 3 deletions torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ScaledMMConfig(NamedTuple):
use_fast_accum: bool = False
fp8_output: bool = False
pad_inner_dim: bool = False
convert_dtypes_for_rowwise_scaled_mm: bool = False


class LinearMMConfig(NamedTuple):
Expand All @@ -75,9 +76,9 @@ class LinearMMConfig(NamedTuple):
grad_weight (ScaledMMConfig): Configuration for the grad_weight gemm.
"""

output: ScaledMMConfig = ScaledMMConfig(False, True, False, False)
grad_input: ScaledMMConfig = ScaledMMConfig(False, False, False, False)
grad_weight: ScaledMMConfig = ScaledMMConfig(False, False, False, False)
output: ScaledMMConfig = ScaledMMConfig(False, True, False, False, False)
grad_input: ScaledMMConfig = ScaledMMConfig(False, False, False, False, False)
grad_weight: ScaledMMConfig = ScaledMMConfig(False, False, False, False, False)


class GemmInputRole(enum.Enum):
Expand Down
Loading