Skip to content

Commit a0fdb38

Browse files
authored
[mlir][linalg] Add more precise memory effects to linalg op (#92079)
This patch add more precise memory effect to linalg op. Including the following points: 1. Remove the read side effects for operands that are not used. 2. Set the effect for all side effects to "full".
1 parent 34ba1c0 commit a0fdb38

File tree

3 files changed

+52
-27
lines changed

3 files changed

+52
-27
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,29 +1122,38 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
11221122
static void getGenericEffectsImpl(
11231123
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11241124
&effects,
1125-
ValueRange results, const ValueRange inputOperands,
1126-
ValueRange outputOperands) {
1127-
for (auto operand : inputOperands) {
1125+
LinalgOp linalgOp) {
1126+
SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
1127+
for (auto [index, operand] : llvm::enumerate(inputOperands)) {
11281128
if (!llvm::isa<MemRefType>(operand.getType()))
11291129
continue;
1130-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1131-
SideEffects::DefaultResource::get());
1130+
if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
1131+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
1132+
/*effectOnFullRegion=*/true,
1133+
SideEffects::DefaultResource::get());
1134+
}
11321135
}
1133-
for (auto operand : outputOperands) {
1136+
unsigned inputOperandSize = inputOperands.size();
1137+
1138+
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
11341139
if (!llvm::isa<MemRefType>(operand.getType()))
11351140
continue;
1136-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1137-
SideEffects::DefaultResource::get());
1138-
effects.emplace_back(MemoryEffects::Write::get(), operand,
1141+
if (linalgOp.payloadUsesValueFromOperand(
1142+
&linalgOp->getOpOperand(index + inputOperandSize))) {
1143+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
1144+
/*effectOnFullRegion=*/true,
1145+
SideEffects::DefaultResource::get());
1146+
}
1147+
effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
1148+
/*effectOnFullRegion=*/true,
11391149
SideEffects::DefaultResource::get());
11401150
}
11411151
}
11421152

11431153
void GenericOp::getEffects(
11441154
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11451155
&effects) {
1146-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1147-
getDpsInits());
1156+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
11481157
}
11491158

11501159
LogicalResult GenericOp::verify() { return success(); }
@@ -1492,8 +1501,7 @@ ArrayAttr MapOp::getIndexingMaps() {
14921501
void MapOp::getEffects(
14931502
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
14941503
&effects) {
1495-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1496-
getDpsInits());
1504+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
14971505
}
14981506

14991507
//===----------------------------------------------------------------------===//
@@ -1561,8 +1569,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
15611569
void ReduceOp::getEffects(
15621570
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
15631571
&effects) {
1564-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1565-
getDpsInits());
1572+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
15661573
}
15671574

15681575
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1846,8 +1853,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
18461853
void TransposeOp::getEffects(
18471854
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
18481855
&effects) {
1849-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1850-
getDpsInits());
1856+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
18511857
}
18521858

18531859
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
@@ -1984,8 +1990,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
19841990
void BroadcastOp::getEffects(
19851991
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
19861992
&effects) {
1987-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1988-
getDpsInits());
1993+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
19891994
}
19901995

19911996
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2513,8 +2518,23 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
25132518
void SoftmaxOp::getEffects(
25142519
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
25152520
&effects) {
2516-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2517-
getDpsInits());
2521+
for (Value operand : getDpsInputs()) {
2522+
if (!llvm::isa<MemRefType>(operand.getType()))
2523+
continue;
2524+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
2525+
/*effectOnFullRegion=*/true,
2526+
SideEffects::DefaultResource::get());
2527+
}
2528+
for (Value operand : getDpsInits()) {
2529+
if (!llvm::isa<MemRefType>(operand.getType()))
2530+
continue;
2531+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
2532+
/*effectOnFullRegion=*/true,
2533+
SideEffects::DefaultResource::get());
2534+
effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
2535+
/*effectOnFullRegion=*/true,
2536+
SideEffects::DefaultResource::get());
2537+
}
25182538
}
25192539

25202540
// Helper functions for softmax decomposition.

mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,16 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
7676
// new op. Since the new op does not have any tensor results, it does not
7777
// return anything.
7878
assert(op->getNumRegions() == 1 && "expected that op has 1 region");
79-
auto newOp = cast<DestinationStyleOpInterface>(cloneWithoutRegions(
80-
rewriter, op, /*newResultTypes=*/TypeRange{}, newOperands));
81-
rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
82-
newOp->getRegion(0).begin());
79+
OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{},
80+
op->getAttrs());
81+
state.addRegion();
82+
Operation *newOp = Operation::create(state);
83+
newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
84+
op->getRegion(0).getBlocks());
85+
86+
// We don't want the rewriter tracks an incomplete operation, so insert new
87+
// operation after op was fully constructed.
88+
rewriter.insert(newOp);
8389

8490
// Replace the results of the old op with the new output buffers.
8591
replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,7 @@ LogicalResult {0}::fold(FoldAdaptor,
667667
void {0}::getEffects(SmallVectorImpl<
668668
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
669669
if (hasPureTensorSemantics()) return;
670-
getGenericEffectsImpl(effects,
671-
getOperation()->getResults(), getDpsInputs(), getDpsInits());
670+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
672671
}
673672
)FMT";
674673

0 commit comments

Comments
 (0)