|
16 | 16 | # under the License. |
17 | 17 | """The entry point of TVM parser for tir.""" |
18 | 18 | import inspect |
19 | | -from typing import Callable, Union |
| 19 | +from typing import Any, Callable, Dict, Union |
20 | 20 |
|
21 | 21 | from tvm.ir.base import deprecated |
22 | 22 | from tvm.tir import Buffer, PrimFunc |
23 | 23 |
|
24 | 24 | from ...ir_builder.tir import buffer, ptr |
25 | | -from .._core import parse, utils |
| 25 | +from .._core import doc, parse, parse_macro, utils |
26 | 26 |
|
27 | 27 |
|
28 | 28 | def prim_func(func: Callable) -> Union[PrimFunc, Callable]: |
@@ -50,6 +50,101 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: |
50 | 50 | setattr(prim_func, "dispatch_token", "tir") |
51 | 51 |
|
52 | 52 |
|
| 53 | +# Semantics of TIR macros: |
| 54 | +# - Function that is decorated with @T.macro can have any parameters that |
| 55 | +# follow Python syntax, i.e. positional, keyword, etc. Type annotations |
| 56 | +# are not required, but are allowed. |
| 57 | +# - Macro use follows the same syntax as a function call. |
| 58 | +# For `macro_name(arg1, arg2, arg3, ...)`, the values are substituted into |
| 59 | +# the body of the macro, and the body with the substituted values is then |
| 60 | +# inserted at the point where the call to the macro is located. |
| 61 | + |
| 62 | + |
| 63 | +class TIRMacro: |
| 64 | + """Representation of T.macro.""" |
| 65 | + |
| 66 | + def __init__( |
| 67 | + self, |
| 68 | + source_ast: doc.AST, |
| 69 | + source_txt: str, |
| 70 | + closure_vars: Dict[str, Any], |
| 71 | + func: Callable, |
| 72 | + hygienic: bool, |
| 73 | + ) -> None: |
| 74 | + self.source_ast = source_ast |
| 75 | + self.source_txt = source_txt |
| 76 | + self.closure_vars = closure_vars |
| 77 | + self.func = func |
| 78 | + self.hygienic = hygienic |
| 79 | + |
| 80 | + def __repr__(self): |
| 81 | + return self.source_txt |
| 82 | + |
| 83 | + |
| 84 | +def macro(*args, hygienic: bool = True) -> Callable: |
| 85 | + """Decorator for macro definitions. |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + hygienic: bool |
| 90 | + Specifies whether the macro is hygienic or not. |
| 91 | + A macro is hygienic if all symbols used in the macro's body are resolved |
| 92 | + to values from the location of the macro definition. A non-hygienic macro |
| 93 | + will have its symbols resolved to values at the time of the macro's use. |
| 94 | +
|
| 95 | + Example: |
| 96 | + ``` |
| 97 | + import tvm |
| 98 | + from tvm.script import tir as T |
| 99 | +
|
| 100 | + x_value = 128 |
| 101 | +
|
| 102 | + @T.macro(hygienic=True) |
| 103 | + def static_capture(A, B): |
| 104 | + B[()] = A[x_value] ### x_value binds to 128 |
| 105 | +
|
| 106 | + @T.macro(hygienic=False) |
| 107 | + def dynamic_capture(A, B): |
| 108 | + B[()] = A[x_value] ### x_value will bind at the time of use |
| 109 | +
|
| 110 | +
|
| 111 | + @T.prim_func |
| 112 | + def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: |
| 113 | + for x_value in T.serial(10): |
| 114 | + static_capture(A, B) ### Produces B[()] = A[128] |
| 115 | +
|
| 116 | + @T.prim_func |
| 117 | + def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None: |
| 118 | + for x_value in T.serial(10): |
| 119 | + dynamic_capture(A, B) ### Produces B[()] = A[x_value] |
| 120 | + ``` |
| 121 | + """ |
| 122 | + |
| 123 | + def _decorator(func: Callable) -> TIRMacro: |
| 124 | + source_ast, source_txt, closure_vars = parse_macro( |
| 125 | + func, utils.inspect_function_capture(func) |
| 126 | + ) |
| 127 | + obj = TIRMacro(source_ast, source_txt, closure_vars, func, hygienic) |
| 128 | + obj.__name__ = func.__name__ |
| 129 | + # We don't need to explicitly store the return value anywhere. |
| 130 | + # This function is a decorator, so the return value will replace |
| 131 | + # the function definition (to which the decorator it is applied) |
| 132 | + # in that function's name space. |
| 133 | + return obj |
| 134 | + |
| 135 | + if len(args) == 0: |
| 136 | + return _decorator |
| 137 | + if len(args) == 1 and inspect.isfunction(args[0]): |
| 138 | + return _decorator(args[0]) |
| 139 | + |
| 140 | + raise ValueError( |
| 141 | + "Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])" |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +# There is no dispatch_token for macro, because macro doesn't invoke parser. |
| 146 | + |
| 147 | + |
53 | 148 | class BufferProxy: |
54 | 149 | """Buffer proxy class for constructing tir buffer.""" |
55 | 150 |
|
|
0 commit comments