Skip to content

Commit a761549

Browse files
rowwise scaling test passing
1 parent 0a90f0b commit a761549

File tree

3 files changed

+46
-16
lines changed

3 files changed

+46
-16
lines changed

torchao/float8/float8_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ def float8_transpose(aten_op, args, kwargs=None):
151151
else:
152152
new_scale = args[0]._scale
153153

154-
if aten_op == aten.transpose.int:
155-
_assert_tensorwise_scale(aten_op, args[0]._scale)
156-
157154
old_axiswise_dim = args[0]._axiswise_dim
158155
new_axiswise_dim = old_axiswise_dim
159156
if old_axiswise_dim is not None:

torchao/prototype/grouped_mm/__init__.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional
33

44
import torch
5+
from torchao import float8
56
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic, get_maybe_axiswise_dim
67
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
78
from torchao.float8.float8_tensor import GemmInputRole
@@ -36,7 +37,9 @@ def forward(
3637
out_dtype: Optional[torch.dtype] = None,
3738
use_fast_accum: bool = False,
3839
) -> torch.Tensor:
39-
40+
# torch._scaled_grouped_mm only supports rowwise scaling currently.
41+
assert float8_recipe_name == Float8LinearRecipeName.ROWWISE, "Only rowwise scaling is supported by torch._scaled_grouped_mm."
42+
4043
# perform dynamic float8 quantization using the given recipe, if specified
4144
assert 2 <= A.ndim <= 3, "A must be 2D or 3D"
4245
assert 2 <= B.ndim <= 3, "B must be 2D or 3D"
@@ -68,19 +71,31 @@ def forward(
6871
-1, float8_config.cast_config_input.scaling_granularity
6972
),
7073
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
71-
)
74+
)
75+
B_fp8_t = B_fp8.transpose(-2, -1)
7276

7377
# Store what we need for backward.
7478
ctx.save_for_backward(A, B)
7579
ctx.float_config = float8_config
7680
ctx.offs = offs
7781

82+
# Scale shape adjustments for compatibility with torch._scaled_grouped_mm.
83+
# For tensorwise scaling, torch._scaled_grouped_mm requires 1D scales, not 0D.
84+
if float8_recipe_name == Float8LinearRecipeName.TENSORWISE:
85+
A_fp8._scale = A_fp8._scale.unsqueeze(0)
86+
B_fp8_t._scale = B_fp8_t._scale.unsqueeze(0)
87+
88+
# For rowwise scaling, torch._scaled_grouped_mm requires scales without any empty dims.
89+
elif float8_recipe_name == Float8LinearRecipeName.ROWWISE:
90+
A_fp8._scale = A_fp8._scale.squeeze()
91+
B_fp8_t._scale = B_fp8_t._scale.squeeze()
92+
7893
# Perform scaled grouped GEMM and return result.
7994
return torch._scaled_grouped_mm(
8095
A_fp8._data,
81-
B_fp8._data,
82-
A_fp8._scale,
83-
B_fp8._scale,
96+
B_fp8_t._data,
97+
A_fp8._scale,
98+
B_fp8_t._scale,
8499
offs,
85100
out_dtype=out_dtype,
86101
use_fast_accum=use_fast_accum,

torchao/prototype/grouped_mm/test_grouped_mm.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,37 @@
55

66

77
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
8-
@pytest.mark.parametrize("float8_recipe", [Float8LinearRecipeName.TENSORWISE, Float8LinearRecipeName.ROWWISE])
98
@pytest.mark.parametrize("use_fast_accum", [True, False])
10-
def test_grouped_gemm(float8_recipe, use_fast_accum):
9+
@pytest.mark.parametrize("strided", [True, False])
10+
def test_grouped_gemm_2d_3d(use_fast_accum, strided):
11+
# unit test ensuring parity between torchao and pytorch core grouped_gemm
12+
# https://github.com/pytorch/pytorch/blob/87bfd66c3c7061db6d36d8daa62f08f507f90e39/test/test_matmul_cuda.py#L1204
1113
device = "cuda"
12-
m, n, k, n_groups = 16, 16, 16, 4
13-
a = torch.randn(m, k * n_groups + k, device=device)
14-
b = torch.randn(n, k * n_groups + k, device=device)
15-
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
14+
s_int = int(strided)
15+
m, n, k, n_groups = 16, 32, 16, 4
16+
a = torch.randn(m * n_groups, k * (1 + s_int), device=device)[:, :k]
17+
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device)[::(1 + s_int), :, :k]
18+
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
1619
result = grouped_mm(
17-
a, b.t(),
20+
a, b,
1821
offs=offs,
19-
float8_recipe=float8_recipe,
22+
float8_recipe=Float8LinearRecipeName.ROWWISE,
2023
out_dtype=torch.bfloat16,
2124
use_fast_accum=use_fast_accum
2225
)
2326
assert isinstance(result, torch.Tensor)
27+
28+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
29+
def test_tensorwise_scaling_not_supported():
30+
device = "cuda"
31+
m, n, k, n_groups = 16, 32, 16, 4
32+
a = torch.randn(m * n_groups, k, device=device)[:, :k]
33+
b = torch.randn(n_groups, n, k, device=device)[::1, :, :k]
34+
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
35+
with pytest.raises(AssertionError):
36+
result = grouped_mm(
37+
a, b,
38+
offs=offs,
39+
float8_recipe=Float8LinearRecipeName.TENSORWISE,
40+
out_dtype=torch.bfloat16,
41+
)

0 commit comments

Comments
 (0)