Skip to content

Commit fdee3d1

Browse files
committed
Fix VTA to fit the new IR Pattern
1 parent 016d03a commit fdee3d1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

vta/python/vta/ir_pass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def inject_conv2d_transpose_skip(stmt_in):
638638
selects = []
639639

640640
def _find_basics(op):
641-
if isinstance(op, tvm.tir.Call):
641+
if isinstance(op, tvm.tir.BufferLoad):
642642
calls.append(op)
643643
elif isinstance(op, tvm.tir.Select):
644644
selects.append(op)
@@ -664,7 +664,7 @@ def _do_fold(op):
664664
body = op.body.body
665665
while isinstance(body, tvm.tir.IfThenElse):
666666
body = body.then_case
667-
args = body.args
667+
args = body.indices
668668
res_tensor = body.func.output(0)
669669
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
670670
inner = tvm.tir.AttrStmt(
@@ -696,19 +696,19 @@ def _do_fold(op):
696696
0, 0, 0))
697697
inner = irb.get()
698698

699-
args = conv_call.args
699+
args = conv_call.indices
700700
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
701701
1, 0, 1, 0, env.BLOCK_OUT)
702702
inner = tvm.tir.AttrStmt(
703703
[dout, res_tensor], 'buffer_bind_scope',
704704
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
705-
args = kernel_call.args
705+
args = kernel_call.indices
706706
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
707707
1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
708708
inner = tvm.tir.AttrStmt(
709709
[dwgt, kernel_tensor], 'buffer_bind_scope',
710710
tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
711-
args = data_call.args
711+
args = data_call.indices
712712
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
713713
1, 0, 1, 0, env.BLOCK_IN)
714714
inner = tvm.tir.AttrStmt(

0 commit comments

Comments
 (0)