Skip to content

Commit dabdec1

Browse files
authored
Fix memref.expand_shape verifier (#91501)
Torch-mlir integration is currently blocked on `memref.expand_shape` verifier errors of the form ``` 'memref.expand_shape' op invalid output shape provided at pos 1 ``` The verifier code generating these errors was introduced in #91245. I have commented there why I believe it's incorrect. This PR has my suggested fix. Unfortunately, this does not seem to be directly testable on `memref` IR, because `static_output_shape` is not directly exposed in the custom assembly format.
1 parent bb6df08 commit dabdec1

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,12 +2356,11 @@ LogicalResult ExpandShapeOp::verify() {
23562356
// Verify if provided output shapes are in agreement with output type.
23572357
DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
23582358
ArrayRef<int64_t> resShape = getResult().getType().getShape();
2359-
unsigned staticShapeNum = 0;
2360-
2361-
for (auto [pos, shape] : llvm::enumerate(resShape))
2362-
if (!ShapedType::isDynamic(shape) &&
2363-
shape != staticOutputShapes[staticShapeNum++])
2364-
emitOpError("invalid output shape provided at pos ") << pos;
2359+
for (auto [pos, shape] : llvm::enumerate(resShape)) {
2360+
if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2361+
return emitOpError("invalid output shape provided at pos ") << pos;
2362+
}
2363+
}
23652364

23662365
return success();
23672366
}

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
502502
// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
503503
func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
504504
%subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
505-
%expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [1, 16, %sz0, 1] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
505+
%expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [%sz0, 1, 8, 2] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
506506
%dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
507507

508508
affine.for %arg6 = 0 to %dim step 64 {

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
203203
%arg3: memref<?x42xf32, strided<[42, 1], offset: 0>>,
204204
%arg4: index,
205205
%arg5: index,
206-
%arg6: index) {
206+
%arg6: index,
207+
%arg7: memref<4x?x4xf32>) {
207208
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
208209
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
209210
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -248,6 +249,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
248249
// CHECK-SAME: memref<?xf32, strided<[1]>> into memref<?x42xf32>
249250
%r3 = memref.expand_shape %3 [[0, 1]] output_shape [%arg6, 42] :
250251
memref<?xf32, strided<[1]>> into memref<?x42xf32>
252+
253+
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
254+
%4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
255+
: memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
251256
return
252257
}
253258

0 commit comments

Comments
 (0)