Skip to content

Commit 7ab574b

Browse files
committed
[COMPILER] Upgrade to meet latest TVM IR pragma convention (apache#32)
1 parent 012697a commit 7ab574b

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

vta/python/vta/ir_pass.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@
88
from .environment import get_env
99

1010

11+
def _match_pragma(stmt, key):
12+
"""Internal helper to match stmt to pragma stmt.
13+
14+
Parameters
15+
----------
16+
stmt : Stmt
17+
The AttrStmt
18+
19+
key : str
20+
The pragma key
21+
"""
22+
return ((stmt.attr_key == "pragma_" + key) or
23+
(stmt.attr_key == "pragma_scope" and stmt.value.value == key))
24+
25+
1126
def fold_uop_loop(stmt_in):
1227
"""Detect and fold uop loop.
1328
@@ -255,7 +270,7 @@ def inject_skip_copy(stmt_in):
255270
Transformed statement
256271
"""
257272
def _do_fold(stmt):
258-
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy"):
273+
if _match_pragma(stmt, "skip_dma_copy"):
259274
return tvm.make.Evaluate(0)
260275
return None
261276
return tvm.ir_pass.IRTransform(
@@ -277,12 +292,12 @@ def inject_coproc_sync(stmt_in):
277292
"""
278293
success = [False]
279294
def _do_fold(stmt):
280-
if stmt.attr_key == "pragma_scope" and stmt.value.value == "coproc_sync":
295+
if _match_pragma(stmt, "coproc_sync"):
281296
success[0] = True
282297
sync = tvm.make.Call(
283298
"int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
284299
return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync))
285-
elif stmt.attr_key == "pragma_scope" and stmt.value.value == "trim_loop":
300+
elif _match_pragma(stmt, "trim_loop"):
286301
op = stmt.body
287302
assert isinstance(op, tvm.stmt.For)
288303
return tvm.make.For(
@@ -561,15 +576,15 @@ def annotate_alu_coproc_scope(stmt_in):
561576
"""
562577
env = get_env()
563578
def _do_fold(stmt):
564-
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
579+
if _match_pragma(stmt, "alu"):
565580
irb = tvm.ir_builder.create()
566581
irb.scope_attr(env.dev.vta_axis, "coproc_scope",
567582
env.dev.get_task_qid(env.dev.QID_COMPUTE))
568583
irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
569584
tvm.make.StringImm("VTAPushALUOp"))
570585
irb.emit(stmt)
571586
return irb.get()
572-
elif (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_alu"):
587+
elif _match_pragma(stmt, "skip_alu"):
573588
return tvm.make.Evaluate(0)
574589
return stmt
575590

@@ -631,7 +646,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
631646

632647
return rev_src_coeff, rev_dst_coeff, rev_extents
633648

634-
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
649+
if _match_pragma(stmt, "alu"):
635650
# Get to the innermost loop body
636651
loop_body = stmt.body
637652
nest_size = 0

0 commit comments

Comments
 (0)