Skip to content

Commit ebe9a7c

Browse files
Krzysztof Parzyszekjunrushao
authored andcommitted
[TIR] Implement TIR macros (apache#15260)
* [TIR] Implement TIR macros This patch introduces two new symbols: `T.macro` and `T.insert`. `T.macro` is a decorator that, when applied to a function, turns the body of that function into a piece of TIR that can be inserted via `T.insert` into a PrimFunc. For example: ```python @T.macro def copy_backwards(dst, src, size): with T.block("backwards"): for i in T.serial(size): ai = T.axis.remap("S", [i]) T.reads(src[0:size]) T.writes(dst[0:size]) dst[ai] = src[size - ai - 1] @T.prim_func def foo_int32(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")): T.insert(copy_backwards, A, B, 128) @T.prim_func def foo_int8(A: T.Buffer((128,), "int8"), B: T.Buffer((128,), "int8")): T.insert(copy_backwards, A, B, 128) ``` The above will generate two PrimFuncs that do the same backwards copy, but applied to buffers with different data types. Semantics: - Function that is decorated with @T.macro can have any parameters that follow Python syntax, i.e. positional, keyword, etc. Type annotations are not required, but are allowed. - The arguments to `T.insert` are macro name followed by the argument list. For `T.insert(arg1, arg2, arg3, ...)`, the values are substituted into the body of the macro as in the call `arg1(arg2, arg3, ...)`. The body with the substituted values is then inserted at the point where the `T.insert` is located. * Fix linter * Fix linter again One linter suggested something that the other didn't like... * Get rid of T.insert, apply macro via function-call syntax * Store closure vars in TIRMacro * ast.parse always returns ast.Module, hence doc is doc.Module * Simplify `expand_macro`, capture environment variables * Implement macro hygiene * Fix linter * Make T.macro work same as T.macro() The previous commit inadvertently made T.macro (without parentheses) illegal, only abbreviated form allowed was T.macro(). Restore T.macro as a valid decorator use. * Edit comment: insertion -> expansion * Add import pytest * One more typo... * Remove stale testcase
1 parent 4e9ca2a commit ebe9a7c

File tree

4 files changed

+264
-4
lines changed

4 files changed

+264
-4
lines changed

python/tvm/script/parser/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@
1818
# pylint: disable=unused-import
1919
from .core import dispatch, doc, utils
2020
from .core.dispatch import OpMethod, register_op
21-
from .core.entry import parse
21+
from .core.entry import parse, parse_macro
2222
from .core.parser import Parser

python/tvm/script/parser/tir/entry.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
# under the License.
1717
"""The entry point of TVM parser for tir."""
1818
import inspect
19-
from typing import Callable, Union
19+
from typing import Any, Callable, Dict, Union
2020

2121
from tvm.ir.base import deprecated
2222
from tvm.tir import Buffer, PrimFunc
2323

2424
from ...ir_builder.tir import buffer, ptr
25-
from .._core import parse, utils
25+
from .._core import doc, parse, parse_macro, utils
2626

2727

2828
def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
@@ -50,6 +50,101 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
5050
setattr(prim_func, "dispatch_token", "tir")
5151

5252

53+
# Semantics of TIR macros:
54+
# - Function that is decorated with @T.macro can have any parameters that
55+
# follow Python syntax, i.e. positional, keyword, etc. Type annotations
56+
# are not required, but are allowed.
57+
# - Macro use follows the same syntax as a function call.
58+
# For `macro_name(arg1, arg2, arg3, ...)`, the values are substituted into
59+
# the body of the macro, and the body with the substituted values is then
60+
# inserted at the point where the call to the macro is located.
61+
62+
63+
class TIRMacro:
64+
"""Representation of T.macro."""
65+
66+
def __init__(
67+
self,
68+
source_ast: doc.AST,
69+
source_txt: str,
70+
closure_vars: Dict[str, Any],
71+
func: Callable,
72+
hygienic: bool,
73+
) -> None:
74+
self.source_ast = source_ast
75+
self.source_txt = source_txt
76+
self.closure_vars = closure_vars
77+
self.func = func
78+
self.hygienic = hygienic
79+
80+
def __repr__(self):
81+
return self.source_txt
82+
83+
84+
def macro(*args, hygienic: bool = True) -> Callable:
85+
"""Decorator for macro definitions.
86+
87+
Parameters
88+
----------
89+
hygienic: bool
90+
Specifies whether the macro is hygienic or not.
91+
A macro is hygienic if all symbols used in the macro's body are resolved
92+
to values from the location of the macro definition. A non-hygienic macro
93+
will have its symbols resolved to values at the time of the macro's use.
94+
95+
Example:
96+
```
97+
import tvm
98+
from tvm.script import tir as T
99+
100+
x_value = 128
101+
102+
@T.macro(hygienic=True)
103+
def static_capture(A, B):
104+
B[()] = A[x_value] ### x_value binds to 128
105+
106+
@T.macro(hygienic=False)
107+
def dynamic_capture(A, B):
108+
B[()] = A[x_value] ### x_value will bind at the time of use
109+
110+
111+
@T.prim_func
112+
def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
113+
for x_value in T.serial(10):
114+
static_capture(A, B) ### Produces B[()] = A[128]
115+
116+
@T.prim_func
117+
def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
118+
for x_value in T.serial(10):
119+
dynamic_capture(A, B) ### Produces B[()] = A[x_value]
120+
```
121+
"""
122+
123+
def _decorator(func: Callable) -> TIRMacro:
124+
source_ast, source_txt, closure_vars = parse_macro(
125+
func, utils.inspect_function_capture(func)
126+
)
127+
obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic)
128+
obj.__name__ = func.__name__
129+
# We don't need to explicitly store the return value anywhere.
130+
# This function is a decorator, so the return value will replace
131+
# the function definition (to which the decorator it is applied)
132+
# in that function's name space.
133+
return obj
134+
135+
if len(args) == 0:
136+
return _decorator
137+
if len(args) == 1 and inspect.isfunction(args[0]):
138+
return _decorator(args[0])
139+
140+
raise ValueError(
141+
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])"
142+
)
143+
144+
145+
# There is no dispatch_token for macro, because macro doesn't invoke parser.
146+
147+
53148
class BufferProxy:
54149
"""Buffer proxy class for constructing tir buffer."""
55150

python/tvm/script/parser/tir/parser.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
"""The base parser for tir"""
1818

1919
import contextlib
20+
import inspect
2021
from functools import partial
21-
from typing import Any
22+
from typing import Any, Union
2223

2324
import tvm
2425
from tvm.ir import GlobalVar, PrimType
@@ -29,6 +30,8 @@
2930
from ...ir_builder.base import IRBuilder
3031
from ...ir_builder.base import IRBuilderFrame as Frame
3132
from .._core import Parser, dispatch, doc
33+
from ..core.parser import VarTable
34+
from .entry import TIRMacro
3235

3336

3437
def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
@@ -427,6 +430,12 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
427430
node : doc.Expr
428431
The doc AST Expr node.
429432
"""
433+
434+
if isinstance(node.value, doc.Call):
435+
callee = self.eval_expr(node.value.func)
436+
if isinstance(callee, TIRMacro):
437+
return expand_macro(self, callee, node.value)
438+
430439
res = self.eval_expr(node.value)
431440
if res is None:
432441
pass
@@ -447,6 +456,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
447456
pass
448457
else:
449458
self.report_error(node, f"Parsing resulted in unexpected type {type(res)}")
459+
return None # For pylint
450460

451461

452462
@dispatch.register(token="tir", type_name="If")
@@ -528,3 +538,51 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
528538
# Only ret_type is needed for func_signature.
529539
func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
530540
return I.decl_function(node.name, func_signature)
541+
542+
543+
def expand_macro(self: Parser, callee: TIRMacro, call: doc.Call) -> None:
544+
"""Bind arguments to the macro invocation to the parameters in the macro definition,
545+
and pass the macro body for further parsing.
546+
"""
547+
548+
assert isinstance(callee, TIRMacro), f"Unexpected macro type {type(callee)}"
549+
550+
def find_macro_def(name: str, decl_list: doc.AST) -> Union[doc.FunctionDef, Any]:
551+
for decl in decl_list:
552+
if isinstance(decl, doc.FunctionDef) and decl.name == name:
553+
return decl
554+
return None
555+
556+
macro_def = find_macro_def(callee.__name__, callee.source_ast.body)
557+
assert macro_def is not None, f"Invalid macro AST for {callee.__name__}"
558+
# `macro_def` is the FunctionDef of the macro.
559+
560+
args = [self.eval_expr(arg) for arg in call.args]
561+
kwargs = {kw.arg: self.eval_expr(kw.value) for kw in call.keywords}
562+
param_binding = inspect.signature(callee.func).bind(*args, **kwargs)
563+
param_binding.apply_defaults()
564+
local_vars = param_binding.arguments
565+
566+
if callee.hygienic:
567+
# If the macro was hygienic, construct new var_table with a single frame that
568+
# contains the captured environment, and process the macro's body with that
569+
# frame.
570+
saved_var_table = self.var_table
571+
self.var_table = VarTable()
572+
with self.var_table.with_frame():
573+
for k, v in callee.closure_vars.items():
574+
self.var_table.add(k, v)
575+
for k, v in local_vars.items():
576+
self.var_table.add(k, v)
577+
578+
self.visit_body(macro_def.body)
579+
580+
self.var_table = saved_var_table
581+
582+
else:
583+
# Otherwise, dynamically resolve symbols in the macro's body.
584+
with self.var_table.with_frame():
585+
for k, v in local_vars.items():
586+
self.var_table.add(k, v)
587+
588+
self.visit_body(macro_def.body)

tests/python/unittest/test_tvmscript_parser_tir.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
"""Unittests for tvm.script.parser.tir"""
1818

19+
import pytest
1920
import tvm.testing
2021
from tvm.script.parser import tir as T
2122
from tvm import ir, tir
@@ -71,5 +72,111 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
7172
assert matmul.__name__ == "matmul"
7273

7374

75+
def test_tir_macro_decorator_signature():
76+
@T.prim_func
77+
def evaluate0():
78+
T.evaluate(0)
79+
80+
# Ok, no parentheses
81+
@T.macro
82+
def func1():
83+
T.evaluate(0)
84+
85+
assert func1.hygienic
86+
87+
@T.prim_func
88+
def use1():
89+
func1()
90+
91+
tvm.ir.assert_structural_equal(use1, evaluate0)
92+
93+
# Ok, empty parentheses
94+
@T.macro()
95+
def func2():
96+
T.evaluate(0)
97+
98+
assert func2.hygienic
99+
100+
@T.prim_func
101+
def use2():
102+
func2()
103+
104+
tvm.ir.assert_structural_equal(use1, evaluate0)
105+
106+
with pytest.raises(ValueError):
107+
# Wrong: non-keyword argument
108+
@T.macro(True)
109+
def func3():
110+
T.evaluate()
111+
112+
113+
def test_tir_macro_signature():
114+
@T.macro
115+
def assign(i, *args, t1, **kwargs):
116+
vi, vj, vk = T.axis.remap("SSR", [i, args[0], args[1]])
117+
kwargs["t3"][vi, vj] = kwargs["t3"][vi, vj] + t1[vi, vk] * kwargs["t2"][vj, vk]
118+
119+
@T.prim_func
120+
def matmul_w_macro(a: T.handle, b: T.handle, c: T.handle) -> None:
121+
A = T.match_buffer(a, [128, 128])
122+
B = T.match_buffer(b, [128, 128])
123+
C = T.match_buffer(c, [128, 128])
124+
for i, j, k in T.grid(128, 128, 128):
125+
with T.block("update"):
126+
assign(i, j, k, t1=A, t2=B, t3=C)
127+
128+
@T.prim_func
129+
def matmul_no_macro(a: T.handle, b: T.handle, c: T.handle) -> None:
130+
A = T.match_buffer(a, [128, 128])
131+
B = T.match_buffer(b, [128, 128])
132+
C = T.match_buffer(c, [128, 128])
133+
for i, j, k in T.grid(128, 128, 128):
134+
with T.block("update"):
135+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
136+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
137+
138+
tvm.ir.assert_structural_equal(matmul_no_macro, matmul_w_macro)
139+
140+
141+
def test_tir_macro_hygienic():
142+
x_value = 128
143+
144+
@T.macro(hygienic=True)
145+
def static_capture(A, B):
146+
B[()] = A[x_value]
147+
148+
@T.prim_func
149+
def use_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
150+
for x_value in T.serial(10):
151+
static_capture(A, B)
152+
153+
@T.prim_func
154+
def expected_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
155+
for x_value in range(10):
156+
B[()] = A[128]
157+
158+
tvm.ir.assert_structural_equal(use_hygienic, expected_hygienic)
159+
160+
161+
def test_tir_macro_non_hygienic():
162+
x_value = 128
163+
164+
@T.macro(hygienic=False)
165+
def dynamic_capture(A, B):
166+
B[()] = A[x_value]
167+
168+
@T.prim_func
169+
def use_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
170+
for x_value in T.serial(10):
171+
dynamic_capture(A, B)
172+
173+
@T.prim_func
174+
def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
175+
for x_value in range(10):
176+
B[()] = A[x_value]
177+
178+
tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic)
179+
180+
74181
if __name__ == "__main__":
75182
tvm.testing.main()

0 commit comments

Comments
 (0)