Skip to content

Commit 86f609f

Browse files
matthias-springerzahiraam
authored andcommitted
[mlir][IR] Change MutableOperandRange::operator[] to return an OpOperand & (llvm#66515)
`operator[]` returns `OpOperand &` instead of `Value`. * This allows users to get OpOperands by name instead of "magic" number. E.g., `extractSliceOp->getOpOperand(0)` can be written as `extractSliceOp.getSourceMutable()[0]`. * `OperandRange` provides a read-only API to operands: `operator[]` returns `Value`. `MutableOperandRange` now provides a mutable API: `operator[]` returns `OpOperand &`, which can be used to set operands. Note: The TableGen code generator could be changed to return `OpOperand &` (instead of `MutableOperandRange`) for non-variadic and non-optional arguments in a subsequent change. Then the `[0]` part in the above example would no longer be necessary.
1 parent 38d5da9 commit 86f609f

File tree

8 files changed

+21
-21
lines changed

8 files changed

+21
-21
lines changed

mlir/include/mlir/IR/ValueRange.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,8 @@ class MutableOperandRange {
162162
/// elements attribute, which contains the sizes of the sub ranges.
163163
MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
164164

165-
/// Returns the value at the given index.
166-
Value operator[](unsigned index) const {
167-
return operator OperandRange()[index];
168-
}
165+
/// Returns the OpOperand at the given index.
166+
OpOperand &operator[](unsigned index) const;
169167

170168
OperandRange::iterator begin() const {
171169
return static_cast<OperandRange>(*this).begin();

mlir/include/mlir/Interfaces/ControlFlowInterfaces.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class SuccessorOperands {
7676
Value operator[](unsigned index) const {
7777
if (isOperandProduced(index))
7878
return Value();
79-
return forwardedOperands[index - producedOperandCount];
79+
return forwardedOperands[index - producedOperandCount].get();
8080
}
8181

8282
/// Get the range of operands that are simply forwarded to the successor.

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -549,22 +549,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
549549

550550
bool MaterializeInDestinationOp::bufferizesToMemoryRead(
551551
OpOperand &opOperand, const AnalysisState &state) {
552-
if (&opOperand == &getOperation()->getOpOperand(0) /*source*/)
553-
return true;
554-
return false;
552+
return &opOperand == &getSourceMutable()[0];
555553
}
556554

557555
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
558556
OpOperand &opOperand, const AnalysisState &state) {
559-
if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
560-
return true;
561-
return false;
557+
return &opOperand == &getDestMutable()[0];
562558
}
563559

564560
AliasingValueList
565561
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
566562
const AnalysisState &state) {
567-
if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
563+
if (&opOperand == &getDestMutable()[0])
568564
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
569565
return {};
570566
}

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion
949949
reshapeOp, "failed preconditions of fusion with producer generic op");
950950
}
951951

952-
if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) {
952+
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable()[0])) {
953953
return rewriter.notifyMatchFailure(reshapeOp,
954954
"fusion blocked by control function");
955955
}

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
509509
// 1. Get the producer of the source (potentially walking through
510510
// `iter_args` of nested `scf.for`)
511511
auto [fusableProducer, destinationIterArg] =
512-
getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
512+
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
513513
loops);
514514
if (!fusableProducer)
515515
return std::nullopt;

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
644644
RankedTensorType destType = insertSliceOp.getDestType();
645645

646646
// The source is always read.
647-
if (&opOperand == &op->getOpOperand(0) /*src*/)
647+
if (&opOperand == &insertSliceOp.getSourceMutable()[0])
648648
return true;
649649

650650
// For the destination, it depends...
651-
assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest");
651+
assert(&opOperand == &insertSliceOp.getDestMutable()[0] && "expected dest");
652652

653653
// Dest is not read if it is entirely overwritten. E.g.:
654654
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -851,9 +851,8 @@ struct ReshapeOpInterface
851851
tensor::ReshapeOp> {
852852
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
853853
const AnalysisState &state) const {
854-
if (&opOperand == &op->getOpOperand(1) /* shape */)
855-
return true;
856-
return false;
854+
auto reshapeOp = cast<tensor::ReshapeOp>(op);
855+
return &opOperand == &reshapeOp.getShapeMutable()[0];
857856
}
858857

859858
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -915,7 +914,8 @@ struct ParallelInsertSliceOpInterface
915914

916915
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
917916
const AnalysisState &state) const {
918-
return &opOperand == &op->getOpOperand(1) /*dest*/;
917+
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
918+
return &opOperand == &parallelInsertSliceOp.getDestMutable()[0];
919919
}
920920

921921
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,

mlir/lib/IR/OperationSupport.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,11 @@ void MutableOperandRange::updateLength(unsigned newLength) {
517517
}
518518
}
519519

520+
OpOperand &MutableOperandRange::operator[](unsigned index) const {
521+
assert(index < length && "index is out of bounds");
522+
return owner->getOpOperand(start + index);
523+
}
524+
520525
//===----------------------------------------------------------------------===//
521526
// MutableOperandRangeRange
522527

mlir/lib/Transforms/Utils/CFGToSCF.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ class EdgeMultiplexer {
277277
if (index >= result->second &&
278278
index < result->second + edge.getSuccessor()->getNumArguments()) {
279279
// Original block arguments to the entry block.
280-
newSuccOperands[index] = successorOperands[index - result->second];
280+
newSuccOperands[index] =
281+
successorOperands[index - result->second].get();
281282
continue;
282283
}
283284

0 commit comments

Comments
 (0)