Skip to content

Commit 9d40a4f

Browse files
cxy-1993chenxunyu
authored andcommitted
[mlir][linalg] Add more precise memory effects to linalg op
1 parent 79a6a7e commit 9d40a4f

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,29 +1103,38 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
11031103
static void getGenericEffectsImpl(
11041104
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11051105
&effects,
1106-
ValueRange results, const ValueRange inputOperands,
1107-
ValueRange outputOperands) {
1108-
for (auto operand : inputOperands) {
1106+
LinalgOp linalgOp) {
1107+
ValueRange inputOperands = linalgOp.getDpsInputs();
1108+
for (auto [index, operand] : llvm::enumerate(inputOperands)) {
11091109
if (!llvm::isa<MemRefType>(operand.getType()))
11101110
continue;
1111-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1112-
SideEffects::DefaultResource::get());
1111+
if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
1112+
effects.emplace_back(MemoryEffects::Read::get(), operand, 0, true,
1113+
SideEffects::DefaultResource::get());
1114+
}
11131115
}
1114-
for (auto operand : outputOperands) {
1116+
unsigned inputOperandSize = inputOperands.size();
1117+
unsigned usedOutputSize =
1118+
linalgOp.getOpOperandsMatchingBBargs().size() - inputOperandSize;
1119+
1120+
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
11151121
if (!llvm::isa<MemRefType>(operand.getType()))
11161122
continue;
1117-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1118-
SideEffects::DefaultResource::get());
1119-
effects.emplace_back(MemoryEffects::Write::get(), operand,
1123+
if (index < usedOutputSize &&
1124+
linalgOp.payloadUsesValueFromOperand(
1125+
&linalgOp->getOpOperand(index + inputOperandSize))) {
1126+
effects.emplace_back(MemoryEffects::Read::get(), operand, 0, true,
1127+
SideEffects::DefaultResource::get());
1128+
}
1129+
effects.emplace_back(MemoryEffects::Write::get(), operand, 0, true,
11201130
SideEffects::DefaultResource::get());
11211131
}
11221132
}
11231133

11241134
void GenericOp::getEffects(
11251135
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11261136
&effects) {
1127-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1128-
getDpsInits());
1137+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
11291138
}
11301139

11311140
LogicalResult GenericOp::verify() { return success(); }
@@ -1473,8 +1482,7 @@ ArrayAttr MapOp::getIndexingMaps() {
14731482
void MapOp::getEffects(
14741483
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
14751484
&effects) {
1476-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1477-
getDpsInits());
1485+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
14781486
}
14791487

14801488
//===----------------------------------------------------------------------===//
@@ -1542,8 +1550,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
15421550
void ReduceOp::getEffects(
15431551
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
15441552
&effects) {
1545-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1546-
getDpsInits());
1553+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
15471554
}
15481555

15491556
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1827,8 +1834,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
18271834
void TransposeOp::getEffects(
18281835
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
18291836
&effects) {
1830-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1831-
getDpsInits());
1837+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
18321838
}
18331839

18341840
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
@@ -1965,8 +1971,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
19651971
void BroadcastOp::getEffects(
19661972
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
19671973
&effects) {
1968-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1969-
getDpsInits());
1974+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
19701975
}
19711976

19721977
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2494,8 +2499,7 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
24942499
void SoftmaxOp::getEffects(
24952500
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
24962501
&effects) {
2497-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2498-
getDpsInits());
2502+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
24992503
}
25002504

25012505
// Helper functions for softmax decomposition.

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
659659
void {0}::getEffects(SmallVectorImpl<
660660
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
661661
if (hasPureTensorSemantics()) return;
662-
getGenericEffectsImpl(effects,
663-
getOperation()->getResults(), getDpsInputs(), getDpsInits());
662+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
664663
}
665664
)FMT";
666665

0 commit comments

Comments
 (0)