8
8
@pytest .mark .parametrize ("use_fast_accum" , [True , False ])
9
9
@pytest .mark .parametrize ("strided" , [True , False ])
10
10
def test_grouped_gemm_2d_3d (use_fast_accum , strided ):
11
- # unit test ensuring parity between torchao and pytorch core grouped_gemm
12
- # https://github.com/pytorch/pytorch/blob/87bfd66c3c7061db6d36d8daa62f08f507f90e39/test/test_matmul_cuda.py#L1204
13
11
device = "cuda"
14
12
s_int = int (strided )
15
13
m , n , k , n_groups = 16 , 32 , 16 , 4
@@ -25,6 +23,24 @@ def test_grouped_gemm_2d_3d(use_fast_accum, strided):
25
23
)
26
24
assert isinstance (result , torch .Tensor )
27
25
26
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
27
+ @pytest .mark .parametrize ("use_fast_accum" , [True , False ])
28
+ @pytest .mark .parametrize ("strided" , [True , False ])
29
+ def test_grouped_gemm_3d_3d (use_fast_accum , strided ):
30
+ device = "cuda"
31
+ s_int = int (strided )
32
+ m , n , k , n_groups = 16 , 32 , 16 , 4
33
+ a = torch .randn (n_groups * (1 + s_int ), m , k * (1 + s_int ), device = device )[::(1 + s_int ), :, :k ]
34
+ b = torch .randn (n_groups * (1 + s_int ), n , k * (1 + s_int ), device = device )[::(1 + s_int ), :, :k ]
35
+ result = grouped_mm (
36
+ a , b ,
37
+ float8_recipe = Float8LinearRecipeName .ROWWISE ,
38
+ out_dtype = torch .bfloat16 ,
39
+ use_fast_accum = use_fast_accum
40
+ )
41
+ assert isinstance (result , torch .Tensor )
42
+
43
+
28
44
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
29
45
def test_tensorwise_scaling_not_supported ():
30
46
device = "cuda"
0 commit comments