@@ -638,7 +638,7 @@ def inject_conv2d_transpose_skip(stmt_in):
638638 selects = []
639639
640640 def _find_basics (op ):
641- if isinstance (op , tvm .tir .Call ):
641+ if isinstance (op , tvm .tir .BufferLoad ):
642642 calls .append (op )
643643 elif isinstance (op , tvm .tir .Select ):
644644 selects .append (op )
@@ -664,7 +664,7 @@ def _do_fold(op):
664664 body = op .body .body
665665 while isinstance (body , tvm .tir .IfThenElse ):
666666 body = body .then_case
667- args = body .args
667+ args = body .indices
668668 res_tensor = body .func .output (0 )
669669 tpl = (args [0 ], 1 , args [1 ], 1 , args [2 ], 1 , args [3 ], 1 , 0 , 1 , 0 , env .BLOCK_OUT )
670670 inner = tvm .tir .AttrStmt (
@@ -696,19 +696,19 @@ def _do_fold(op):
696696 0 , 0 , 0 ))
697697 inner = irb .get ()
698698
699- args = conv_call .args
699+ args = conv_call .indices
700700 tpl = (args [0 ], 1 , args [1 ], 1 , args [2 ], 1 , args [3 ],
701701 1 , 0 , 1 , 0 , env .BLOCK_OUT )
702702 inner = tvm .tir .AttrStmt (
703703 [dout , res_tensor ], 'buffer_bind_scope' ,
704704 tvm .tir .call_intrin ('handle' , 'tvm_tuple' , * tpl ), inner )
705- args = kernel_call .args
705+ args = kernel_call .indices
706706 tpl = (args [0 ], 1 , args [1 ], 1 , args [2 ], 1 , args [3 ],
707707 1 , 0 , env .BLOCK_OUT , 0 , env .BLOCK_IN )
708708 inner = tvm .tir .AttrStmt (
709709 [dwgt , kernel_tensor ], 'buffer_bind_scope' ,
710710 tvm .tir .call_intrin ('handle' , 'tvm_tuple' , * tpl ), inner )
711- args = data_call .args
711+ args = data_call .indices
712712 tpl = (args [0 ], 1 , args [1 ], 1 , args [2 ], 1 , args [3 ],
713713 1 , 0 , 1 , 0 , env .BLOCK_IN )
714714 inner = tvm .tir .AttrStmt (
0 commit comments