|  | 
|  | 1 | +"""TVMScript parser overrides tailored for TileLang.""" | 
|  | 2 | + | 
|  | 3 | +from functools import partial | 
|  | 4 | +from typing import Tuple | 
|  | 5 | + | 
|  | 6 | +from tvm.script.ir_builder import tir as T | 
|  | 7 | +from tvm.script.parser._core import dispatch, doc | 
|  | 8 | +from tvm.tir import BufferLoad, Var | 
|  | 9 | + | 
|  | 10 | +from tvm.script.parser.tir import parser as tvm_tir_parser | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +def _get_node_span(node: doc.AST) -> Tuple[int, int, int, int]: | 
|  | 14 | +    """Return the span (lineno, col, end_lineno, end_col) for a doc node.""" | 
|  | 15 | +    return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset) | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +# Original implementation located at | 
|  | 19 | +# 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_aug_assign). | 
|  | 20 | +@dispatch.register(token="tir", type_name="AugAssign") | 
|  | 21 | +def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None:  # pylint: disable=unused-argument | 
|  | 22 | +    """Override `AugAssign` to support writes into `local.var` buffers.""" | 
|  | 23 | +    lhs_pos = _get_node_span(node.target) | 
|  | 24 | +    rhs_pos = _get_node_span(node.value) | 
|  | 25 | + | 
|  | 26 | +    node.target.ctx = doc.Load() | 
|  | 27 | +    with self.var_table.with_frame(): | 
|  | 28 | +        lhs_name = "__tvm_tmp_value_aug_assign_lhs" | 
|  | 29 | +        rhs_name = "__tvm_tmp_value_aug_assign_rhs" | 
|  | 30 | +        lhs_expr = self.eval_expr(node.target) | 
|  | 31 | +        rhs_expr = self.eval_expr(node.value) | 
|  | 32 | +        self.var_table.add(lhs_name, lhs_expr) | 
|  | 33 | +        self.var_table.add(rhs_name, rhs_expr) | 
|  | 34 | +        op = doc.BinOp( | 
|  | 35 | +            doc.Name(lhs_name, doc.Load(), *lhs_pos), | 
|  | 36 | +            node.op, | 
|  | 37 | +            doc.Name(rhs_name, doc.Load(), *rhs_pos), | 
|  | 38 | +            *lhs_pos, | 
|  | 39 | +        ) | 
|  | 40 | +        rhs = self.eval_expr(op) | 
|  | 41 | + | 
|  | 42 | +    lhs = node.target | 
|  | 43 | +    lhs.ctx = doc.Store() | 
|  | 44 | +    if isinstance(lhs, doc.Subscript): | 
|  | 45 | +        if isinstance(lhs.slice, doc.Tuple): | 
|  | 46 | +            indices = [self.eval_expr(index) for index in lhs.slice.elts] | 
|  | 47 | +        else: | 
|  | 48 | +            indices = [self.eval_expr(lhs.slice)] | 
|  | 49 | +        T.buffer_store(self.eval_expr(lhs.value), rhs, indices) | 
|  | 50 | +        return | 
|  | 51 | + | 
|  | 52 | +    if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): | 
|  | 53 | +        load_ctx = doc.Load() | 
|  | 54 | +        store_ctx = doc.Store() | 
|  | 55 | +        lhs.ctx = load_ctx | 
|  | 56 | +        lhs_value = self.eval_expr(lhs) | 
|  | 57 | +        lhs.ctx = store_ctx | 
|  | 58 | +        if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and | 
|  | 59 | +                len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): | 
|  | 60 | +            T.buffer_store(lhs_value.buffer, rhs, indices=[0]) | 
|  | 61 | +            return | 
|  | 62 | + | 
|  | 63 | +    self.eval_assign(target=lhs, source=rhs, bind_value=tvm_tir_parser.bind_assign_value) | 
|  | 64 | + | 
|  | 65 | + | 
|  | 66 | +# Original implementation located at | 
|  | 67 | +# 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_ann_assign). | 
|  | 68 | +@dispatch.register(token="tir", type_name="AnnAssign") | 
|  | 69 | +def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None:  # pylint: disable=unused-argument | 
|  | 70 | +    """Override `AnnAssign` to support writes into `local.var` buffers.""" | 
|  | 71 | +    lhs = node.target | 
|  | 72 | +    rhs = self.eval_expr(node.value) | 
|  | 73 | +    ann_var = self.visit_tvm_annotation(node.annotation) | 
|  | 74 | +    if not isinstance(ann_var, Var): | 
|  | 75 | +        self.report_error(node.annotation, "Annotation should be Var") | 
|  | 76 | + | 
|  | 77 | +    if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get(): | 
|  | 78 | +        load_ctx = doc.Load() | 
|  | 79 | +        store_ctx = doc.Store() | 
|  | 80 | +        lhs.ctx = load_ctx | 
|  | 81 | +        lhs_value = self.eval_expr(lhs) | 
|  | 82 | +        lhs.ctx = store_ctx | 
|  | 83 | +        if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and | 
|  | 84 | +                len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0): | 
|  | 85 | +            T.buffer_store(lhs_value.buffer, rhs, indices=[0]) | 
|  | 86 | +            return | 
|  | 87 | + | 
|  | 88 | +    self.eval_assign(target=lhs, source=ann_var, bind_value=tvm_tir_parser.bind_assign_value) | 
|  | 89 | +    frame = T.LetStmt(rhs, var=ann_var) | 
|  | 90 | +    frame.add_callback(partial(frame.__exit__, None, None, None)) | 
|  | 91 | +    frame.__enter__() | 
0 commit comments