14
14
is_hip_cdna4 ,
15
15
is_hopper_or_newer ,
16
16
is_hopper ,
17
+ is_xpu ,
17
18
)
18
19
from triton .experimental import gluon
19
20
from triton .experimental .gluon import language as ttgl
@@ -55,8 +56,8 @@ def copy_kernel(Out, In, numel, XBLOCK: ttgl.constexpr, layout: ttgl.constexpr):
55
56
ttgl .BlockedLayout (size_per_thread = [8 ], threads_per_warp = [THREADS_PER_WARP ], warps_per_cta = [8 ], order = [0 ]),
56
57
])
57
58
@pytest .mark .parametrize ("XBLOCK" , [128 , 256 , 512 , 1024 , 2048 ])
58
- def test_copy_kernel (layout , XBLOCK ):
59
- inp = torch .randn (XBLOCK * 4 - 7 , device = "cuda" )
59
+ def test_copy_kernel (layout , XBLOCK , device ):
60
+ inp = torch .randn (XBLOCK * 4 - 7 , device = device )
60
61
out = torch .empty_like (inp )
61
62
62
63
copy_kernel [(4 , )](out , inp , inp .numel (), XBLOCK , layout , num_warps = layout .warps_per_cta [0 ])
@@ -73,7 +74,7 @@ def tma_kernel(desc):
73
74
alloc ._keep_alive ()
74
75
75
76
76
- @pytest .mark .skipif (not is_hopper_or_newer (), reason = "Requires Hopper" )
77
+ @pytest .mark .xfail (not is_hopper_or_newer (), reason = "Requires Hopper" )
77
78
def test_tma ():
78
79
out = torch .ones ((16 , 16 ), dtype = torch .float16 , device = "cuda" )
79
80
layout = ttgl .NVMMASharedLayout (
@@ -112,9 +113,9 @@ def async_copy_mbarrier_kernel(out, inp, xnumel, XBLOCK: ttgl.constexpr, YBLOCK:
112
113
ttgl .store (out + xindex * YBLOCK + yindex , val )
113
114
114
115
115
- @pytest .mark .skipif (not is_ampere_or_newer (), reason = "Requires Ampere" )
116
- def test_async_copy_mbarrier ():
117
- tensor_opts = dict (dtype = torch .float , device = "cuda" )
116
+ @pytest .mark .xfail (not is_ampere_or_newer (), reason = "Requires Ampere" )
117
+ def test_async_copy_mbarrier (device ):
118
+ tensor_opts = dict (dtype = torch .float , device = device )
118
119
out = torch .empty ((32 , 32 ), ** tensor_opts )
119
120
inp = torch .randn ((20 , 32 ), ** tensor_opts )
120
121
async_copy_mbarrier_kernel [(1 , )](out , inp , inp .shape [0 ], XBLOCK = 32 , YBLOCK = 32 )
@@ -153,7 +154,7 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg
153
154
ttgl .store (out + out_offs_m * N + out_offs_n , acc )
154
155
155
156
156
- @pytest .mark .skipif (not is_hopper (), reason = "Requires Hopper" )
157
+ @pytest .mark .xfail (not is_hopper (), reason = "Requires Hopper" )
157
158
@pytest .mark .parametrize ("ASYNC" , [True , False ])
158
159
def test_warpgroup_mma (ASYNC ):
159
160
torch .manual_seed (0 )
@@ -168,7 +169,7 @@ def test_warpgroup_mma(ASYNC):
168
169
torch .testing .assert_close (out , ref , atol = 1e-3 , rtol = 1e-1 )
169
170
170
171
171
- @pytest .mark .skipif (not is_hip_cdna4 (), reason = "Requires CDNA4" )
172
+ @pytest .mark .xfail (not is_hip_cdna4 (), reason = "Requires CDNA4" )
172
173
@pytest .mark .parametrize ("use_buffer_load" , [True , False ])
173
174
def test_amd_direct_load_to_shared (use_buffer_load ):
174
175
@@ -204,7 +205,7 @@ def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr):
204
205
assert 'vmcnt(0)' in pgm .asm ['amdgcn' ]
205
206
206
207
207
- @pytest .mark .skipif (not (is_hip_gfx11 () or is_hip_gfx12 ()), reason = "Requires RDNA3 or RDNA4" )
208
+ @pytest .mark .xfail (not (is_hip_gfx11 () or is_hip_gfx12 ()), reason = "Requires RDNA3 or RDNA4" )
208
209
@pytest .mark .parametrize ("M, N, K" , [(64 , 64 , 64 )])
209
210
@pytest .mark .parametrize ("in_dtype" , ['float16' , 'bfloat16' ])
210
211
def test_amd_wmma (M , N , K , in_dtype ):
@@ -270,6 +271,8 @@ def kernel(a_ptr, b_ptr, c_ptr, #
270
271
@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
271
272
@pytest .mark .parametrize ("cdna_version" , [3 , 4 ])
272
273
def test_amd_mfma (M , N , K , in_dtype , num_warps , cdna_version ):
274
+ if is_xpu ():
275
+ pytest .xfail ("XPU does not support AMD MFMA" )
273
276
274
277
@gluon .jit
275
278
def kernel (a_ptr , b_ptr , c_ptr , stride_am , stride_ak , #
@@ -328,7 +331,7 @@ def kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, #
328
331
torch .testing .assert_close (ref , triton_output )
329
332
330
333
331
- @pytest .mark .skipif (not is_hip_cdna4 (), reason = "Requires CDNA4" )
334
+ @pytest .mark .xfail (not is_hip_cdna4 (), reason = "Requires CDNA4" )
332
335
@pytest .mark .parametrize ("M, N, K, rhs_scale, mxfp_type, normal_type" , [(32 , 32 , 128 , rhs_scale , mxfp_type , normal_type )
333
336
for rhs_scale in [True , False ]
334
337
for mxfp_type in ["e2m1" ]
@@ -470,7 +473,7 @@ def make_finite(x, dtype):
470
473
torch .testing .assert_close (z , z_ref , rtol = 1e-5 , atol = 1e-5 )
471
474
472
475
473
- def test_math_fast_expf ():
476
+ def test_math_fast_expf (device ):
474
477
475
478
@gluon .jit
476
479
def fast_expf_kernel (x_ptr , y_ptr , warp_size : ttgl .constexpr , num_warps : ttgl .constexpr ):
@@ -484,13 +487,13 @@ def fast_expf_kernel(x_ptr, y_ptr, warp_size: ttgl.constexpr, num_warps: ttgl.co
484
487
num_warps = 4
485
488
486
489
torch .manual_seed (0 )
487
- x = torch .randn (THREADS_PER_WARP * num_warps , device = "cuda" , dtype = torch .float32 )
490
+ x = torch .randn (THREADS_PER_WARP * num_warps , device = device , dtype = torch .float32 )
488
491
y = torch .empty_like (x )
489
492
fast_expf_kernel [(1 , )](x , y , THREADS_PER_WARP , num_warps )
490
493
torch .testing .assert_close (y , torch .exp (x ), atol = 1e-5 , rtol = 1e-4 )
491
494
492
495
493
- def test_math_fast_dividef ():
496
+ def test_math_fast_dividef (device ):
494
497
495
498
@gluon .jit
496
499
def fast_dividef_kernel (x_ptr , y_ptr , z_ptr , warp_size : ttgl .constexpr , num_warps : ttgl .constexpr ):
@@ -505,7 +508,7 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
505
508
num_warps = 4
506
509
507
510
torch .manual_seed (0 )
508
- x = torch .randn (THREADS_PER_WARP * num_warps , device = "cuda" , dtype = torch .float32 )
511
+ x = torch .randn (THREADS_PER_WARP * num_warps , device = device , dtype = torch .float32 )
509
512
y = torch .randn_like (x )
510
513
z = torch .empty_like (x )
511
514
y [y == 0 ] = 1.0
@@ -514,7 +517,7 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
514
517
515
518
516
519
@pytest .mark .xfail (reason = "copy to tmem with scale layout is currently broken in Gluon." )
517
- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
520
+ @pytest .mark .xfail (not is_blackwell (), reason = "Requires Blackwell" )
518
521
def test_tmem_copy_2d ():
519
522
device = "cuda"
520
523
@@ -563,7 +566,7 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
563
566
assert torch .equal (x [m * 32 :(m + 1 ) * 32 ], z_tri [32 * i :32 * (i + 1 ), col_offset :(col_offset + 4 )])
564
567
565
568
566
- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
569
+ @pytest .mark .xfail (not is_blackwell (), reason = "Requires Blackwell" )
567
570
def test_tmem_subslice_block_m_64 ():
568
571
569
572
@gluon .jit
@@ -643,7 +646,7 @@ def kernel(s_ptr, out_ptr):
643
646
torch .testing .assert_close (out_ref , out_tri , atol = 0 , rtol = 0 )
644
647
645
648
646
- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
649
+ @pytest .mark .xfail (not is_blackwell (), reason = "Requires Blackwell" )
647
650
def test_block_m_64_mma ():
648
651
649
652
@gluon .jit
@@ -734,7 +737,7 @@ def kernel(a_ptr, b_ptr, c_ptr, d_ptr):
734
737
torch .testing .assert_close (d_ref , d_tri , rtol = 0.08 , atol = 0 )
735
738
736
739
737
- def test_slice_reinterpret ():
740
+ def test_slice_reinterpret (device ):
738
741
BLOCK = ttgl .constexpr (2048 )
739
742
SPLIT_BLOCK = ttgl .constexpr (BLOCK // 2 )
740
743
XBLOCK = ttgl .constexpr (32 )
@@ -759,13 +762,13 @@ def kernel(in_ptr, out_ptr):
759
762
value = smem_slice1 .load (blocked )
760
763
ttgl .store (ttgl .set_auto_layout (out_ptr + offs , blocked ), value )
761
764
762
- input = torch .randint (0 , 100 , (XBLOCK , YBLOCK ), dtype = torch .int32 , device = "cuda" )
765
+ input = torch .randint (0 , 100 , (XBLOCK , YBLOCK ), dtype = torch .int32 , device = device )
763
766
output = torch .empty_like (input )
764
767
kernel [(1 , )](input , output )
765
768
torch .testing .assert_close (input , output , atol = 0 , rtol = 0 )
766
769
767
770
768
- @pytest .mark .skipif (not is_hopper_or_newer (), reason = "Requires Hopper" )
771
+ @pytest .mark .xfail (not is_hopper_or_newer (), reason = "Requires Hopper" )
769
772
def test_tma_slice ():
770
773
XBLOCK = YBLOCK = ttgl .constexpr (128 )
771
774
@@ -802,7 +805,7 @@ def kernel(in_desc, out_desc):
802
805
@pytest .mark .parametrize ("swizzle" , [32 , 64 , 128 ])
803
806
@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
804
807
@pytest .mark .parametrize ("M, N, BLOCK_N" , [(128 , 128 , 128 ), (256 , 128 , 64 ), (128 , 128 , 16 )])
805
- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
808
+ @pytest .mark .xfail (not is_blackwell (), reason = "Requires Blackwell" )
806
809
def test_tmem_copy_no_scales (M , N , BLOCK_N , num_warps , swizzle ):
807
810
808
811
@gluon .jit
@@ -856,7 +859,7 @@ def early_return_kernel(x):
856
859
return x
857
860
858
861
859
- def test_2d_tensor_early_return ():
862
+ def test_2d_tensor_early_return (device ):
860
863
warp_size = ttgl .constexpr (THREADS_PER_WARP )
861
864
862
865
@gluon .jit
@@ -871,12 +874,12 @@ def kernel(N, out):
871
874
x += early_return_kernel (x )
872
875
ttgl .store (out , x .sum (0 ).sum (0 ))
873
876
874
- out = torch .empty (1 , dtype = torch .int32 , device = "cuda" )
877
+ out = torch .empty (1 , dtype = torch .int32 , device = device )
875
878
compiled_kernel = kernel .warmup (N = 100 , out = out , grid = (1 , ))
876
879
assert compiled_kernel .asm ["llir" ].count ("define" ) == 1
877
880
878
881
879
- @pytest .mark .skipif (not is_hip_cdna3 () and not is_hip_cdna4 (), reason = "Requires CDNA3 or CDNA4" )
882
+ @pytest .mark .xfail (not is_hip_cdna3 () and not is_hip_cdna4 (), reason = "Requires CDNA3 or CDNA4" )
880
883
def test_inline_with_amdgpu_dialect ():
881
884
882
885
@gluon .jit
@@ -906,7 +909,8 @@ def kernel(x, y):
906
909
{"offsets" : [[0 , 1 ], [0 , 2 ], [0 , 8 ], [0 , 4 ], [0 , 16 ], [0 , 32 ], [2 , 0 ], [1 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [32 , 0 ]]}])
907
910
@pytest .mark .parametrize ("slice_m_offset, slice_n_offset, slice_m, slice_n" , [(48 , 16 , 16 , 16 ), (32 , 48 , 32 , 16 ),
908
911
(48 , 32 , 16 , 32 )])
909
- def test_padded_shared_layout_subslice (interval_pairs , shared_layout , slice_m_offset , slice_n_offset , slice_m , slice_n ):
912
+ def test_padded_shared_layout_subslice (interval_pairs , shared_layout , slice_m_offset , slice_n_offset , slice_m , slice_n ,
913
+ device ):
910
914
m = 64
911
915
n = 64
912
916
num_warps = 1
@@ -945,8 +949,8 @@ def kernel(in_ptr, out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, SLICE_M_OFFSET
945
949
out_offs = offs_m_store [:, None ] * SLICE_N + offs_n_store [None , :]
946
950
ttgl .store (out_ptr + out_offs , out_data )
947
951
948
- input = torch .arange (m * n , device = "cuda" ).reshape (m , n ).to (torch .int32 )
949
- output = torch .zeros ((slice_m , slice_n ), dtype = torch .int32 , device = "cuda" )
952
+ input = torch .arange (m * n , device = device ).reshape (m , n ).to (torch .int32 )
953
+ output = torch .zeros ((slice_m , slice_n ), dtype = torch .int32 , device = device )
950
954
ref_output = input [slice_m_offset :slice_m_offset + slice_m , slice_n_offset :slice_n_offset + slice_n ]
951
955
952
956
kernel [(1 , )](input , output , m , n , slice_m_offset , slice_n_offset , slice_m , slice_n , num_warps = num_warps )
0 commit comments