-
Notifications
You must be signed in to change notification settings - Fork 14k
[mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided by user #91245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided by user #91245
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Prathamesh Tagore (meshtag) ChangesThe verifier was not checking for the case when the user provided shape in output_shape is different than the one inferred from output type. Fix this. Full diff: https://github.com/llvm/llvm-project/pull/91245.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 393f73dc65cd8d..e6a93bf42199a4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2353,6 +2353,16 @@ LogicalResult ExpandShapeOp::verify() {
<< " dynamic dims while output_shape has " << getOutputShape().size()
<< " values";
+ // Verify if provided output shapes are in agreement with output type.
+ DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
+ ArrayRef<int64_t> resShape = getResult().getType().getShape();
+ unsigned staticShapeNum = 0;
+
+ for (unsigned i = 0, e = resShape.size(); i < e; ++i)
+ if (!ShapedType::isDynamic(resShape[i]) &&
+ resShape[i] != staticOutputShapes[staticShapeNum++])
+ emitOpError("invalid output shape provided at pos ") << i;
+
return success();
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 70c96aad9555ef..0f533cb95a0ca9 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1103,3 +1103,14 @@ func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>)
: memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
return
}
+
+// -----
+
+func.func @expand_shape_invalid_output_shape(
+ %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
+ // expected-error @+1 {{invalid output shape provided at pos 2}}
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 15, 21] :
+ memref<30x20xf32, strided<[4000, 2], offset: 100>>
+ into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
+ return
+}
|
891740e
to
2bf7ced
Compare
Can someone help me land this commit. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM too
…ut_shape is provided by user The verifier was not checking for the case when the user provided shape in output_shape is different than the one inferred from output type. Fix this.
2bf7ced
to
9b6844b
Compare
unsigned staticShapeNum = 0; | ||
|
||
for (auto [pos, shape] : llvm::enumerate(resShape)) | ||
if (!ShapedType::isDynamic(shape) && | ||
shape != staticOutputShapes[staticShapeNum++]) | ||
emitOpError("invalid output shape provided at pos ") << pos; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? Concretely, this verifier is producing an error when resShape
and staticOutputShape
are equal, and contain a mix of static and kDynamic
values. Indeed, since in C++ the &&
operator is a sequence point, the staticShapeNum++
increment is only happening if the left hand side, !ShapedType::isDynamic(shape)
, evaluated to true. In other words, this verifier code requires staticOutputShape
to only list static dimensions and omit kDynamic
values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my understanding is correct, I would suggest this modification:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78201ae29cd9..4f33137dcff4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2356,11 +2356,8 @@ LogicalResult ExpandShapeOp::verify() {
// Verify if provided output shapes are in agreement with output type.
DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
ArrayRef<int64_t> resShape = getResult().getType().getShape();
- unsigned staticShapeNum = 0;
-
for (auto [pos, shape] : llvm::enumerate(resShape))
- if (!ShapedType::isDynamic(shape) &&
- shape != staticOutputShapes[staticShapeNum++])
+ if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos])
emitOpError("invalid output shape provided at pos ") << pos;
return success();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense... does this fix the issue you were hitting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it does.
Torch-mlir integration is currently blocked on `memref.expand_shape` verifier errors of the form ``` 'memref.expand_shape' op invalid output shape provided at pos 1 ``` The verifier code generating these errors was introduced in #91245. I have commented there why I believe it's incorrect. This PR has my suggested fix. Unfortunately, this does not seem to be directly testable on `memref` IR, because `static_output_shape` is not directly exposed in the custom assembly format.
The verifier was not checking for the case when the user provided shape in output_shape is different than the one inferred from output type. Fix this.