Skip to content

Commit 30b28ac

Browse files
[mlir][SCF] ForOp: Remove getIterArgNumberForOpOperand
This function was inconsistent with the remaining API because it accepted `OpOperand &` that do not belong to the op. All the other functions assert. This helper function is also not really necessary, as the iter_arg number is identical to the result number.
1 parent 97495d3 commit 30b28ac

File tree

3 files changed

+13
-26
lines changed

3 files changed

+13
-26
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,16 +263,6 @@ def ForOp : SCF_Op<"for",
263263
}
264264
/// Number of operands controlling the loop: lb, ub, step
265265
unsigned getNumControlOperands() { return 3; }
266-
/// Get the iter arg number for an operand. If it isnt an iter arg
267-
/// operand return std::nullopt.
268-
std::optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
269-
if (opOperand.getOwner() != getOperation())
270-
return std::nullopt;
271-
unsigned operandNumber = opOperand.getOperandNumber();
272-
if (operandNumber < getNumControlOperands())
273-
return std::nullopt;
274-
return operandNumber - getNumControlOperands();
275-
}
276266

277267
/// Get the region iter arg that corresponds to an OpOperand.
278268
/// This helper prevents internal op implementation detail leakage to

mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
810810
OpBuilder::InsertionGuard g(rewriter);
811811
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
812812

813-
std::optional<unsigned> maybeOperandNumber =
814-
forOp.getIterArgNumberForOpOperand(*pUse);
815-
assert(maybeOperandNumber.has_value() && "expected a proper iter arg number");
816-
817-
int64_t operandNumber = maybeOperandNumber.value();
813+
unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
818814
auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
819-
auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber)
815+
auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
820816
.getDefiningOp<tensor::ExtractSliceOp>();
821817
if (!yieldingExtractSliceOp)
822818
return tensor::ExtractSliceOp();
@@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
829825
return tensor::ExtractSliceOp();
830826

831827
SmallVector<Value> initArgs = forOp.getInitArgs();
832-
initArgs[operandNumber] = hoistedPackedTensor;
828+
initArgs[iterArgNumber] = hoistedPackedTensor;
833829
SmallVector<Value> yieldOperands = yieldOp.getOperands();
834-
yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource();
830+
yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
835831

836832
int64_t numOriginalForOpResults = initArgs.size();
837833
LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
@@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
844840
hoistedPackedTensor.getLoc(), hoistedPackedTensor,
845841
outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
846842
outerSliceOp.getMixedStrides());
847-
rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted);
843+
rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
848844
}
849845
scf::ForOp newForOp =
850846
replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
@@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
853849
<< "\n");
854850
LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
855851
LLVM_DEBUG(DBGS() << "with result #"
856-
<< numOriginalForOpResults + operandNumber
852+
<< numOriginalForOpResults + iterArgNumber
857853
<< " of forOp, giving us: " << extracted << "\n");
858854
rewriter.startRootUpdate(extracted);
859855
extracted.getSourceMutable().assign(
860-
newForOp.getResult(numOriginalForOpResults + operandNumber));
856+
newForOp.getResult(numOriginalForOpResults + iterArgNumber));
861857
rewriter.finalizeRootUpdate(extracted);
862858

863859
LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
864860
<< "\n");
865861
LLVM_DEBUG(DBGS() << "with region iter arg #"
866-
<< numOriginalForOpResults + operandNumber << "\n");
862+
<< numOriginalForOpResults + iterArgNumber << "\n");
867863
rewriter.replaceAllUsesWith(
868864
paddedValueBeforeHoisting,
869-
newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber));
865+
newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
870866

871867
return extracted;
872868
}

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,9 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
569569
scf::ForOp outerMostLoop = loops.front();
570570
if (destinationInitArg &&
571571
(*destinationInitArg)->getOwner() == outerMostLoop) {
572-
std::optional<unsigned> iterArgNumber =
573-
outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
572+
unsigned iterArgNumber =
573+
outerMostLoop.getResultForOpOperand(**destinationInitArg)
574+
.getResultNumber();
574575
int64_t resultNumber = fusableProducer.getResultNumber();
575576
if (auto dstOp =
576577
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
@@ -584,7 +585,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
584585
scf::ForOp innerMostLoop = loops.back();
585586
updateDestinationOperandsForTiledOp(
586587
rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
587-
innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
588+
innerMostLoop.getRegionIterArgs()[iterArgNumber]);
588589
}
589590
}
590591
return scf::SCFFuseProducerOfSliceResult{fusableProducer,

0 commit comments

Comments
 (0)