|
| 1 | +import tilelang |
| 2 | +import tilelang.testing |
| 3 | +from tilelang import language as T |
| 4 | + |
| 5 | + |
| 6 | +@tilelang.jit |
| 7 | +def _compile_kernel_without_inplace(): |
| 8 | + num_tokens = T.symbolic("num_tokens") |
| 9 | + |
| 10 | + @T.prim_func |
| 11 | + def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]): |
| 12 | + with T.Kernel(num_tokens, threads=32) as pid: |
| 13 | + read = T.alloc_var("int") |
| 14 | + read = x[pid] |
| 15 | + |
| 16 | + write = T.alloc_var("int") |
| 17 | + write = read * 2 |
| 18 | + x[pid] = write |
| 19 | + |
| 20 | + return buggy_kernel |
| 21 | + |
| 22 | + |
| 23 | +@tilelang.jit( |
| 24 | + pass_configs={ |
| 25 | + tilelang.PassConfigKey.TL_STORAGE_REWRITE_DETECT_INPLACE: True, |
| 26 | + }, |
| 27 | +) |
| 28 | +def _compile_kernel_with_inplace(): |
| 29 | + num_tokens = T.symbolic("num_tokens") |
| 30 | + |
| 31 | + @T.prim_func |
| 32 | + def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]): |
| 33 | + with T.Kernel(num_tokens, threads=32) as pid: |
| 34 | + read = T.alloc_var("int") |
| 35 | + read = x[pid] |
| 36 | + |
| 37 | + write = T.alloc_var("int") |
| 38 | + write = read * 2 |
| 39 | + x[pid] = write |
| 40 | + |
| 41 | + return buggy_kernel |
| 42 | + |
| 43 | + |
| 44 | +def _get_device_kernel_script(detect_inplace: bool) -> str: |
| 45 | + if detect_inplace: |
| 46 | + kernel = _compile_kernel_with_inplace() |
| 47 | + else: |
| 48 | + kernel = _compile_kernel_without_inplace() |
| 49 | + source = kernel.get_kernel_source() |
| 50 | + return source |
| 51 | + |
| 52 | +def test_storage_rewrite_detect_inplace_toggle(): |
| 53 | + script_off = _get_device_kernel_script(detect_inplace=False) |
| 54 | + script_on = _get_device_kernel_script(detect_inplace=True) |
| 55 | + |
| 56 | + assert script_off.count("read = (read * 2);") == 0 |
| 57 | + assert script_on.count("read = (read * 2);") > 0 |
| 58 | + |
| 59 | + |
| 60 | + |
| 61 | +if __name__ == "__main__": |
| 62 | + tilelang.testing.main() |
0 commit comments