Skip to content

Commit 77e31e5

Browse files
authored
[Language] Enhance T.alloc_var for AugAssign and AnnAsign (#979)
* feat: add parser overrides for local.var aug assign. * lint fix
1 parent 747381a commit 77e31e5

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

tilelang/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# TODO(lei): remove this import once the
88
# upstream tir script is fully compatible
99
from tvm.script.parser.tir import *
10+
from . import overrides as _overrides # noqa: F401
1011
from .tir import (
1112
prim_func, # noqa: F401
1213
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""TileLang-specific runtime overrides.
2+
3+
Importing this package registers custom handlers that extend or override
4+
behaviour from upstream TVMScript for TileLang semantics.
5+
"""
6+
7+
# Register parser overrides upon import.
8+
from . import parser # noqa: F401
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""TVMScript parser overrides tailored for TileLang."""
2+
3+
from functools import partial
4+
from typing import Tuple
5+
6+
from tvm.script.ir_builder import tir as T
7+
from tvm.script.parser._core import dispatch, doc
8+
from tvm.tir import BufferLoad, Var
9+
10+
from tvm.script.parser.tir import parser as tvm_tir_parser
11+
12+
13+
def _get_node_span(node: doc.AST) -> Tuple[int, int, int, int]:
14+
"""Return the span (lineno, col, end_lineno, end_col) for a doc node."""
15+
return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset)
16+
17+
18+
# Original implementation located at
19+
# 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_aug_assign).
20+
@dispatch.register(token="tir", type_name="AugAssign")
21+
def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: disable=unused-argument
22+
"""Override `AugAssign` to support writes into `local.var` buffers."""
23+
lhs_pos = _get_node_span(node.target)
24+
rhs_pos = _get_node_span(node.value)
25+
26+
node.target.ctx = doc.Load()
27+
with self.var_table.with_frame():
28+
lhs_name = "__tvm_tmp_value_aug_assign_lhs"
29+
rhs_name = "__tvm_tmp_value_aug_assign_rhs"
30+
lhs_expr = self.eval_expr(node.target)
31+
rhs_expr = self.eval_expr(node.value)
32+
self.var_table.add(lhs_name, lhs_expr)
33+
self.var_table.add(rhs_name, rhs_expr)
34+
op = doc.BinOp(
35+
doc.Name(lhs_name, doc.Load(), *lhs_pos),
36+
node.op,
37+
doc.Name(rhs_name, doc.Load(), *rhs_pos),
38+
*lhs_pos,
39+
)
40+
rhs = self.eval_expr(op)
41+
42+
lhs = node.target
43+
lhs.ctx = doc.Store()
44+
if isinstance(lhs, doc.Subscript):
45+
if isinstance(lhs.slice, doc.Tuple):
46+
indices = [self.eval_expr(index) for index in lhs.slice.elts]
47+
else:
48+
indices = [self.eval_expr(lhs.slice)]
49+
T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
50+
return
51+
52+
if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get():
53+
load_ctx = doc.Load()
54+
store_ctx = doc.Store()
55+
lhs.ctx = load_ctx
56+
lhs_value = self.eval_expr(lhs)
57+
lhs.ctx = store_ctx
58+
if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and
59+
len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0):
60+
T.buffer_store(lhs_value.buffer, rhs, indices=[0])
61+
return
62+
63+
self.eval_assign(target=lhs, source=rhs, bind_value=tvm_tir_parser.bind_assign_value)
64+
65+
66+
# Original implementation located at
67+
# 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_ann_assign).
68+
@dispatch.register(token="tir", type_name="AnnAssign")
69+
def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: disable=unused-argument
70+
"""Override `AnnAssign` to support writes into `local.var` buffers."""
71+
lhs = node.target
72+
rhs = self.eval_expr(node.value)
73+
ann_var = self.visit_tvm_annotation(node.annotation)
74+
if not isinstance(ann_var, Var):
75+
self.report_error(node.annotation, "Annotation should be Var")
76+
77+
if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get():
78+
load_ctx = doc.Load()
79+
store_ctx = doc.Store()
80+
lhs.ctx = load_ctx
81+
lhs_value = self.eval_expr(lhs)
82+
lhs.ctx = store_ctx
83+
if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and
84+
len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0):
85+
T.buffer_store(lhs_value.buffer, rhs, indices=[0])
86+
return
87+
88+
self.eval_assign(target=lhs, source=ann_var, bind_value=tvm_tir_parser.bind_assign_value)
89+
frame = T.LetStmt(rhs, var=ann_var)
90+
frame.add_callback(partial(frame.__exit__, None, None, None))
91+
frame.__enter__()

0 commit comments

Comments
 (0)