Skip to content

Commit 89f47c7

Browse files
[mlir][Interfaces] Clean up DestinationStyleOpInterface
* "init" operands are specified with `MutableOperandRange` (which gives access to the underlying `OpOperand *`). No more magic numbers. * Remove most interface methods and make them helper functions. Only `getInitsMutable` should be implemented. * Provide separate helper functions for accessing mutable/immutable operands (`OpOperand`/`Value`, in line with llvm#66515): `getInitsMutable` and `getInits` (same naming convention as auto-generated op accessors). `getInputOperands` was not renamed because this function cannot return a `MutableOperandRange` (because the operands are not necessarily consecutive). `OpOperandVector` is no longer needed. * The new `getDpsInits`/`getDpsInitsMutable` is more efficient than the old `getDpsInitOperands` because no `SmallVector` is created. The new functions return a range of operands. * Fix a bug in `getDpsInputOperands`: out-of-bounds operands were potentially returned. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
1 parent 7fcbb64 commit 89f47c7

34 files changed

+356
-484
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,7 @@ def Bufferization_MaterializeInDestinationOp
264264
return ::llvm::cast<RankedTensorType>(getResult().getType());
265265
}
266266

267-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
268-
return {1, 2}; // `dest` operand
269-
}
267+
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
270268
}];
271269

272270
let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,12 +555,12 @@ def LinalgStructuredInterface
555555
are expection. For example, in `map` output operand isn't used in
556556
the block.
557557
}],
558-
/*retTy=*/"OpOperandVector",
558+
/*retTy=*/"::llvm::SmallVector<OpOperand *>",
559559
/*methodName=*/"getOpOperandsMatchingBBargs",
560560
/*args=*/(ins),
561561
/*methodBody=*/"",
562562
/*defaultImplementation=*/[{
563-
OpOperandVector result;
563+
::llvm::SmallVector<OpOperand *> result;
564564
result.reserve($_op->getNumOperands());
565565
llvm::transform(
566566
this->getOperation()->getOpOperands(),

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
149149
int64_t getOutputOperandRank() {
150150
return getOutputOperandType().getRank();
151151
}
152-
// Method to implement DestinationStyleOpInterface.
153-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
154-
std::pair<unsigned, unsigned> outputsIndexAndLength =
155-
getODSOperandIndexAndLength(1);
156-
return std::make_pair<int64_t, int64_t>(
157-
outputsIndexAndLength.first,
158-
outputsIndexAndLength.first + outputsIndexAndLength.second);
159-
}
152+
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
160153
}];
161154
let hasVerifier = 1;
162155
}

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
207207
getRegionBuilder() {
208208
return nullptr;
209209
}
210-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
211-
int64_t getNumOperands = this->getNumOperands();
212-
return {getNumOperands - getOutputs().size(), getNumOperands};
213-
}
210+
211+
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
214212
}];
215213

216214
let hasCanonicalizer = 1;
@@ -283,11 +281,9 @@ def MapOp : LinalgStructuredBase_Op<"map", [
283281
}
284282

285283
// Implement functions necessary for DestinationStyleOpInterface.
286-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
287-
int64_t getNumOperands = this->getNumOperands();
288-
return {getNumOperands - 1, getNumOperands};
289-
}
290-
OpOperandVector getOpOperandsMatchingBBargs() {
284+
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
285+
286+
SmallVector<OpOperand *> getOpOperandsMatchingBBargs() {
291287
return getDpsInputOperands();
292288
}
293289

@@ -381,9 +377,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
381377
getRegionBuilder() {
382378
return nullptr;
383379
}
384-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
385-
return {getInits().size(), getNumOperands()};
386-
}
380+
MutableOperandRange getDpsInitsMutable() { return getInitsMutable(); }
387381
}];
388382

389383
let hasCustomAssemblyFormat = 1;
@@ -446,10 +440,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
446440
}
447441

448442
// Implement functions necessary for DestinationStyleOpInterface.
449-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
450-
int64_t getNumOperands = this->getNumOperands();
451-
return {getNumOperands - 1, getNumOperands};
452-
}
443+
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
453444

454445
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
455446
mlir::ArrayRef<mlir::NamedAttribute>)>
@@ -517,10 +508,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
517508
}
518509

519510
// Implement functions necessary for DestinationStyleOpInterface.
520-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
521-
int64_t getNumOperands = this->getNumOperands();
522-
return {getNumOperands - 1, getNumOperands};
523-
}
511+
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
524512

525513
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
526514
mlir::ArrayRef<mlir::NamedAttribute>)>

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -750,9 +750,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
750750
}];
751751

752752
let extraClassDeclaration = [{
753-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
754-
return {1, 2}; // `dest` operand
755-
}
753+
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
756754
}];
757755

758756
let hasFolder = 1;
@@ -892,9 +890,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
892890
/// and `strides` operands.
893891
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
894892

895-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
896-
return {1, 2}; // `dest` operand
897-
}
893+
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
898894
}];
899895

900896
let hasCanonicalizer = 1;
@@ -1714,10 +1710,7 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
17141710
RankedTensorType getDestType() {
17151711
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
17161712

1717-
/// Return position for init operand. Init operand is `dest`.
1718-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
1719-
return {1, 2}; // `dest` operand
1720-
}
1713+
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
17211714

17221715
/// Interface method for ConditionallySpeculatable.
17231716
Speculation::Speculatability getSpeculatability();

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,8 +1330,8 @@ def Vector_TransferReadOp :
13301330
// MaskableOpInterface methods.
13311331
bool supportsPassthru() { return true; }
13321332

1333-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
1334-
return {0, 0}; // empty range (no init operands)
1333+
MutableOperandRange getDpsInitsMutable() {
1334+
return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
13351335
}
13361336
}];
13371337

@@ -1494,9 +1494,7 @@ def Vector_TransferWriteOp :
14941494
/// ops of other dialects.
14951495
Value getValue() { return getVector(); }
14961496

1497-
std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
1498-
return {1, 2}; // `source` operand
1499-
}
1497+
MutableOperandRange getDpsInitsMutable() { return getSourceMutable(); }
15001498
}];
15011499

15021500
let hasFolder = 1;

mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717
#include "llvm/ADT/SmallVector.h"
1818

1919
namespace mlir {
20-
/// OpOperand vector that implicitly converts to a Value vector.
21-
struct OpOperandVector : public llvm::SmallVector<OpOperand *> {
22-
operator SmallVector<Value>();
23-
};
24-
2520
namespace detail {
2621
/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface
2722
LogicalResult verifyDestinationStyleOpInterface(Operation *op);

0 commit comments

Comments
 (0)