@@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
810
810
OpBuilder::InsertionGuard g (rewriter);
811
811
rewriter.setInsertionPointAfter (hoistedPackedTensor.getDefiningOp ());
812
812
813
- std::optional<unsigned > maybeOperandNumber =
814
- forOp.getIterArgNumberForOpOperand (*pUse);
815
- assert (maybeOperandNumber.has_value () && " expected a proper iter arg number" );
816
-
817
- int64_t operandNumber = maybeOperandNumber.value ();
813
+ unsigned iterArgNumber = forOp.getResultForOpOperand (*pUse).getResultNumber ();
818
814
auto yieldOp = cast<scf::YieldOp>(forOp.getBody (0 )->getTerminator ());
819
- auto yieldingExtractSliceOp = yieldOp->getOperand (operandNumber )
815
+ auto yieldingExtractSliceOp = yieldOp->getOperand (iterArgNumber )
820
816
.getDefiningOp <tensor::ExtractSliceOp>();
821
817
if (!yieldingExtractSliceOp)
822
818
return tensor::ExtractSliceOp ();
@@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
829
825
return tensor::ExtractSliceOp ();
830
826
831
827
SmallVector<Value> initArgs = forOp.getInitArgs ();
832
- initArgs[operandNumber ] = hoistedPackedTensor;
828
+ initArgs[iterArgNumber ] = hoistedPackedTensor;
833
829
SmallVector<Value> yieldOperands = yieldOp.getOperands ();
834
- yieldOperands[operandNumber ] = yieldingExtractSliceOp.getSource ();
830
+ yieldOperands[iterArgNumber ] = yieldingExtractSliceOp.getSource ();
835
831
836
832
int64_t numOriginalForOpResults = initArgs.size ();
837
833
LLVM_DEBUG (DBGS () << " numOriginalForOpResults: " << numOriginalForOpResults
@@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
844
840
hoistedPackedTensor.getLoc (), hoistedPackedTensor,
845
841
outerSliceOp.getMixedOffsets (), outerSliceOp.getMixedSizes (),
846
842
outerSliceOp.getMixedStrides ());
847
- rewriter.replaceAllUsesWith (forOp.getResult (operandNumber ), extracted);
843
+ rewriter.replaceAllUsesWith (forOp.getResult (iterArgNumber ), extracted);
848
844
}
849
845
scf::ForOp newForOp =
850
846
replaceLoopWithNewYields (rewriter, forOp, initArgs, yieldOperands);
@@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
853
849
<< " \n " );
854
850
LLVM_DEBUG (DBGS () << " replace source of: " << extracted << " \n " );
855
851
LLVM_DEBUG (DBGS () << " with result #"
856
- << numOriginalForOpResults + operandNumber
852
+ << numOriginalForOpResults + iterArgNumber
857
853
<< " of forOp, giving us: " << extracted << " \n " );
858
854
rewriter.startRootUpdate (extracted);
859
855
extracted.getSourceMutable ().assign (
860
- newForOp.getResult (numOriginalForOpResults + operandNumber ));
856
+ newForOp.getResult (numOriginalForOpResults + iterArgNumber ));
861
857
rewriter.finalizeRootUpdate (extracted);
862
858
863
859
LLVM_DEBUG (DBGS () << " replace uses of: " << paddedValueBeforeHoisting
864
860
<< " \n " );
865
861
LLVM_DEBUG (DBGS () << " with region iter arg #"
866
- << numOriginalForOpResults + operandNumber << " \n " );
862
+ << numOriginalForOpResults + iterArgNumber << " \n " );
867
863
rewriter.replaceAllUsesWith (
868
864
paddedValueBeforeHoisting,
869
- newForOp.getRegionIterArg (numOriginalForOpResults + operandNumber ));
865
+ newForOp.getRegionIterArg (numOriginalForOpResults + iterArgNumber ));
870
866
871
867
return extracted;
872
868
}
0 commit comments