Skip to content

Commit 43ed2f0

Browse files
cxy-1993chenxunyu
authored andcommitted
[mlir][linalg] Add more precise memory effects to linalg op
1 parent 79a6a7e commit 43ed2f0

File tree

3 files changed

+52
-27
lines changed

3 files changed

+52
-27
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,29 +1103,38 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
11031103
static void getGenericEffectsImpl(
11041104
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11051105
&effects,
1106-
ValueRange results, const ValueRange inputOperands,
1107-
ValueRange outputOperands) {
1108-
for (auto operand : inputOperands) {
1106+
LinalgOp linalgOp) {
1107+
SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
1108+
for (auto [index, operand] : llvm::enumerate(inputOperands)) {
11091109
if (!llvm::isa<MemRefType>(operand.getType()))
11101110
continue;
1111-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1112-
SideEffects::DefaultResource::get());
1111+
if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
1112+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
1113+
/*effectOnFullRegion=*/true,
1114+
SideEffects::DefaultResource::get());
1115+
}
11131116
}
1114-
for (auto operand : outputOperands) {
1117+
unsigned inputOperandSize = inputOperands.size();
1118+
1119+
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
11151120
if (!llvm::isa<MemRefType>(operand.getType()))
11161121
continue;
1117-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1118-
SideEffects::DefaultResource::get());
1119-
effects.emplace_back(MemoryEffects::Write::get(), operand,
1122+
if (linalgOp.payloadUsesValueFromOperand(
1123+
&linalgOp->getOpOperand(index + inputOperandSize))) {
1124+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
1125+
/*effectOnFullRegion=*/true,
1126+
SideEffects::DefaultResource::get());
1127+
}
1128+
effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
1129+
/*effectOnFullRegion=*/true,
11201130
SideEffects::DefaultResource::get());
11211131
}
11221132
}
11231133

11241134
void GenericOp::getEffects(
11251135
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11261136
&effects) {
1127-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1128-
getDpsInits());
1137+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
11291138
}
11301139

11311140
LogicalResult GenericOp::verify() { return success(); }
@@ -1473,8 +1482,7 @@ ArrayAttr MapOp::getIndexingMaps() {
14731482
void MapOp::getEffects(
14741483
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
14751484
&effects) {
1476-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1477-
getDpsInits());
1485+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
14781486
}
14791487

14801488
//===----------------------------------------------------------------------===//
@@ -1542,8 +1550,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
15421550
void ReduceOp::getEffects(
15431551
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
15441552
&effects) {
1545-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1546-
getDpsInits());
1553+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
15471554
}
15481555

15491556
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1827,8 +1834,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
18271834
void TransposeOp::getEffects(
18281835
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
18291836
&effects) {
1830-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1831-
getDpsInits());
1837+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
18321838
}
18331839

18341840
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
@@ -1965,8 +1971,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
19651971
void BroadcastOp::getEffects(
19661972
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
19671973
&effects) {
1968-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1969-
getDpsInits());
1974+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
19701975
}
19711976

19721977
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2494,8 +2499,23 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
24942499
void SoftmaxOp::getEffects(
24952500
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
24962501
&effects) {
2497-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2498-
getDpsInits());
2502+
for (Value operand : getDpsInputs()) {
2503+
if (!llvm::isa<MemRefType>(operand.getType()))
2504+
continue;
2505+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
2506+
/*effectOnFullRegion=*/true,
2507+
SideEffects::DefaultResource::get());
2508+
}
2509+
for (Value operand : getDpsInits()) {
2510+
if (!llvm::isa<MemRefType>(operand.getType()))
2511+
continue;
2512+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
2513+
/*effectOnFullRegion=*/true,
2514+
SideEffects::DefaultResource::get());
2515+
effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
2516+
/*effectOnFullRegion=*/true,
2517+
SideEffects::DefaultResource::get());
2518+
}
24992519
}
25002520

25012521
// Helper functions for softmax decomposition.

mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,16 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
7676
// new op. Since the new op does not have any tensor results, it does not
7777
// return anything.
7878
assert(op->getNumRegions() == 1 && "expected that op has 1 region");
79-
auto newOp = cast<DestinationStyleOpInterface>(cloneWithoutRegions(
80-
rewriter, op, /*newResultTypes=*/TypeRange{}, newOperands));
81-
rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
82-
newOp->getRegion(0).begin());
79+
OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{},
80+
op->getAttrs());
81+
state.addRegion();
82+
Operation *newOp = Operation::create(state);
83+
newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
84+
op->getRegion(0).getBlocks());
85+
86+
// We don't want the rewriter tracks an incomplete operation, so insert new
87+
// operation after op was fully constructed.
88+
rewriter.insert(newOp);
8389

8490
// Replace the results of the old op with the new output buffers.
8591
replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
659659
void {0}::getEffects(SmallVectorImpl<
660660
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
661661
if (hasPureTensorSemantics()) return;
662-
getGenericEffectsImpl(effects,
663-
getOperation()->getResults(), getDpsInputs(), getDpsInits());
662+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
664663
}
665664
)FMT";
666665

0 commit comments

Comments
 (0)