Skip to content

Commit 54401b4

Browse files
authored
[mlir][memref.expand_shape] Add verifier check to ensure correct output_shape is provided by user (#91245)
The verifier was not checking for the case when the user provided shape in output_shape is different than the one inferred from output type. Fix this.
1 parent ff0c5cc commit 54401b4

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,6 +2353,16 @@ LogicalResult ExpandShapeOp::verify() {
23532353
<< " dynamic dims while output_shape has " << getOutputShape().size()
23542354
<< " values";
23552355

2356+
// Verify if provided output shapes are in agreement with output type.
2357+
DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2358+
ArrayRef<int64_t> resShape = getResult().getType().getShape();
2359+
unsigned staticShapeNum = 0;
2360+
2361+
for (auto [pos, shape] : llvm::enumerate(resShape))
2362+
if (!ShapedType::isDynamic(shape) &&
2363+
shape != staticOutputShapes[staticShapeNum++])
2364+
emitOpError("invalid output shape provided at pos ") << pos;
2365+
23562366
return success();
23572367
}
23582368

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,3 +1103,14 @@ func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>)
11031103
: memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
11041104
return
11051105
}
1106+
1107+
// -----
1108+
1109+
func.func @expand_shape_invalid_output_shape(
1110+
%arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
1111+
// expected-error @+1 {{invalid output shape provided at pos 2}}
1112+
%0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 15, 21] :
1113+
memref<30x20xf32, strided<[4000, 2], offset: 100>>
1114+
into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
1115+
return
1116+
}

0 commit comments

Comments
 (0)