Skip to content

Commit dc9784c

Browse files
committed
add test
1 parent 4e3843f commit dc9784c

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import tilelang
2+
import tilelang.testing
3+
from tilelang import language as T
4+
5+
6+
@tilelang.jit
7+
def _compile_kernel_without_inplace():
8+
num_tokens = T.symbolic("num_tokens")
9+
10+
@T.prim_func
11+
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
12+
with T.Kernel(num_tokens, threads=32) as pid:
13+
read = T.alloc_var("int")
14+
read = x[pid]
15+
16+
write = T.alloc_var("int")
17+
write = read * 2
18+
x[pid] = write
19+
20+
return buggy_kernel
21+
22+
23+
@tilelang.jit(
24+
pass_configs={
25+
tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True,
26+
},
27+
)
28+
def _compile_kernel_with_inplace():
29+
num_tokens = T.symbolic("num_tokens")
30+
31+
@T.prim_func
32+
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
33+
with T.Kernel(num_tokens, threads=32) as pid:
34+
read = T.alloc_var("int")
35+
read = x[pid]
36+
37+
write = T.alloc_var("int")
38+
write = read * 2
39+
x[pid] = write
40+
41+
return buggy_kernel
42+
43+
44+
def _get_device_kernel_script(detect_inplace: bool) -> str:
45+
if detect_inplace:
46+
kernel = _compile_kernel_with_inplace()
47+
else:
48+
kernel = _compile_kernel_without_inplace()
49+
source = kernel.get_kernel_source()
50+
return source
51+
52+
def test_storage_rewrite_detect_inplace_toggle():
53+
script_off = _get_device_kernel_script(detect_inplace=False)
54+
script_on = _get_device_kernel_script(detect_inplace=True)
55+
56+
assert script_off.count("read = (read * 2);") == 0
57+
assert script_on.count("read = (read * 2);") > 0
58+
59+
60+
61+
if __name__ == "__main__":
62+
tilelang.testing.main()

0 commit comments

Comments
 (0)