1515"""
1616
1717from __future__ import annotations
18+ from typing import overload
1819from tilelang import tvm as tvm
1920from tvm .script import tir as T
2021from tvm .tir import PrimExpr
2122from tvm .script .parser .tir import block_attr
23+ from tvm .tir .buffer import Buffer
24+ from tvm .tir .expr import FloatImm , IntImm
2225
2326
2427def alloc_shared (shape , dtype , scope = "shared.dyn" ):
@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
6770 return T .alloc_buffer (shape , dtype , scope = scope )
6871
6972
73+ @overload
74+ def alloc_var (dtype : str , init : PrimExpr | int | float , scope : str = 'local.var' ) -> Buffer :
75+ ...
76+
77+
78+ @overload
79+ def alloc_var (dtype : str ,
80+ scope : str = 'local.var' ,
81+ * ,
82+ init : PrimExpr | int | float | None = None ) -> Buffer :
83+ ...
84+
85+
7086def alloc_var (dtype , * args , scope = "local.var" , init : PrimExpr | None = None ):
7187 """Allocate a single-element variable buffer.
7288
@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
8298 init (PrimExpr, optional): The optional initializer value. When provided,
8399 the generated code will initialize the variable with this value instead
84100 of defaulting to zero.
85-
101+ Examples:
102+ a = T.alloc_var('int32', 1) # var with init 1
103+ a = T.alloc_var('int32', 'local.var') # var with local.var scope
104+ a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope
105+ a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope
106+ a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope
86107 Returns:
87108 T.Buffer: A TVM buffer object allocated as a single-element variable
88109 """
@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
113134
114135 buffer = T .alloc_buffer ([1 ], dtype , scope = parsed_scope )
115136 if parsed_init is not None :
116- block_attr ({"tl.local_var_init" : {buffer .data : parsed_init }})
137+ if isinstance (parsed_init , (int , float , IntImm , FloatImm )):
138+ block_attr ({"tl.local_var_init" : {buffer .data : parsed_init }})
139+ else :
140+ T .buffer_store (buffer , parsed_init , 0 )
117141 return buffer
118142
119143
0 commit comments