|
5 | 5 |
|
6 | 6 |
|
7 | 7 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
8 |
| -@pytest.mark.parametrize("float8_recipe", [Float8LinearRecipeName.TENSORWISE, Float8LinearRecipeName.ROWWISE]) |
9 | 8 | @pytest.mark.parametrize("use_fast_accum", [True, False])
|
10 |
| -def test_grouped_gemm(float8_recipe, use_fast_accum): |
| 9 | +@pytest.mark.parametrize("strided", [True, False]) |
| 10 | +def test_grouped_gemm_2d_3d(use_fast_accum, strided): |
| 11 | + # unit test ensuring parity between torchao and pytorch core grouped_gemm |
| 12 | + # https://github.com/pytorch/pytorch/blob/87bfd66c3c7061db6d36d8daa62f08f507f90e39/test/test_matmul_cuda.py#L1204 |
11 | 13 | device = "cuda"
|
12 |
| - m, n, k, n_groups = 16, 16, 16, 4 |
13 |
| - a = torch.randn(m, k * n_groups + k, device=device) |
14 |
| - b = torch.randn(n, k * n_groups + k, device=device) |
15 |
| - offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) |
| 14 | + s_int = int(strided) |
| 15 | + m, n, k, n_groups = 16, 32, 16, 4 |
| 16 | + a = torch.randn(m * n_groups, k * (1 + s_int), device=device)[:, :k] |
| 17 | + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device)[::(1 + s_int), :, :k] |
| 18 | + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) |
16 | 19 | result = grouped_mm(
|
17 |
| - a, b.t(), |
| 20 | + a, b, |
18 | 21 | offs=offs,
|
19 |
| - float8_recipe=float8_recipe, |
| 22 | + float8_recipe=Float8LinearRecipeName.ROWWISE, |
20 | 23 | out_dtype=torch.bfloat16,
|
21 | 24 | use_fast_accum=use_fast_accum
|
22 | 25 | )
|
23 | 26 | assert isinstance(result, torch.Tensor)
|
| 27 | + |
| 28 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 29 | +def test_tensorwise_scaling_not_supported(): |
| 30 | + device = "cuda" |
| 31 | + m, n, k, n_groups = 16, 32, 16, 4 |
| 32 | + a = torch.randn(m * n_groups, k, device=device)[:, :k] |
| 33 | + b = torch.randn(n_groups, n, k, device=device)[::1, :, :k] |
| 34 | + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) |
| 35 | + with pytest.raises(AssertionError): |
| 36 | + result = grouped_mm( |
| 37 | + a, b, |
| 38 | + offs=offs, |
| 39 | + float8_recipe=Float8LinearRecipeName.TENSORWISE, |
| 40 | + out_dtype=torch.bfloat16, |
| 41 | + ) |
0 commit comments