Skip to content

Commit 6923a31

Browse files
[mlir][IR] Change MutableArrayRange to enumerate OpOperand & (#66622)
In line with #66515, change `MutableArrayRange::begin`/`end` to enumerate `OpOperand &` instead of `Value`. Also remove `ForOp::getIterOpOperands`/`setIterArg`, which are now redundant. Note: `MutableOperandRange` cannot be made a derived class of `indexed_accessor_range_base` (like `OperandRange`), because `MutableOperandRange::assign` can change the number of operands in the range.
1 parent 45bb45f commit 6923a31

File tree

9 files changed

+53
-42
lines changed

9 files changed

+53
-42
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
250250
"expected an index less than the number of region iter args");
251251
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
252252
}
253-
MutableArrayRef<OpOperand> getIterOpOperands() {
254-
return
255-
getOperation()->getOpOperands().drop_front(getNumControlOperands());
256-
}
257253

258254
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
259255
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
260256
void setStep(Value step) { getOperation()->setOperand(2, step); }
261-
void setIterArg(unsigned iterArgNum, Value iterArgValue) {
262-
getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
263-
}
264257

265258
/// Number of induction variables, always 1 for scf::ForOp.
266259
unsigned getNumInductionVars() { return 1; }

mlir/include/mlir/IR/ValueRange.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,9 @@ class MutableOperandRange {
165165
/// Returns the OpOperand at the given index.
166166
OpOperand &operator[](unsigned index) const;
167167

168-
OperandRange::iterator begin() const {
169-
return static_cast<OperandRange>(*this).begin();
170-
}
171-
172-
OperandRange::iterator end() const {
173-
return static_cast<OperandRange>(*this).end();
174-
}
168+
/// Iterators enumerate OpOperands.
169+
MutableArrayRef<OpOperand>::iterator begin() const;
170+
MutableArrayRef<OpOperand>::iterator end() const;
175171

176172
private:
177173
/// Update the length of this range to the one provided.

mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
4747

4848
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
4949

50+
static bool isMemrefOperand(OpOperand &operand) {
51+
return isMemref(operand.get());
52+
}
53+
5054
//===----------------------------------------------------------------------===//
5155
// Backedges analysis
5256
//===----------------------------------------------------------------------===//
@@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
937941

938942
// Add an additional operand for every MemRef for the ownership indicator.
939943
if (!funcWithoutDynamicOwnership) {
940-
unsigned numMemRefs = llvm::count_if(operands, isMemref);
944+
unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
941945
SmallVector<Value> newOperands{OperandRange(operands)};
942946
auto ownershipValues =
943947
deallocOp.getUpdatedConditions().take_front(numMemRefs);

mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ struct CondBranchOpInterface
9696
mapping[retained] = ownership;
9797
}
9898
SmallVector<Value> replacements, ownerships;
99-
for (Value operand : destOperands) {
100-
replacements.push_back(operand);
101-
if (isMemref(operand)) {
102-
assert(mapping.contains(operand) &&
99+
for (OpOperand &operand : destOperands) {
100+
replacements.push_back(operand.get());
101+
if (isMemref(operand.get())) {
102+
assert(mapping.contains(operand.get()) &&
103103
"Should be contained at this point");
104-
ownerships.push_back(mapping[operand]);
104+
ownerships.push_back(mapping[operand.get()]);
105105
}
106106
}
107107
replacements.append(ownerships);

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
932932
assert(operand.get().getType() != replacement.getType() &&
933933
"Expected a different type");
934934
SmallVector<Value> newIterOperands;
935-
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
935+
for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
936936
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
937937
newIterOperands.push_back(replacement);
938938
continue;
@@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
10151015

10161016
LogicalResult matchAndRewrite(ForOp op,
10171017
PatternRewriter &rewriter) const override {
1018-
for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
1018+
for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
10191019
OpOperand &iterOpOperand = std::get<0>(it);
10201020
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
10211021
if (!incomingCast ||

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
325325
/// Helper function for loop bufferization. Return the bufferized values of the
326326
/// given OpOperands. If an operand is not a tensor, return the original value.
327327
static FailureOr<SmallVector<Value>>
328-
getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
328+
getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
329329
const BufferizationOptions &options) {
330330
SmallVector<Value> result;
331331
for (OpOperand &opOperand : operands) {
@@ -598,7 +598,7 @@ struct ForOpInterface
598598

599599
// The new memref init_args of the loop.
600600
FailureOr<SmallVector<Value>> maybeInitArgs =
601-
getBuffers(rewriter, forOp.getIterOpOperands(), options);
601+
getBuffers(rewriter, forOp.getInitArgsMutable(), options);
602602
if (failed(maybeInitArgs))
603603
return failure();
604604
SmallVector<Value> initArgs = *maybeInitArgs;
@@ -816,7 +816,7 @@ struct WhileOpInterface
816816

817817
// The new memref init_args of the loop.
818818
FailureOr<SmallVector<Value>> maybeInitArgs =
819-
getBuffers(rewriter, whileOp->getOpOperands(), options);
819+
getBuffers(rewriter, whileOp.getInitsMutable(), options);
820820
if (failed(maybeInitArgs))
821821
return failure();
822822
SmallVector<Value> initArgs = *maybeInitArgs;

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
500500
MutableArrayRef<scf::ForOp> loops) {
501501
// 1. Get the producer of the source (potentially walking through
502502
// `iter_args` of nested `scf.for`)
503-
auto [fusableProducer, destinationIterArg] =
503+
auto [fusableProducer, destinationInitArg] =
504504
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
505505
loops);
506506
if (!fusableProducer)
@@ -567,17 +567,15 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
567567
// TODO: This can be modeled better if the `DestinationStyleOpInterface`.
568568
// Update to use that when it does become available.
569569
scf::ForOp outerMostLoop = loops.front();
570-
std::optional<unsigned> iterArgNumber;
571-
if (destinationIterArg) {
572-
iterArgNumber =
573-
outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
574-
}
575-
if (iterArgNumber) {
570+
if (destinationInitArg &&
571+
(*destinationInitArg)->getOwner() == outerMostLoop) {
572+
std::optional<unsigned> iterArgNumber =
573+
outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
576574
int64_t resultNumber = fusableProducer.getResultNumber();
577575
if (auto dstOp =
578576
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
579-
outerMostLoop.setIterArg(iterArgNumber.value(),
580-
dstOp.getTiedOpOperand(fusableProducer)->get());
577+
(*destinationInitArg)
578+
->set(dstOp.getTiedOpOperand(fusableProducer)->get());
581579
}
582580
for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
583581
auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);

mlir/lib/IR/OperationSupport.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
522522
return owner->getOpOperand(start + index);
523523
}
524524

525+
MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
526+
return owner->getOpOperands().slice(start, length).begin();
527+
}
528+
529+
MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
530+
return owner->getOpOperands().slice(start, length).end();
531+
}
532+
525533
//===----------------------------------------------------------------------===//
526534
// MutableOperandRangeRange
527535

mlir/lib/Transforms/Utils/CFGToSCF.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
137137
return succOps.getMutableForwardedOperands();
138138
}
139139

140+
/// Return the operand range used to transfer operands from `block` to its
141+
/// successor with the given index.
142+
static OperandRange getSuccessorOperands(Block *block,
143+
unsigned successorIndex) {
144+
return getMutableSuccessorOperands(block, successorIndex);
145+
}
146+
140147
/// Appends all the block arguments from `other` to the block arguments of
141148
/// `block`, copying their types and locations.
142149
static void addBlockArgumentsFromOther(Block *block, Block *other) {
@@ -175,8 +182,14 @@ class Edge {
175182

176183
/// Returns the arguments of this edge that are passed to the block arguments
177184
/// of the successor.
178-
MutableOperandRange getSuccessorOperands() const {
179-
return getMutableSuccessorOperands(fromBlock, successorIndex);
185+
MutableOperandRange getMutableSuccessorOperands() const {
186+
return ::getMutableSuccessorOperands(fromBlock, successorIndex);
187+
}
188+
189+
/// Returns the arguments of this edge that are passed to the block arguments
190+
/// of the successor.
191+
OperandRange getSuccessorOperands() const {
192+
return ::getSuccessorOperands(fromBlock, successorIndex);
180193
}
181194
};
182195

@@ -262,7 +275,7 @@ class EdgeMultiplexer {
262275
assert(result != blockArgMapping.end() &&
263276
"Edge was not originally passed to `create` method.");
264277

265-
MutableOperandRange successorOperands = edge.getSuccessorOperands();
278+
MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();
266279

267280
// Extra arguments are always appended at the end of the block arguments.
268281
unsigned extraArgsBeginIndex =
@@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
666679
// invalidated when mutating the operands through a different
667680
// `MutableOperandRange` of the same operation.
668681
SmallVector<Value> loopHeaderSuccessorOperands =
669-
llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
682+
llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));
670683

671684
// Add all values used in the next iteration to the exit block. Replace
672685
// any uses that are outside the loop with the newly created exit block.
@@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
742755

743756
loopHeaderSuccessorOperands.push_back(argument);
744757
for (Edge edge : successorEdges(latch))
745-
edge.getSuccessorOperands().append(argument);
758+
edge.getMutableSuccessorOperands().append(argument);
746759
}
747760

748761
use.set(blockArgument);
@@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
939952
if (regionEntry->getNumSuccessors() == 1) {
940953
// Single successor we can just splice together.
941954
Block *successor = regionEntry->getSuccessor(0);
942-
for (auto &&[oldValue, newValue] :
943-
llvm::zip(successor->getArguments(),
944-
getMutableSuccessorOperands(regionEntry, 0)))
955+
for (auto &&[oldValue, newValue] : llvm::zip(
956+
successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
945957
oldValue.replaceAllUsesWith(newValue);
946958
regionEntry->getTerminator()->erase();
947959

0 commit comments

Comments
 (0)