Skip to content

Initial prototype of differentiable _scaled_grouped_mm function #1969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 64 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
134242b
grouped_mm forward pass
danielvegamyhre Mar 26, 2025
2113753
add unit test
danielvegamyhre Mar 26, 2025
0a90f0b
only support float8
danielvegamyhre Mar 26, 2025
a761549
rowwise scaling test passing
danielvegamyhre Mar 27, 2025
8d15a8a
add 3Dx3D test
danielvegamyhre Mar 27, 2025
cced381
numeric unit tests passing
danielvegamyhre Mar 27, 2025
46d7e42
lint
danielvegamyhre Mar 27, 2025
e32d528
update 3Dx3D case
danielvegamyhre Mar 27, 2025
c42af73
lint
danielvegamyhre Mar 27, 2025
e61c71d
lint
danielvegamyhre Mar 27, 2025
94a0cba
change func name
danielvegamyhre Mar 27, 2025
fce469b
lint
danielvegamyhre Mar 27, 2025
3899bb2
B must be 3D
danielvegamyhre Mar 27, 2025
5099838
add docstring
danielvegamyhre Mar 27, 2025
4e04022
allow other axiswise dims so we can pass in 3D B tensor tranposed
danielvegamyhre Mar 27, 2025
4117a9e
clean up
danielvegamyhre Mar 27, 2025
61f0ee4
add todo
danielvegamyhre Mar 27, 2025
dc40622
lint
danielvegamyhre Mar 27, 2025
80b7630
rename var
danielvegamyhre Mar 27, 2025
dc013a3
check input dims are compatible
danielvegamyhre Mar 27, 2025
4f385e5
add detailed comments
danielvegamyhre Mar 27, 2025
72a9b9f
update comments
danielvegamyhre Mar 28, 2025
c4c6c99
update comments
danielvegamyhre Mar 28, 2025
cf42af1
add backward pass
danielvegamyhre Mar 28, 2025
dc6bcf3
add detailed comments
danielvegamyhre Mar 28, 2025
4c5e9db
2d-2d working
danielvegamyhre Mar 28, 2025
c9d30b6
backward working for everything except 2d-3d
danielvegamyhre Mar 28, 2025
c19bc88
all test cases working
danielvegamyhre Mar 28, 2025
90b99ba
docstring
danielvegamyhre Mar 28, 2025
526d88c
update test
danielvegamyhre Mar 28, 2025
25fa1c8
handle jagged 2d tensors
danielvegamyhre Mar 28, 2025
281950c
lint
danielvegamyhre Mar 29, 2025
9f15ac4
work on test for gradients
danielvegamyhre Apr 1, 2025
10a9823
grad is none
danielvegamyhre Apr 1, 2025
f20ddf3
add assert
danielvegamyhre Apr 1, 2025
922b842
grads not none
danielvegamyhre Apr 1, 2025
4b3ca69
outputs mismatch
danielvegamyhre Apr 1, 2025
5d367df
forward matches but _scaled_mm has no backward
danielvegamyhre Apr 1, 2025
7d21bbb
all outputs match
danielvegamyhre Apr 1, 2025
7dc7c73
gradients match
danielvegamyhre Apr 1, 2025
5703cfd
cleanup
danielvegamyhre Apr 1, 2025
6f65dae
use quant primitives manually in forward
danielvegamyhre Apr 1, 2025
93c2692
clean up
danielvegamyhre Apr 1, 2025
d7949c4
lint
danielvegamyhre Apr 1, 2025
212b47f
don't change float8 tensor
danielvegamyhre Apr 1, 2025
4b42be3
lint
danielvegamyhre Apr 1, 2025
fa708fd
lint
danielvegamyhre Apr 1, 2025
fad9d36
fix test
danielvegamyhre Apr 1, 2025
c54b528
improve test readability
danielvegamyhre Apr 1, 2025
b571442
remove old comment
danielvegamyhre Apr 1, 2025
302b554
reorganize
danielvegamyhre Apr 1, 2025
2864068
only support rowwise scaling
danielvegamyhre Apr 1, 2025
fb48868
lint
danielvegamyhre Apr 1, 2025
1cd3658
explicit re-export
danielvegamyhre Apr 1, 2025
c154222
lint
danielvegamyhre Apr 1, 2025
e9f2174
reorganize
danielvegamyhre Apr 1, 2025
4ba8453
add tests for invalid dims
danielvegamyhre Apr 1, 2025
c2e5d42
validate group sizes are multiples of 16
danielvegamyhre Apr 1, 2025
7466ce4
use save_for_backward for offs
danielvegamyhre Apr 1, 2025
527525b
remove group size assert to avoid device-host sync
danielvegamyhre Apr 1, 2025
3ea7455
precompute B_non_transposed_fp8_col_major for backward to save memory
danielvegamyhre Apr 1, 2025
a1e7c53
inline configs in impl
danielvegamyhre Apr 2, 2025
d405950
update naming
danielvegamyhre Apr 2, 2025
300db8b
add assert
danielvegamyhre Apr 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchao/prototype/scaled_grouped_mm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
_scaled_grouped_mm as _scaled_grouped_mm,
)
361 changes: 361 additions & 0 deletions torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,361 @@
from typing import Optional, Tuple

import torch

from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated


def _scaled_grouped_mm(
A: torch.Tensor,
B_t: torch.Tensor,
offs: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
This function performs dynamic float8 quantization with row-wise scaling
on the input tensors A and B, then performs a scaled grouped GEMM and returns the results.

Args:
A (bf16/float32 torch.Tensor): The first high-precision input tensor, which must be a 2D tensor of shape (M * num_groups, K)
and in row-major memory layout.
B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (B, K, N)
and in column-major memory layout.
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
"""
return _Float8GroupedMM.apply(
A,
B_t,
offs,
out_dtype,
)


class _Float8GroupedMM(torch.autograd.Function):
"""Differentiable implementation of grouped GEMM with dynamic float8 quantization."""

@staticmethod
def forward(
ctx,
A: torch.Tensor,
B_t: torch.Tensor,
offs: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
# torchao _scaled_grouped_mm only supports A=2D, B=3D.
assert A.ndim == 2, "A must be 2D"
assert B_t.ndim == 3, "B must be 3D"

assert (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these assertions redundant with what is there in torch._scaled_grouped_mm? Ideally we only assert in one place for each condition.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some are specific to this implementation (only support 2d A and 3D B for now, the primary use case), but some are redundant, yes.

I found the device-side assertions to be opaque sometimes, I often had to read the kernel code to figure out why it wasn't working, to see the exact condition that was failing if the error was a bit ambiguous. So my goal here was to make the requirements more transparent and easier to debug.

(I think since scaled grouped mm is a kernel executing on the GPU, if a check fails, then on the CPU side the error message we can't see the actual line of code with the condition that failed, it just points to the entrypoint torch._scaled_grouped_mm)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim assertions are not device side, so they have correct stack trace. Device-side assertions are put in when it's impossible to check the same thing on the host without introducing host-device sync, so they happen for a reason and unfortunately can't be improved. To get correct stack trace run with CUDA_LAUNCH_BLOCKING=1

A.size(-1) % 16 == 0
), f"A must have a last dim divisible by 16, but got shape: {A.shape}"
assert (
B_t.size(-2) % 16 == 0 and B_t.size(-1) % 16 == 0
), f"B must have last 2 dims divisible by 16, but got shape: {B_t.shape}"

# Assert input tensors are in high-precision dtypes.
assert (
A.dtype == torch.float32 or A.dtype == torch.bfloat16
), "A must be float32 or bfloat16"
assert (
B_t.dtype == torch.float32 or B_t.dtype == torch.bfloat16
), "B must be float32 or bfloat16"
assert offs.dtype == torch.int32, "offs must be int32"

# Assert A and B dims are compatible for a scaled grouped GEMM.
assert A.size(-1) == B_t.size(
-2
), f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm"

# The left operand in the scaled grouped GEMM must be row-major due to hardware requirements.
assert not _is_column_major(A), "A must be row-major"

# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
assert _is_column_major(B_t), "B must be column-major"

# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
# A shape: (M, K)
# A_scales shape: (M,1)
A_scales = tensor_to_scale(
A,
torch.float8_e4m3fn,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
round_scales_to_power_of_2=True,
)
A_scaled = A.to(torch.float32) * A_scales
A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)

# Convert B to float8, column-major for right operand of grouped GEMM.
# B shape: (B, K, N)
# B scales must be computed rowwise keeping the outer/final dim, so:
# B_scales shape: (B, 1, N)
B_t_scales = tensor_to_scale(
B_t,
torch.float8_e4m3fn,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-2,
round_scales_to_power_of_2=True,
)
B_t_scaled = B_t.to(torch.float32) * B_t_scales
B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)

# Precompute non-transposed B column-major for backward, to save memory by storing the
# low precision B tensor instead of the high precision B tensor.
# In the backward this is needed for grad_A: grad_output @ B.
B = B_t.contiguous().transpose(-2, -1)

# - B shape: (B, K, N)
# - B scales must be computed rowwise keeping the outer/final dim, so:
# - B_scale shape: (B, 1, N)
B_scales = tensor_to_scale(
B,
torch.float8_e4m3fn,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-2,
round_scales_to_power_of_2=True,
)
B_scaled = B.to(torch.float32) * B_scales
B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn)

# Store what we need for backward.
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
ctx.out_dtype = out_dtype

# Perform scaled grouped GEMM and return result.
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
return torch._scaled_grouped_mm(
A_fp8_row_major,
B_t_fp8_col_major,
A_scales.squeeze().reciprocal(),
B_t_scales.squeeze().reciprocal(),
offs,
out_dtype=out_dtype,
use_fast_accum=True,
)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors
out_dtype = ctx.out_dtype

# Convert grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_A: grad_output @ B
#
# grad_output shape: (M, N)
# grad_output_scale shape: (M, 1)
grad_output_scales = tensor_to_scale(
grad_output,
torch.float8_e4m3fn,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
round_scales_to_power_of_2=True,
)
grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales
grad_output_fp8_row_major = to_fp8_saturated(
grad_output_scaled, torch.float8_e4m3fn
)

# Compute grad_A.
#
# grad_A = grad_output @ B
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
grad_A = torch._scaled_grouped_mm(
grad_output_fp8_row_major,
B_fp8_col_major,
grad_output_scales.squeeze().reciprocal(),
B_scales.squeeze().reciprocal(),
offs,
out_dtype=out_dtype,
use_fast_accum=True,
)

# Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_B: grad_output_t @ A
grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous()

# Convert A to float8, column-major for right operand of grouped GEMM:
# needed for grad_B: grad_output @ A
A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1)

# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
grad_output_t_fp8_row_major, grad_output_t_scales = (
_to_2d_jagged_float8_tensor_rowwise(
grad_output_t_row_major,
offs,
target_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
)
A_fp8_col_major, A_scales = _to_2d_jagged_float8_tensor_colwise(
A_col_major,
offs,
target_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)

# Compute grad_B = grad_output_t @ A.
# grad_B = grad_output_t @ A
# grad_B = (N,M) @ (M,K) = (N,K)
grad_B = torch._scaled_grouped_mm(
grad_output_t_fp8_row_major,
A_fp8_col_major,
grad_output_t_scales.reciprocal(),
A_scales.reciprocal(),
offs,
out_dtype=out_dtype,
use_fast_accum=True,
)
return grad_A, grad_B.transpose(-2, -1), None, None, None, None


def _to_2d_jagged_float8_tensor_colwise(
A_col_major: torch.Tensor,
offs: torch.Tensor,
target_dtype: torch.dtype = torch.float8_e4m3fn,
round_scales_to_power_of_2: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function converts the 2D input tensor A to a jagged float8 tensor,
with scales computed along *logical columns* for each group individually,
where groups are determined based on the offsets.

For the right operand of a normal scaled GEMM, the rowwise scales are computed over logical columns.
(i.e., a tensor of (K,N) will have scales of shape (1,N).

However, for a 2D right operand of a grouped GEMM, these logical columns go through multiple distinct
groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
along the logical columns and apply it to the entire tensor.

Instead, we need to compute scales for each subtensor individually. For a tensor of shape (K,N) this results
in scales of shape (1,N * num_groups).

Args:
A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.

Returns:
A tuple containing the jagged float8 tensor and the scales used for the conversion.
"""
assert A_col_major.ndim == 2, "A must be 2D"

num_groups = offs.numel()
A_fp8_col_major = torch.empty_like(A_col_major, dtype=target_dtype)
A_scales = torch.empty(
A_fp8_col_major.size(1) * num_groups,
dtype=torch.float32,
device=A_fp8_col_major.device,
)

start_idx = 0
next_scale_idx = 0
for end_idx in offs.tolist():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here also would be better to have a triton kernel computing scales that could read offs on the device, to avoid syncs

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this implementation has more room for perf optimization, my first goal was to get accurate numerics. As a follow up I can write a triton kernel to avoid this device-host sync and for loop.

# Get the subtensor of A for this group, fetching the next group of rows, with all columns for each.
subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, K)

# Compute local rowwise scales for this subtensor, which are along logical columns for the right operand.
subtensor_scales = tensor_to_scale(
subtensor,
target_dtype,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=0,
round_scales_to_power_of_2=round_scales_to_power_of_2,
)

# Apply scales to subtensor and convert to float8.
tensor_scaled = subtensor.to(torch.float32) * subtensor_scales
float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype)

# Store this portion of the resulting float8 tensor and scales.
A_fp8_col_major[start_idx:end_idx, :] = float8_subtensor
A_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = (
subtensor_scales.squeeze()
)

# Update start index for next group.
start_idx = end_idx
next_scale_idx += subtensor_scales.numel()

return A_fp8_col_major, A_scales


def _to_2d_jagged_float8_tensor_rowwise(
x: torch.Tensor,
offs: torch.Tensor,
target_dtype: torch.dtype,
round_scales_to_power_of_2: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function converts the 2D input tensor to a jagged float8 tensor,
with scales computed along *logical rows* for each group individually,
where groups are determined based on the offsets.

For a 2D *left* operand of a normal scaled GEMM, the rowwise scales are computed over logical rows.
(i.e., a tensor of (M,K) will have scales of shape (M,1).

However, for a 2D left operand of a grouped GEMM, these logical rows go through multiple distinct
groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
along the logical rows and apply it to the entire tensor.

Instead, we need to compute scales for each subtensor individually. For a tensor of shape (M,K) this results
in scales of shape (M * num_groups, 1).

Args:
A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.

Returns:
A tuple containing the jagged float8 tensor and the scales used for the conversion.
"""
assert x.ndim == 2, "input tensor must be 2D"

num_groups = offs.numel()
x_fp8 = torch.empty_like(x, dtype=target_dtype)
x_scales = torch.empty(
x_fp8.size(0) * num_groups, dtype=torch.float32, device=x_fp8.device
)

start_idx = 0
next_scale_idx = 0
for end_idx in offs.tolist():
# Get the subtensor of A for this group, fetching all rows with the next group of rows.
subtensor = x[:, start_idx:end_idx] # (M, local_group_size)

# Compute local rowwise scales for this subtensor, which are along logical rows for the left operand.
subtensor_scales = tensor_to_scale(
subtensor,
target_dtype,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-1,
round_scales_to_power_of_2=round_scales_to_power_of_2,
)

# Apply scales to subtensor and convert to float8.
tensor_scaled = subtensor.to(torch.float32) * subtensor_scales
float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype)

# Store this portion of the resulting float8 tensor and scales.
x_fp8[:, start_idx:end_idx] = float8_subtensor
x_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = (
subtensor_scales.squeeze()
)

# Update start index for next group.
start_idx = end_idx
next_scale_idx += subtensor_scales.numel()

return x_fp8, x_scales


def _is_column_major(x: torch.Tensor) -> bool:
"""
This function checks if the input tensor is column-major.

Args:
x (torch.Tensor): The input tensor to be checked.

Returns:
A boolean indicating whether the input tensor is column-major.
"""
assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D"
return x.stride(-2) == 1 and x.stride(-1) > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work for 4d/5d/etc tensors? if not, maybe assert that rank is 3?

Loading
Loading