Skip to content

Commit b9c03c0

Browse files
[Intel] Enable Gluon tests
Signed-off-by: Whitney Tsang <whitney.tsang@intel.com> Co-authored-by: Ilya Enkovich <ilya.enkovich@intel.com>
1 parent b0f9b78 commit b9c03c0

File tree

7 files changed

+86
-43
lines changed

7 files changed

+86
-43
lines changed

.github/workflows/build-test-reusable.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ jobs:
196196
suite:
197197
- minicore
198198
- scaled_dot
199+
- gluon
199200
- rest
200201
- tutorial-fa-64
201202
- tutorial-fa-128-fwdfp8
@@ -306,6 +307,11 @@ jobs:
306307
run: |
307308
${{ env.TRITON_TEST_CMD }} --scaled-dot
308309
310+
- name: Run gluon tests
311+
if: matrix.suite == 'gluon' && inputs.driver_version == 'rolling'
312+
run: |
313+
${{ env.TRITON_TEST_CMD }} --gluon
314+
309315
- name: Run interpreter tests
310316
if: matrix.suite == 'rest'
311317
run: |

python/test/gluon/test_consan.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexpr
8585
tma.store_wait(0)
8686

8787

88-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
88+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
8989
@pytest.mark.parametrize("FAILURE", [True, False])
9090
def test_async_tma_kernel(FAILURE, device, run_wrapper):
9191
if run_wrapper:
@@ -141,7 +141,7 @@ def tma_interleave_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: ttgl.cons
141141
tma.store_wait(0)
142142

143143

144-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
144+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer")
145145
@pytest.mark.parametrize("FAILURE", [True, False])
146146
def test_tma_interleave_kernel(FAILURE, device, run_wrapper):
147147
if run_wrapper:
@@ -190,7 +190,7 @@ def async_copy_kernel(input, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexpr):
190190
ampere.async_copy.wait_group(0)
191191

192192

193-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires ampere or newer")
193+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires ampere or newer")
194194
@pytest.mark.parametrize("FAILURE", [True, False])
195195
def test_async_copy(FAILURE, device, run_wrapper):
196196
if run_wrapper:
@@ -252,7 +252,7 @@ def tcgen5_mma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexp
252252
mbarrier.invalidate(bar.index(1))
253253

254254

255-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
255+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
256256
@pytest.mark.parametrize("FAILURE", [True, False])
257257
@pytest.mark.parametrize("MEM_ACCESS_KIND", ["tma_cp", "local_store", "tmem_load", "tmem_store"])
258258
def test_tcgen5_mma(FAILURE, MEM_ACCESS_KIND, device, run_wrapper):
@@ -305,7 +305,7 @@ def warpgroup_mma_kernel(input, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexpr)
305305
smemA.store(ttgl.full([XBLOCK, XBLOCK], 42, ttgl.float16, blocked_layout))
306306

307307

308-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
308+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
309309
@pytest.mark.parametrize("FAILURE", [True, False])
310310
def test_warpgroup_mma(FAILURE, device, run_wrapper):
311311
if run_wrapper:
@@ -353,7 +353,7 @@ def warpgroup_mma_kernel2(input, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexpr
353353
smemA.store(ttgl.full([XBLOCK, XBLOCK], 42, ttgl.float16, blocked_layout))
354354

355355

356-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
356+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
357357
@pytest.mark.parametrize("FAILURE", [True, False])
358358
def test_warpgroup_mma2(FAILURE, device, run_wrapper):
359359
if run_wrapper:
@@ -406,7 +406,7 @@ def tcgen5_mma_multibar_kernel(input_desc, XBLOCK: ttgl.constexpr, BUF_IDX: ttgl
406406
mbarrier.invalidate(bar.index(i))
407407

408408

409-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
409+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
410410
@pytest.mark.parametrize("BUF_IDX", [0, 1])
411411
@pytest.mark.parametrize("BAR_IDX", [0, 1, 2, 3])
412412
def test_tcgen5_mma_multibar(BUF_IDX, BAR_IDX, device, run_wrapper):
@@ -529,7 +529,7 @@ def multibuffered_loop_tma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: t
529529
mbarrier.invalidate(barMMA.index(i))
530530

531531

532-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
532+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] < 10, reason="Requires blackwell or newer")
533533
@pytest.mark.parametrize("FAILURE", [True, False])
534534
def test_multibuffered_loop(FAILURE, device, run_wrapper):
535535
if run_wrapper:
@@ -611,7 +611,7 @@ def multibuffered_loop_wgmma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE:
611611
mbarrier.invalidate(barLoadB.index(i))
612612

613613

614-
@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
614+
@pytest.mark.xfail(not is_cuda() or torch.cuda.get_device_capability()[0] != 9, reason="Requires hopper")
615615
@pytest.mark.parametrize("FAILURE", [True, False])
616616
def test_multibuffered_wgmma_loop(FAILURE, device, run_wrapper):
617617
if run_wrapper:

python/test/gluon/test_core.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
is_hip_cdna4,
1515
is_hopper_or_newer,
1616
is_hopper,
17+
is_xpu,
1718
)
1819
from triton.experimental import gluon
1920
from triton.experimental.gluon import language as ttgl
@@ -55,8 +56,8 @@ def copy_kernel(Out, In, numel, XBLOCK: ttgl.constexpr, layout: ttgl.constexpr):
5556
ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[8], order=[0]),
5657
])
5758
@pytest.mark.parametrize("XBLOCK", [128, 256, 512, 1024, 2048])
58-
def test_copy_kernel(layout, XBLOCK):
59-
inp = torch.randn(XBLOCK * 4 - 7, device="cuda")
59+
def test_copy_kernel(layout, XBLOCK, device):
60+
inp = torch.randn(XBLOCK * 4 - 7, device=device)
6061
out = torch.empty_like(inp)
6162

6263
copy_kernel[(4, )](out, inp, inp.numel(), XBLOCK, layout, num_warps=layout.warps_per_cta[0])
@@ -73,7 +74,7 @@ def tma_kernel(desc):
7374
alloc._keep_alive()
7475

7576

76-
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper")
77+
@pytest.mark.xfail(not is_hopper_or_newer(), reason="Requires Hopper")
7778
def test_tma():
7879
out = torch.ones((16, 16), dtype=torch.float16, device="cuda")
7980
layout = ttgl.NVMMASharedLayout(
@@ -112,9 +113,9 @@ def async_copy_mbarrier_kernel(out, inp, xnumel, XBLOCK: ttgl.constexpr, YBLOCK:
112113
ttgl.store(out + xindex * YBLOCK + yindex, val)
113114

114115

115-
@pytest.mark.skipif(not is_ampere_or_newer(), reason="Requires Ampere")
116-
def test_async_copy_mbarrier():
117-
tensor_opts = dict(dtype=torch.float, device="cuda")
116+
@pytest.mark.xfail(not is_ampere_or_newer(), reason="Requires Ampere")
117+
def test_async_copy_mbarrier(device):
118+
tensor_opts = dict(dtype=torch.float, device=device)
118119
out = torch.empty((32, 32), **tensor_opts)
119120
inp = torch.randn((20, 32), **tensor_opts)
120121
async_copy_mbarrier_kernel[(1, )](out, inp, inp.shape[0], XBLOCK=32, YBLOCK=32)
@@ -153,7 +154,7 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg
153154
ttgl.store(out + out_offs_m * N + out_offs_n, acc)
154155

155156

156-
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
157+
@pytest.mark.xfail(not is_hopper(), reason="Requires Hopper")
157158
@pytest.mark.parametrize("ASYNC", [True, False])
158159
def test_warpgroup_mma(ASYNC):
159160
torch.manual_seed(0)
@@ -168,7 +169,7 @@ def test_warpgroup_mma(ASYNC):
168169
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1)
169170

170171

171-
@pytest.mark.skipif(not is_hip_cdna4(), reason="Requires CDNA4")
172+
@pytest.mark.xfail(not is_hip_cdna4(), reason="Requires CDNA4")
172173
@pytest.mark.parametrize("use_buffer_load", [True, False])
173174
def test_amd_direct_load_to_shared(use_buffer_load):
174175

@@ -204,7 +205,7 @@ def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr):
204205
assert 'vmcnt(0)' in pgm.asm['amdgcn']
205206

206207

207-
@pytest.mark.skipif(not (is_hip_gfx11() or is_hip_gfx12()), reason="Requires RDNA3 or RDNA4")
208+
@pytest.mark.xfail(not (is_hip_gfx11() or is_hip_gfx12()), reason="Requires RDNA3 or RDNA4")
208209
@pytest.mark.parametrize("M, N, K", [(64, 64, 64)])
209210
@pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16'])
210211
def test_amd_wmma(M, N, K, in_dtype):
@@ -270,6 +271,8 @@ def kernel(a_ptr, b_ptr, c_ptr, #
270271
@pytest.mark.parametrize("num_warps", [4, 8])
271272
@pytest.mark.parametrize("cdna_version", [3, 4])
272273
def test_amd_mfma(M, N, K, in_dtype, num_warps, cdna_version):
274+
if is_xpu():
275+
pytest.xfail("XPU does not support AMD MFMA")
273276

274277
@gluon.jit
275278
def kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, #
@@ -328,7 +331,7 @@ def kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, #
328331
torch.testing.assert_close(ref, triton_output)
329332

330333

331-
@pytest.mark.skipif(not is_hip_cdna4(), reason="Requires CDNA4")
334+
@pytest.mark.xfail(not is_hip_cdna4(), reason="Requires CDNA4")
332335
@pytest.mark.parametrize("M, N, K, rhs_scale, mxfp_type, normal_type", [(32, 32, 128, rhs_scale, mxfp_type, normal_type)
333336
for rhs_scale in [True, False]
334337
for mxfp_type in ["e2m1"]
@@ -470,7 +473,7 @@ def make_finite(x, dtype):
470473
torch.testing.assert_close(z, z_ref, rtol=1e-5, atol=1e-5)
471474

472475

473-
def test_math_fast_expf():
476+
def test_math_fast_expf(device):
474477

475478
@gluon.jit
476479
def fast_expf_kernel(x_ptr, y_ptr, warp_size: ttgl.constexpr, num_warps: ttgl.constexpr):
@@ -484,13 +487,13 @@ def fast_expf_kernel(x_ptr, y_ptr, warp_size: ttgl.constexpr, num_warps: ttgl.co
484487
num_warps = 4
485488

486489
torch.manual_seed(0)
487-
x = torch.randn(THREADS_PER_WARP * num_warps, device="cuda", dtype=torch.float32)
490+
x = torch.randn(THREADS_PER_WARP * num_warps, device=device, dtype=torch.float32)
488491
y = torch.empty_like(x)
489492
fast_expf_kernel[(1, )](x, y, THREADS_PER_WARP, num_warps)
490493
torch.testing.assert_close(y, torch.exp(x), atol=1e-5, rtol=1e-4)
491494

492495

493-
def test_math_fast_dividef():
496+
def test_math_fast_dividef(device):
494497

495498
@gluon.jit
496499
def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warps: ttgl.constexpr):
@@ -505,7 +508,7 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
505508
num_warps = 4
506509

507510
torch.manual_seed(0)
508-
x = torch.randn(THREADS_PER_WARP * num_warps, device="cuda", dtype=torch.float32)
511+
x = torch.randn(THREADS_PER_WARP * num_warps, device=device, dtype=torch.float32)
509512
y = torch.randn_like(x)
510513
z = torch.empty_like(x)
511514
y[y == 0] = 1.0
@@ -514,7 +517,7 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
514517

515518

516519
@pytest.mark.xfail(reason="copy to tmem with scale layout is currently broken in Gluon.")
517-
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
520+
@pytest.mark.xfail(not is_blackwell(), reason="Requires Blackwell")
518521
def test_tmem_copy_2d():
519522
device = "cuda"
520523

@@ -563,7 +566,7 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
563566
assert torch.equal(x[m * 32:(m + 1) * 32], z_tri[32 * i:32 * (i + 1), col_offset:(col_offset + 4)])
564567

565568

566-
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
569+
@pytest.mark.xfail(not is_blackwell(), reason="Requires Blackwell")
567570
def test_tmem_subslice_block_m_64():
568571

569572
@gluon.jit
@@ -643,7 +646,7 @@ def kernel(s_ptr, out_ptr):
643646
torch.testing.assert_close(out_ref, out_tri, atol=0, rtol=0)
644647

645648

646-
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
649+
@pytest.mark.xfail(not is_blackwell(), reason="Requires Blackwell")
647650
def test_block_m_64_mma():
648651

649652
@gluon.jit
@@ -734,7 +737,7 @@ def kernel(a_ptr, b_ptr, c_ptr, d_ptr):
734737
torch.testing.assert_close(d_ref, d_tri, rtol=0.08, atol=0)
735738

736739

737-
def test_slice_reinterpret():
740+
def test_slice_reinterpret(device):
738741
BLOCK = ttgl.constexpr(2048)
739742
SPLIT_BLOCK = ttgl.constexpr(BLOCK // 2)
740743
XBLOCK = ttgl.constexpr(32)
@@ -759,13 +762,13 @@ def kernel(in_ptr, out_ptr):
759762
value = smem_slice1.load(blocked)
760763
ttgl.store(ttgl.set_auto_layout(out_ptr + offs, blocked), value)
761764

762-
input = torch.randint(0, 100, (XBLOCK, YBLOCK), dtype=torch.int32, device="cuda")
765+
input = torch.randint(0, 100, (XBLOCK, YBLOCK), dtype=torch.int32, device=device)
763766
output = torch.empty_like(input)
764767
kernel[(1, )](input, output)
765768
torch.testing.assert_close(input, output, atol=0, rtol=0)
766769

767770

768-
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper")
771+
@pytest.mark.xfail(not is_hopper_or_newer(), reason="Requires Hopper")
769772
def test_tma_slice():
770773
XBLOCK = YBLOCK = ttgl.constexpr(128)
771774

@@ -802,7 +805,7 @@ def kernel(in_desc, out_desc):
802805
@pytest.mark.parametrize("swizzle", [32, 64, 128])
803806
@pytest.mark.parametrize("num_warps", [4, 8])
804807
@pytest.mark.parametrize("M, N, BLOCK_N", [(128, 128, 128), (256, 128, 64), (128, 128, 16)])
805-
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
808+
@pytest.mark.xfail(not is_blackwell(), reason="Requires Blackwell")
806809
def test_tmem_copy_no_scales(M, N, BLOCK_N, num_warps, swizzle):
807810

808811
@gluon.jit
@@ -856,7 +859,7 @@ def early_return_kernel(x):
856859
return x
857860

858861

859-
def test_2d_tensor_early_return():
862+
def test_2d_tensor_early_return(device):
860863
warp_size = ttgl.constexpr(THREADS_PER_WARP)
861864

862865
@gluon.jit
@@ -871,12 +874,12 @@ def kernel(N, out):
871874
x += early_return_kernel(x)
872875
ttgl.store(out, x.sum(0).sum(0))
873876

874-
out = torch.empty(1, dtype=torch.int32, device="cuda")
877+
out = torch.empty(1, dtype=torch.int32, device=device)
875878
compiled_kernel = kernel.warmup(N=100, out=out, grid=(1, ))
876879
assert compiled_kernel.asm["llir"].count("define") == 1
877880

878881

879-
@pytest.mark.skipif(not is_hip_cdna3() and not is_hip_cdna4(), reason="Requires CDNA3 or CDNA4")
882+
@pytest.mark.xfail(not is_hip_cdna3() and not is_hip_cdna4(), reason="Requires CDNA3 or CDNA4")
880883
def test_inline_with_amdgpu_dialect():
881884

882885
@gluon.jit
@@ -906,7 +909,8 @@ def kernel(x, y):
906909
{"offsets": [[0, 1], [0, 2], [0, 8], [0, 4], [0, 16], [0, 32], [2, 0], [1, 0], [4, 0], [8, 0], [16, 0], [32, 0]]}])
907910
@pytest.mark.parametrize("slice_m_offset, slice_n_offset, slice_m, slice_n", [(48, 16, 16, 16), (32, 48, 32, 16),
908911
(48, 32, 16, 32)])
909-
def test_padded_shared_layout_subslice(interval_pairs, shared_layout, slice_m_offset, slice_n_offset, slice_m, slice_n):
912+
def test_padded_shared_layout_subslice(interval_pairs, shared_layout, slice_m_offset, slice_n_offset, slice_m, slice_n,
913+
device):
910914
m = 64
911915
n = 64
912916
num_warps = 1
@@ -945,8 +949,8 @@ def kernel(in_ptr, out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, SLICE_M_OFFSET
945949
out_offs = offs_m_store[:, None] * SLICE_N + offs_n_store[None, :]
946950
ttgl.store(out_ptr + out_offs, out_data)
947951

948-
input = torch.arange(m * n, device="cuda").reshape(m, n).to(torch.int32)
949-
output = torch.zeros((slice_m, slice_n), dtype=torch.int32, device="cuda")
952+
input = torch.arange(m * n, device=device).reshape(m, n).to(torch.int32)
953+
output = torch.zeros((slice_m, slice_n), dtype=torch.int32, device=device)
950954
ref_output = input[slice_m_offset:slice_m_offset + slice_m, slice_n_offset:slice_n_offset + slice_n]
951955

952956
kernel[(1, )](input, output, m, n, slice_m_offset, slice_n_offset, slice_m, slice_n, num_warps=num_warps)

0 commit comments

Comments
 (0)