Skip to content

Commit d3c37e3

Browse files
committed
[Fix] init var with complex expression
1 parent 60567ba commit d3c37e3

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import tilelang
2+
import tilelang.language as T
3+
import tilelang.testing
4+
5+
def test_var_assign() -> None:
6+
@tilelang.jit(out_idx=-1)
7+
def jit_kernel():
8+
@T.prim_func
9+
def test_var_assign(A: T.Tensor((2,), 'int32')):
10+
with T.Kernel(1) as _:
11+
a = T.alloc_var('int32', init=1)
12+
b = T.alloc_var('int32', init=a) # b gets value of a
13+
a = 2
14+
d = T.alloc_var('int32', init=a) # c gets new value of a
15+
A[0] = b
16+
A[1] = d
17+
print(test_var_assign)
18+
return test_var_assign
19+
kernel = jit_kernel()
20+
print(kernel.get_kernel_source())
21+
res = kernel()
22+
assert res[0] == 1
23+
assert res[1] == 2
24+
25+
26+
if __name__ == '__main__':
27+
tilelang.testing.main()

tilelang/language/allocate.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
"""
1616

1717
from __future__ import annotations
18+
from typing import overload
1819
from tilelang import tvm as tvm
1920
from tvm.script import tir as T
2021
from tvm.tir import PrimExpr
2122
from tvm.script.parser.tir import block_attr
23+
from tvm.tir.buffer import Buffer
24+
from tvm.tir.expr import FloatImm, IntImm
2225

2326

2427
def alloc_shared(shape, dtype, scope="shared.dyn"):
@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
6770
return T.alloc_buffer(shape, dtype, scope=scope)
6871

6972

73+
@overload
74+
def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = 'local.var') -> Buffer:
75+
...
76+
77+
78+
@overload
79+
def alloc_var(dtype: str,
80+
scope: str = 'local.var',
81+
*,
82+
init: PrimExpr | int | float | None = None) -> Buffer:
83+
...
84+
85+
7086
def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
7187
"""Allocate a single-element variable buffer.
7288
@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
8298
init (PrimExpr, optional): The optional initializer value. When provided,
8399
the generated code will initialize the variable with this value instead
84100
of defaulting to zero.
85-
101+
Examples:
102+
a = T.alloc_var('int32', 1) # var with init 1
103+
a = T.alloc_var('int32', 'local.var') # var with local.var scope
104+
a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope
105+
a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope
106+
a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope
86107
Returns:
87108
T.Buffer: A TVM buffer object allocated as a single-element variable
88109
"""
@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
113134

114135
buffer = T.alloc_buffer([1], dtype, scope=parsed_scope)
115136
if parsed_init is not None:
116-
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
137+
if isinstance(parsed_init, (int, float, IntImm, FloatImm)):
138+
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
139+
else:
140+
T.buffer_store(buffer, parsed_init, 0)
117141
return buffer
118142

119143

0 commit comments

Comments
 (0)