@@ -55,8 +55,8 @@ def copy_kernel(Out, In, numel, XBLOCK: ttgl.constexpr, layout: ttgl.constexpr):
55
55
ttgl .BlockedLayout (size_per_thread = [8 ], threads_per_warp = [THREADS_PER_WARP ], warps_per_cta = [8 ], order = [0 ]),
56
56
])
57
57
@pytest .mark .parametrize ("XBLOCK" , [128 , 256 , 512 , 1024 , 2048 ])
58
- def test_copy_kernel (layout , XBLOCK ):
59
- inp = torch .randn (XBLOCK * 4 - 7 , device = "cuda" )
58
+ def test_copy_kernel (layout , XBLOCK , device ):
59
+ inp = torch .randn (XBLOCK * 4 - 7 , device = device )
60
60
out = torch .empty_like (inp )
61
61
62
62
copy_kernel [(4 , )](out , inp , inp .numel (), XBLOCK , layout , num_warps = layout .warps_per_cta [0 ])
@@ -113,8 +113,8 @@ def async_copy_mbarrier_kernel(out, inp, xnumel, XBLOCK: ttgl.constexpr, YBLOCK:
113
113
114
114
115
115
@pytest .mark .skipif (not is_ampere_or_newer (), reason = "Requires Ampere" )
116
- def test_async_copy_mbarrier ():
117
- tensor_opts = dict (dtype = torch .float , device = "cuda" )
116
+ def test_async_copy_mbarrier (device ):
117
+ tensor_opts = dict (dtype = torch .float , device = device )
118
118
out = torch .empty ((32 , 32 ), ** tensor_opts )
119
119
inp = torch .randn ((20 , 32 ), ** tensor_opts )
120
120
async_copy_mbarrier_kernel [(1 , )](out , inp , inp .shape [0 ], XBLOCK = 32 , YBLOCK = 32 )
@@ -470,7 +470,7 @@ def make_finite(x, dtype):
470
470
torch .testing .assert_close (z , z_ref , rtol = 1e-5 , atol = 1e-5 )
471
471
472
472
473
- def test_math_fast_expf ():
473
+ def test_math_fast_expf (device ):
474
474
475
475
@gluon .jit
476
476
def fast_expf_kernel (x_ptr , y_ptr , warp_size : ttgl .constexpr , num_warps : ttgl .constexpr ):
@@ -484,13 +484,13 @@ def fast_expf_kernel(x_ptr, y_ptr, warp_size: ttgl.constexpr, num_warps: ttgl.co
484
484
num_warps = 4
485
485
486
486
torch .manual_seed (0 )
487
- x = torch .randn (THREADS_PER_WARP * num_warps , device = "cuda" , dtype = torch .float32 )
487
+ x = torch .randn (THREADS_PER_WARP * num_warps , device = device , dtype = torch .float32 )
488
488
y = torch .empty_like (x )
489
489
fast_expf_kernel [(1 , )](x , y , THREADS_PER_WARP , num_warps )
490
490
torch .testing .assert_close (y , torch .exp (x ), atol = 1e-5 , rtol = 1e-4 )
491
491
492
492
493
- def test_math_fast_dividef ():
493
+ def test_math_fast_dividef (device ):
494
494
495
495
@gluon .jit
496
496
def fast_dividef_kernel (x_ptr , y_ptr , z_ptr , warp_size : ttgl .constexpr , num_warps : ttgl .constexpr ):
@@ -505,7 +505,7 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
505
505
num_warps = 4
506
506
507
507
torch .manual_seed (0 )
508
- x = torch .randn (THREADS_PER_WARP * num_warps , device = "cuda" , dtype = torch .float32 )
508
+ x = torch .randn (THREADS_PER_WARP * num_warps , device = device , dtype = torch .float32 )
509
509
y = torch .randn_like (x )
510
510
z = torch .empty_like (x )
511
511
y [y == 0 ] = 1.0
@@ -734,7 +734,7 @@ def kernel(a_ptr, b_ptr, c_ptr, d_ptr):
734
734
torch .testing .assert_close (d_ref , d_tri , rtol = 0.08 , atol = 0 )
735
735
736
736
737
- def test_slice_reinterpret ():
737
+ def test_slice_reinterpret (device ):
738
738
BLOCK = ttgl .constexpr (2048 )
739
739
SPLIT_BLOCK = ttgl .constexpr (BLOCK // 2 )
740
740
XBLOCK = ttgl .constexpr (32 )
@@ -759,7 +759,7 @@ def kernel(in_ptr, out_ptr):
759
759
value = smem_slice1 .load (blocked )
760
760
ttgl .store (ttgl .set_auto_layout (out_ptr + offs , blocked ), value )
761
761
762
- input = torch .randint (0 , 100 , (XBLOCK , YBLOCK ), dtype = torch .int32 , device = "cuda" )
762
+ input = torch .randint (0 , 100 , (XBLOCK , YBLOCK ), dtype = torch .int32 , device = device )
763
763
output = torch .empty_like (input )
764
764
kernel [(1 , )](input , output )
765
765
torch .testing .assert_close (input , output , atol = 0 , rtol = 0 )
@@ -856,7 +856,7 @@ def early_return_kernel(x):
856
856
return x
857
857
858
858
859
- def test_2d_tensor_early_return ():
859
+ def test_2d_tensor_early_return (device ):
860
860
warp_size = ttgl .constexpr (THREADS_PER_WARP )
861
861
862
862
@gluon .jit
@@ -871,7 +871,7 @@ def kernel(N, out):
871
871
x += early_return_kernel (x )
872
872
ttgl .store (out , x .sum (0 ).sum (0 ))
873
873
874
- out = torch .empty (1 , dtype = torch .int32 , device = "cuda" )
874
+ out = torch .empty (1 , dtype = torch .int32 , device = device )
875
875
compiled_kernel = kernel .warmup (N = 100 , out = out , grid = (1 , ))
876
876
assert compiled_kernel .asm ["llir" ].count ("define" ) == 1
877
877
@@ -906,7 +906,8 @@ def kernel(x, y):
906
906
{"offsets" : [[0 , 1 ], [0 , 2 ], [0 , 8 ], [0 , 4 ], [0 , 16 ], [0 , 32 ], [2 , 0 ], [1 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [32 , 0 ]]}])
907
907
@pytest .mark .parametrize ("slice_m_offset, slice_n_offset, slice_m, slice_n" , [(48 , 16 , 16 , 16 ), (32 , 48 , 32 , 16 ),
908
908
(48 , 32 , 16 , 32 )])
909
- def test_padded_shared_layout_subslice (interval_pairs , shared_layout , slice_m_offset , slice_n_offset , slice_m , slice_n ):
909
+ def test_padded_shared_layout_subslice (interval_pairs , shared_layout , slice_m_offset , slice_n_offset , slice_m , slice_n ,
910
+ device ):
910
911
m = 64
911
912
n = 64
912
913
num_warps = 1
@@ -945,8 +946,8 @@ def kernel(in_ptr, out_ptr, M: ttgl.constexpr, N: ttgl.constexpr, SLICE_M_OFFSET
945
946
out_offs = offs_m_store [:, None ] * SLICE_N + offs_n_store [None , :]
946
947
ttgl .store (out_ptr + out_offs , out_data )
947
948
948
- input = torch .arange (m * n , device = "cuda" ).reshape (m , n ).to (torch .int32 )
949
- output = torch .zeros ((slice_m , slice_n ), dtype = torch .int32 , device = "cuda" )
949
+ input = torch .arange (m * n , device = device ).reshape (m , n ).to (torch .int32 )
950
+ output = torch .zeros ((slice_m , slice_n ), dtype = torch .int32 , device = device )
950
951
ref_output = input [slice_m_offset :slice_m_offset + slice_m , slice_n_offset :slice_n_offset + slice_n ]
951
952
952
953
kernel [(1 , )](input , output , m , n , slice_m_offset , slice_n_offset , slice_m , slice_n , num_warps = num_warps )
0 commit comments