Skip to content

Commit 95fa529

Browse files
committed
add test
1 parent bed2858 commit 95fa529

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import pytest
2+
import torch
3+
4+
import tilelang
5+
import tilelang.testing
6+
import tilelang.language as T
7+
8+
9+
@tilelang.jit
10+
def get_kernel(reduce_op: str, dtype: str):
11+
12+
assert reduce_op in ["sum", "max", "min", "bitand", "bitor"]
13+
14+
@T.prim_func
15+
def main(
16+
x: T.Tensor((32), dtype)
17+
18+
):
19+
with T.Kernel(1, threads=32):
20+
tx = T.get_thread_binding(0)
21+
local_val = T.alloc_local([1], dtype)
22+
local_val[0] = x[tx]
23+
reduced_val = T.alloc_local([1], dtype)
24+
if reduce_op == "sum":
25+
reduced_val[0] = T.warp_reduce_sum(local_val[0])
26+
elif reduce_op == "max":
27+
reduced_val[0] = T.warp_reduce_max(local_val[0])
28+
elif reduce_op == "min":
29+
reduced_val[0] = T.warp_reduce_min(local_val[0])
30+
elif reduce_op == "bitand":
31+
reduced_val[0] = T.warp_reduce_bitand(local_val[0])
32+
elif reduce_op == "bitor":
33+
reduced_val[0] = T.warp_reduce_bitor(local_val[0])
34+
x[tx] = reduced_val[0]
35+
return main
36+
37+
38+
def test_warp_reduce_sum():
39+
a = torch.randn((32,), dtype=torch.float32, device='cuda')
40+
kernel = get_kernel('sum', 'float32')
41+
ref = torch.full_like(a, a.sum())
42+
kernel(a)
43+
torch.testing.assert_close(a, ref)
44+
45+
46+
def test_warp_reduce_max():
47+
a = torch.randn((32,), dtype=torch.float32, device='cuda')
48+
kernel = get_kernel("max", 'float32')
49+
print(kernel.get_kernel_source())
50+
ref = torch.full_like(a, a.max())
51+
kernel(a)
52+
torch.testing.assert_close(a, ref)
53+
54+
55+
def test_warp_reduce_min():
56+
a = torch.randn((32,), dtype=torch.float32, device='cuda')
57+
kernel = get_kernel("min", 'float32')
58+
ref = torch.full_like(a, a.min())
59+
kernel(a)
60+
torch.testing.assert_close(a, ref)
61+
62+
63+
def test_warp_reduce_bitand():
64+
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda')
65+
kernel = get_kernel("bitand", 'int32')
66+
ref_val = a[0]
67+
for i in range(1, a.shape[0]):
68+
ref_val = ref_val & a[i]
69+
ref = torch.full_like(a, ref_val)
70+
kernel(a)
71+
torch.testing.assert_close(a, ref)
72+
73+
74+
def test_warp_reduce_bitor():
75+
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda')
76+
kernel = get_kernel("bitor", 'int32')
77+
ref_val = a[0]
78+
for i in range(1, a.shape[0]):
79+
ref_val = ref_val | a[i]
80+
ref = torch.full_like(a, ref_val)
81+
kernel(a)
82+
torch.testing.assert_close(a, ref)
83+
84+
85+
if __name__ == "__main__":
86+
tilelang.testing.main()

0 commit comments

Comments
 (0)