Skip to content

[mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided by user #91245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

meshtag
Copy link
Member

@meshtag meshtag commented May 6, 2024

The verifier was not checking for the case when the user provided shape in output_shape is different than the one inferred from output type. Fix this.

@meshtag meshtag changed the title [mlir][memref] Add verifier check to ensure correct output_shape is provided [mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided May 6, 2024
@llvmbot
Copy link
Member

llvmbot commented May 6, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Prathamesh Tagore (meshtag)

Changes

The verifier was not checking for the case when the user provided shape in output_shape is different than the one inferred from output type. Fix this.


Full diff: https://github.com/llvm/llvm-project/pull/91245.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+10)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (+11)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 393f73dc65cd8d..e6a93bf42199a4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2353,6 +2353,16 @@ LogicalResult ExpandShapeOp::verify() {
            << " dynamic dims while output_shape has " << getOutputShape().size()
            << " values";
 
+  // Verify if provided output shapes are in agreement with output type.
+  DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
+  ArrayRef<int64_t> resShape = getResult().getType().getShape();
+  unsigned staticShapeNum = 0;
+
+  for (unsigned i = 0, e = resShape.size(); i < e; ++i)
+    if (!ShapedType::isDynamic(resShape[i]) &&
+        resShape[i] != staticOutputShapes[staticShapeNum++])
+      emitOpError("invalid output shape provided at pos ") << i;
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 70c96aad9555ef..0f533cb95a0ca9 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1103,3 +1103,14 @@ func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>)
       : memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
   return
 }
+
+// -----
+
+func.func @expand_shape_invalid_output_shape(
+    %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
+  // expected-error @+1 {{invalid output shape provided at pos 2}}
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 15, 21] :
+      memref<30x20xf32, strided<[4000, 2], offset: 100>>
+      into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
+  return
+}

@meshtag meshtag changed the title [mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided [mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided by user May 6, 2024
@meshtag meshtag force-pushed the prathamesh/expand_shape_verifier branch from 891740e to 2bf7ced Compare May 6, 2024 17:29
@meshtag
Copy link
Member Author

meshtag commented May 7, 2024

Can someone help me land this commit. Thanks.

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM too

…ut_shape is provided by user

The verifier was not checking for the case when the user provided shape in output_shape is different than the one inferred from output type. Fix this.
@meshtag meshtag force-pushed the prathamesh/expand_shape_verifier branch from 2bf7ced to 9b6844b Compare May 7, 2024 11:03
@MaheshRavishankar MaheshRavishankar merged commit 54401b4 into llvm:main May 7, 2024
4 checks passed
Comment on lines +2359 to +2365
unsigned staticShapeNum = 0;

for (auto [pos, shape] : llvm::enumerate(resShape))
if (!ShapedType::isDynamic(shape) &&
shape != staticOutputShapes[staticShapeNum++])
emitOpError("invalid output shape provided at pos ") << pos;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Concretely, this verifier is producing an error when resShape and staticOutputShape are equal, and contain a mix of static and kDynamic values. Indeed, since in C++ the && operator is a sequence point, the staticShapeNum++ increment is only happening if the left hand side, !ShapedType::isDynamic(shape), evaluated to true. In other words, this verifier code requires staticOutputShape to only list static dimensions and omit kDynamic values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, I would suggest this modification:

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78201ae29cd9..4f33137dcff4 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2356,11 +2356,8 @@ LogicalResult ExpandShapeOp::verify() {
   // Verify if provided output shapes are in agreement with output type.
   DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
   ArrayRef<int64_t> resShape = getResult().getType().getShape();
-  unsigned staticShapeNum = 0;
-
   for (auto [pos, shape] : llvm::enumerate(resShape))
-    if (!ShapedType::isDynamic(shape) &&
-        shape != staticOutputShapes[staticShapeNum++])
+    if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos])
       emitOpError("invalid output shape provided at pos ") << pos;
 
   return success();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense... does this fix the issue you were hitting?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it does.

bjacob added a commit that referenced this pull request May 8, 2024
Torch-mlir integration is currently blocked on `memref.expand_shape`
verifier errors of the form

```
'memref.expand_shape' op invalid output shape provided at pos 1
```

The verifier code generating these errors was introduced in
#91245. I have commented there
why I believe it's incorrect. This PR has my suggested fix.

Unfortunately, this does not seem to be directly testable on `memref`
IR, because `static_output_shape` is not directly exposed in the custom
assembly format.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants