Skip to content

Commit 7afbe08

Browse files
add 3Dx3D test
1 parent a761549 commit 7afbe08

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

torchao/prototype/grouped_mm/test_grouped_mm.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
@pytest.mark.parametrize("use_fast_accum", [True, False])
99
@pytest.mark.parametrize("strided", [True, False])
1010
def test_grouped_gemm_2d_3d(use_fast_accum, strided):
11-
# unit test ensuring parity between torchao and pytorch core grouped_gemm
12-
# https://github.com/pytorch/pytorch/blob/87bfd66c3c7061db6d36d8daa62f08f507f90e39/test/test_matmul_cuda.py#L1204
1311
device = "cuda"
1412
s_int = int(strided)
1513
m, n, k, n_groups = 16, 32, 16, 4
@@ -25,6 +23,24 @@ def test_grouped_gemm_2d_3d(use_fast_accum, strided):
2523
)
2624
assert isinstance(result, torch.Tensor)
2725

26+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
27+
@pytest.mark.parametrize("use_fast_accum", [True, False])
28+
@pytest.mark.parametrize("strided", [True, False])
29+
def test_grouped_gemm_3d_3d(use_fast_accum, strided):
30+
device = "cuda"
31+
s_int = int(strided)
32+
m, n, k, n_groups = 16, 32, 16, 4
33+
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device)[::(1 + s_int), :, :k]
34+
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device)[::(1 + s_int), :, :k]
35+
result = grouped_mm(
36+
a, b,
37+
float8_recipe=Float8LinearRecipeName.ROWWISE,
38+
out_dtype=torch.bfloat16,
39+
use_fast_accum=use_fast_accum
40+
)
41+
assert isinstance(result, torch.Tensor)
42+
43+
2844
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
2945
def test_tensorwise_scaling_not_supported():
3046
device = "cuda"

0 commit comments

Comments
 (0)