88from .environment import get_env
99
1010
11+ def _match_pragma (stmt , key ):
12+ """Internal helper to match stmt to pragma stmt.
13+
14+ Parameters
15+ ----------
16+ stmt : Stmt
17+ The AttrStmt
18+
19+ key : str
20+ The pragma key
21+ """
22+ return ((stmt .attr_key == "pragma_" + key ) or
23+ (stmt .attr_key == "pragma_scope" and stmt .value .value == key ))
24+
25+
1126def fold_uop_loop (stmt_in ):
1227 """Detect and fold uop loop.
1328
@@ -255,7 +270,7 @@ def inject_skip_copy(stmt_in):
255270 Transformed statement
256271 """
257272 def _do_fold (stmt ):
258- if (stmt . attr_key == "pragma_scope" and stmt . value . value == "skip_dma_copy" ):
273+ if _match_pragma (stmt , "skip_dma_copy" ):
259274 return tvm .make .Evaluate (0 )
260275 return None
261276 return tvm .ir_pass .IRTransform (
@@ -277,12 +292,12 @@ def inject_coproc_sync(stmt_in):
277292 """
278293 success = [False ]
279294 def _do_fold (stmt ):
280- if stmt . attr_key == "pragma_scope" and stmt . value . value == " coproc_sync" :
295+ if _match_pragma ( stmt , " coproc_sync") :
281296 success [0 ] = True
282297 sync = tvm .make .Call (
283298 "int32" , "vta.coproc_sync" , [], tvm .expr .Call .Intrinsic , None , 0 )
284299 return tvm .make .Block (stmt .body , tvm .make .Evaluate (sync ))
285- elif stmt . attr_key == "pragma_scope" and stmt . value . value == " trim_loop" :
300+ elif _match_pragma ( stmt , " trim_loop") :
286301 op = stmt .body
287302 assert isinstance (op , tvm .stmt .For )
288303 return tvm .make .For (
@@ -561,15 +576,15 @@ def annotate_alu_coproc_scope(stmt_in):
561576 """
562577 env = get_env ()
563578 def _do_fold (stmt ):
564- if (stmt . attr_key == "pragma_scope" and stmt . value . value == "alu" ):
579+ if _match_pragma (stmt , "alu" ):
565580 irb = tvm .ir_builder .create ()
566581 irb .scope_attr (env .dev .vta_axis , "coproc_scope" ,
567582 env .dev .get_task_qid (env .dev .QID_COMPUTE ))
568583 irb .scope_attr (env .dev .vta_axis , "coproc_uop_scope" ,
569584 tvm .make .StringImm ("VTAPushALUOp" ))
570585 irb .emit (stmt )
571586 return irb .get ()
572- elif (stmt . attr_key == "pragma_scope" and stmt . value . value == "skip_alu" ):
587+ elif _match_pragma (stmt , "skip_alu" ):
573588 return tvm .make .Evaluate (0 )
574589 return stmt
575590
@@ -631,7 +646,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
631646
632647 return rev_src_coeff , rev_dst_coeff , rev_extents
633648
634- if (stmt . attr_key == "pragma_scope" and stmt . value . value == "alu" ):
649+ if _match_pragma (stmt , "alu" ):
635650 # Get to the innermost loop body
636651 loop_body = stmt .body
637652 nest_size = 0
0 commit comments