Skip to content

Commit 8823e96

Browse files
[mlir][ODS] Change get...Mutable to return OpOperand & for single operands (#66519)
The TableGen code generator now generates C++ code that returns a single `OpOperand &` for `get...Mutable` of operands that are not variadic and not optional. `OpOperand::set`/`assign` can be used to set a value (same as `MutableOperandRange::assign`). This is safer than `MutableOperandRange` because only single values (and no longer `ValueRange`) can be assigned. E.g.: ``` // Assignment of multiple values to non-variadic operand. // Before: Compiles, but produces invalid op. // After: Compilation error. extractSliceOp.getSourceMutable().assign({v1, v2}); ```
1 parent 3b34c11 commit 8823e96

File tree

11 files changed

+44
-20
lines changed

11 files changed

+44
-20
lines changed

mlir/include/mlir/IR/Value.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ class OpOperand : public IROperand<OpOperand, Value> {
268268
/// Return which operand this is in the OpOperand list of the Operation.
269269
unsigned getOperandNumber();
270270

271+
/// Set the current value being used by this operand.
272+
void assign(Value value) { set(value); }
273+
271274
private:
272275
/// Keep the constructor private and accessible to the OperandStorage class
273276
/// only to avoid hard-to-debug typo/programming mistakes.

mlir/include/mlir/IR/ValueRange.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ class MutableOperandRange {
126126
ArrayRef<OperandSegment> operandSegments = std::nullopt);
127127
MutableOperandRange(Operation *owner);
128128

129+
/// Construct a new mutable range for the given OpOperand.
130+
MutableOperandRange(OpOperand &opOperand);
131+
129132
/// Slice this range into a sub range, with the additional operand segment.
130133
MutableOperandRange
131134
slice(unsigned subStart, unsigned subLen,

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,18 +537,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
537537

538538
bool MaterializeInDestinationOp::bufferizesToMemoryRead(
539539
OpOperand &opOperand, const AnalysisState &state) {
540-
return &opOperand == &getSourceMutable()[0];
540+
return &opOperand == &getSourceMutable();
541541
}
542542

543543
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
544544
OpOperand &opOperand, const AnalysisState &state) {
545-
return &opOperand == &getDestMutable()[0];
545+
return &opOperand == &getDestMutable();
546546
}
547547

548548
AliasingValueList
549549
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
550550
const AnalysisState &state) {
551-
if (&opOperand == &getDestMutable()[0])
551+
if (&opOperand == &getDestMutable())
552552
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
553553
return {};
554554
}

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion
949949
reshapeOp, "failed preconditions of fusion with producer generic op");
950950
}
951951

952-
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable()[0])) {
952+
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
953953
return rewriter.notifyMatchFailure(reshapeOp,
954954
"fusion blocked by control function");
955955
}

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
526526
// 1. Get the producer of the source (potentially walking through
527527
// `iter_args` of nested `scf.for`)
528528
auto [fusableProducer, destinationInitArg] =
529-
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
529+
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
530530
loops);
531531
if (!fusableProducer)
532532
return std::nullopt;

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -636,11 +636,11 @@ struct InsertSliceOpInterface
636636
RankedTensorType destType = insertSliceOp.getDestType();
637637

638638
// The source is always read.
639-
if (&opOperand == &insertSliceOp.getSourceMutable()[0])
639+
if (&opOperand == &insertSliceOp.getSourceMutable())
640640
return true;
641641

642642
// For the destination, it depends...
643-
assert(&opOperand == &insertSliceOp.getDestMutable()[0] && "expected dest");
643+
assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
644644

645645
// Dest is not read if it is entirely overwritten. E.g.:
646646
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -840,7 +840,7 @@ struct ReshapeOpInterface
840840
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
841841
const AnalysisState &state) const {
842842
auto reshapeOp = cast<tensor::ReshapeOp>(op);
843-
return &opOperand == &reshapeOp.getShapeMutable()[0];
843+
return &opOperand == &reshapeOp.getShapeMutable();
844844
}
845845

846846
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -917,7 +917,7 @@ struct ParallelInsertSliceOpInterface
917917
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
918918
const AnalysisState &state) const {
919919
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
920-
return &opOperand == &parallelInsertSliceOp.getDestMutable()[0];
920+
return &opOperand == &parallelInsertSliceOp.getDestMutable();
921921
}
922922

923923
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,

mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ struct InsertSliceOpInterface
6363
: public SubsetInsertionOpInterface::ExternalModel<InsertSliceOpInterface,
6464
tensor::InsertSliceOp> {
6565
OpOperand &getSourceOperand(Operation *op) const {
66-
return op->getOpOperand(0);
66+
return cast<tensor::InsertSliceOp>(op).getSourceMutable();
6767
}
6868

6969
bool
@@ -91,11 +91,11 @@ struct ParallelInsertSliceOpInterface
9191
: public SubsetInsertionOpInterface::ExternalModel<
9292
ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> {
9393
OpOperand &getSourceOperand(Operation *op) const {
94-
return op->getOpOperand(0);
94+
return cast<tensor::ParallelInsertSliceOp>(op).getSourceMutable();
9595
}
9696

9797
OpOperand &getDestinationOperand(Operation *op) const {
98-
return op->getOpOperand(1);
98+
return cast<tensor::ParallelInsertSliceOp>(op).getDestMutable();
9999
}
100100

101101
bool

mlir/lib/IR/OperationSupport.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,12 @@ MutableOperandRange::MutableOperandRange(
437437
MutableOperandRange::MutableOperandRange(Operation *owner)
438438
: MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
439439

440+
/// Construct a new mutable range for the given OpOperand.
441+
MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
442+
: MutableOperandRange(opOperand.getOwner(),
443+
/*start=*/opOperand.getOperandNumber(),
444+
/*length=*/1) {}
445+
440446
/// Slice this range into a sub range, with the additional operand segment.
441447
MutableOperandRange
442448
MutableOperandRange::slice(unsigned subStart, unsigned subLen,

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ void LoopBlockOp::getSuccessorRegions(
998998

999999
OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
10001000
assert(point == getBody());
1001-
return getInitMutable();
1001+
return MutableOperandRange(getInitMutable());
10021002
}
10031003

10041004
//===----------------------------------------------------------------------===//

mlir/test/mlir-tblgen/op-decl-and-defs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
9797
// CHECK: ::mlir::Operation::operand_range getODSOperands(unsigned index);
9898
// CHECK: ::mlir::TypedValue<::mlir::IntegerType> getA();
9999
// CHECK: ::mlir::Operation::operand_range getB();
100-
// CHECK: ::mlir::MutableOperandRange getAMutable();
100+
// CHECK: ::mlir::OpOperand &getAMutable();
101101
// CHECK: ::mlir::MutableOperandRange getBMutable();
102102
// CHECK: ::mlir::Operation::result_range getODSResults(unsigned index);
103103
// CHECK: ::mlir::TypedValue<::mlir::IntegerType> getR();

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,14 +2071,26 @@ void OpEmitter::genNamedOperandSetters() {
20712071
continue;
20722072
std::string name = op.getGetterName(operand.name);
20732073

2074-
auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
2075-
? "::mlir::MutableOperandRangeRange"
2076-
: "::mlir::MutableOperandRange",
2077-
name + "Mutable");
2074+
StringRef returnType;
2075+
if (operand.isVariadicOfVariadic()) {
2076+
returnType = "::mlir::MutableOperandRangeRange";
2077+
} else if (operand.isVariableLength()) {
2078+
returnType = "::mlir::MutableOperandRange";
2079+
} else {
2080+
returnType = "::mlir::OpOperand &";
2081+
}
2082+
auto *m = opClass.addMethod(returnType, name + "Mutable");
20782083
ERROR_IF_PRUNED(m, name, op);
20792084
auto &body = m->body();
2080-
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
2081-
<< " auto mutableRange = "
2085+
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n";
2086+
2087+
if (!operand.isVariadicOfVariadic() && !operand.isVariableLength()) {
2088+
// In case of a single operand, return a single OpOperand.
2089+
body << " return getOperation()->getOpOperand(range.first);\n";
2090+
continue;
2091+
}
2092+
2093+
body << " auto mutableRange = "
20822094
"::mlir::MutableOperandRange(getOperation(), "
20832095
"range.first, range.second";
20842096
if (attrSizedOperands) {

0 commit comments

Comments
 (0)