-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][linalg] Add more precise memory effects to linalg op #92079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-core Author: donald chen (cxy-1993) ChangesThis patch add more precise memory effect to linalg op. Including the following points:
Full diff: https://github.com/llvm/llvm-project/pull/92079.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..5958e1a0f3206 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1103,20 +1103,28 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
static void getGenericEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
- ValueRange results, const ValueRange inputOperands,
+ LinalgOp linalgOp, ValueRange results, const ValueRange inputOperands,
ValueRange outputOperands) {
for (auto operand : inputOperands) {
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
- effects.emplace_back(MemoryEffects::Read::get(), operand,
+ effects.emplace_back(MemoryEffects::Read::get(), 0, true, operand,
SideEffects::DefaultResource::get());
}
- for (auto operand : outputOperands) {
+ unsigned inputOperandSize = inputOperands.size();
+ unsigned usedOutputSize =
+ linalgOp.getOpOperandsMatchingBBargs().size() - inputOperandSize;
+
+ for (auto [index, operand] : llvm::enumerate(outputOperands)) {
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
- effects.emplace_back(MemoryEffects::Read::get(), operand,
- SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), operand,
+ if (index < usedOutputSize &&
+ linalgOp.payloadUsesValueFromOperand(
+ &linalgOp->getOpOperand(index + inputOperandSize))) {
+ effects.emplace_back(MemoryEffects::Read::get(), 0, true, operand,
+ SideEffects::DefaultResource::get());
+ }
+ effects.emplace_back(MemoryEffects::Write::get(), 0, true, operand,
SideEffects::DefaultResource::get());
}
}
@@ -1124,7 +1132,8 @@ static void getGenericEffectsImpl(
void GenericOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -1473,7 +1482,8 @@ ArrayAttr MapOp::getIndexingMaps() {
void MapOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -1542,7 +1552,8 @@ ArrayAttr ReduceOp::getIndexingMaps() {
void ReduceOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -1827,7 +1838,8 @@ ArrayAttr TransposeOp::getIndexingMaps() {
void TransposeOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -1965,7 +1977,8 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
void BroadcastOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -2494,7 +2507,8 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
void SoftmaxOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fe6ad15041126..f3071b81e21cb 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -659,7 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
void {0}::getEffects(SmallVectorImpl<
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
if (hasPureTensorSemantics()) return;
- getGenericEffectsImpl(effects,
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
getOperation()->getResults(), getDpsInputs(), getDpsInits());
}
)FMT";
|
73f84df
to
a3bd52c
Compare
71d6a49
to
8d3b0c8
Compare
779df1f
to
9ccde8b
Compare
Counld you please help me merge this patch into master? @matthias-springer |
if (getOperation()->getRegion(0).empty()) { | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite get this. When the region is empty (is this even allowed?), the operand is considered used? But when it has a terminator not reading from the corresponding block argument, the operand is not considered used? Why? If we assume it is read anyway, it should be marked as such in both cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite get this. When the region is empty (is this even allowed?), the operand is considered used? But when it has a terminator not reading from the corresponding block argument, the operand is not considered used? Why? If we assume it is read anyway, it should be marked as such in both cases.
Thank you for your suggestion! For this issue, here is my understanding: When region is empty, it usually occurs in scenarios where region does not need to interpret the semantics of op, such as linalg.map:
%mapped = linalg.map { arith.addf }
ins(%arg0, %arg1 : tensor<10x100xf32>, tensor<10x100xf32>)
outs(%map_init : tensor<10x100xf32>)
In the example above, the linalg.map has an empty region. This is because, with arith.add present, we do not need to provide additional interpretation for this op, as its implementation simply involves adding the two input operands. In such scenarios where no interpretation is required, we consider all operands to be used in accordance with the semantics of the op.
In linalg operations implemented with regions, it is evident that when the block arguments corresponding to operands are not used, it indicates that they are not being read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
linalg.map { arith.addf }
is just a pretty-printed version. It has a region with an arith.addf
and linalg.yield
terminator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
linalg.map { arith.addf }
is just a pretty-printed version. It has a region with anarith.addf
andlinalg.yield
terminator.
You are right. I reviewed why I thought this way: before this, I didn't add handling for when region was empty, leading to some core dump issues during part of the bufferization tests. This happened because the 'cloneWithoutRegions' function was called during the bufferization process, and in this function, the side effects of 'op' were queried, unfortunately, at this time region was still empty. So, I'd like to solicit opinions from both of you: in this situation, can we conservatively assume that all operands can be read and written to? @matthias-springer @ftynse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @matthias-springer @ftynse,I wanted to send a friendly reminder to re-review the proposal above at your earliest convenience and hope to hear your thoughts on this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this lead to a crash? The bufferization impl for LinalgOps calls cloneWithoutRegions
, but inlines the original region immediately afterwards. There should be no call to payloadUsesValueFromOperand
in-between. (Even if there is one, this function would just return false
and should not crash.)
Basically, the op is invalid without a region body. Generally speaking, helper functions on an op are allowed to function incorrectly if the op is invalid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this lead to a crash?
When the cloneWithoutRegions function is called, the bufferize rewriter will invoke the notifyOperationInserted function. This function will query the side effects of the operation to count the number of memory allocations. Because this patch will call payloadUsesValueFromOperand when querying side effects, and the payloadUsesValueFromOperand of the map operation is implemented as follows:
bool payloadUsesValueFromOperand(OpOperand * opOperand) {
if (isDpsInit(opOperand)) return false;
return !getMatchingBlockArgument(opOperand).use_empty();
}
This will fetch the block inside the op's region, causing a core dump.
Basically, the op is invalid without a region body. Generally speaking, helper functions on an op are allowed to function incorrectly if the op is invalid.
So should we prohibit querying side effects before the block is inserted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. The op is already inserted even though it was not fully constructed yet (region missing). Ideally, we would change the BufferizableOpInterface
impl., so that we insert the op only when it is fully constructed.
I think it's important to keep the LinalgOp
interface clean and we should not have to account for edge cases in the bufferization.
Could you change the implementation in Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
such that it clones the op without regions without a builder. Basically using Operation::create(OpState)
(instead of cloneOpWithoutRegions
). Then inline the region. Then rewriter.insert
. Then the callback in the bufferization will be triggered on a fully constructed and valid op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your perspective, we should keep the implementation within the interface clean. I have already modified the code according to your comment, thanks for your feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ftynse, do you have any further comments on the modifications proposed in this patch?
Our friend ftynse hasn't had time to work on this patch recently, and the issues he raised seem to have been resolved. Could you please help me merge this PR? @matthias-springer |
This patch add more precise memory effect to linalg op. Including the following points: