@@ -78,8 +78,7 @@ def _visit(op):
78
78
if not fail [0 ]:
79
79
begin = tvm .call_extern (
80
80
"int32" , "VTAUopLoopBegin" , stmt .extent , * gemm_offsets )
81
- end = tvm .call_extern (
82
- "int32" , "VTAUopLoopEnd" , stmt .extent , * gemm_offsets )
81
+ end = tvm .call_extern ("int32" , "VTAUopLoopEnd" )
83
82
return [begin , ret , end ]
84
83
raise ValueError ("Failed to fold the GEMM instructions.." )
85
84
@@ -683,8 +682,14 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
683
682
else :
684
683
raise RuntimeError (
685
684
"Function call not recognized %s" % (loop_body .value .name ))
685
+ elif isinstance (loop_body .value , tvm .expr .Load ):
686
+ alu_opcode = env .dev .ALU_OPCODE_SHR
687
+ lhs = loop_body .value
688
+ rhs = tvm .const (0 )
686
689
else :
687
- raise RuntimeError ("Expression not recognized %s" % (type (loop_body .value )))
690
+ raise RuntimeError (
691
+ "Expression not recognized %s, %s, %s" % (
692
+ type (loop_body .value ), str (loop_body .value ), str (stmt )))
688
693
689
694
# Derive array index coefficients
690
695
dst_coeff = tvm .arith .DetectLinearEquation (dst_idx , indices )
@@ -772,7 +777,9 @@ def _flatten_loop(src_coeff, dst_coeff, extents):
772
777
irb = tvm .ir_builder .create ()
773
778
for idx , extent in enumerate (extents ):
774
779
irb .emit (tvm .call_extern (
775
- "int32" , "VTAUopLoopBegin" , extent , dst_coeff [idx ], src_coeff [idx ]))
780
+ "int32" , "VTAUopLoopBegin" ,
781
+ extent , dst_coeff [idx ], src_coeff [idx ], 0 ))
782
+ use_imm = int (use_imm )
776
783
irb .emit (tvm .call_extern (
777
784
"int32" , "VTAUopPush" ,
778
785
1 , 0 ,
@@ -804,5 +811,6 @@ def debug_print(stmt):
804
811
stmt : Stmt
805
812
The
806
813
"""
814
+ # pylint: disable=superfluous-parens
807
815
print (stmt )
808
816
return stmt
0 commit comments