Skip to content

Commit 26b81c4

Browse files
authored
[mlir][memref] Add terminator check to prevent a crash (#141972)
This PR adds terminator check to prevent a crash when invoke `lastNonTerminatorInRegion`. Fixes #137333.
1 parent 40e1f7d commit 26b81c4

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,9 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
398398
/// and is only followed by a terminator. This prevents
399399
/// extending the lifetime of allocations.
400400
static bool lastNonTerminatorInRegion(Operation *op) {
401-
return op->getNextNode() == op->getBlock()->getTerminator() &&
402-
llvm::hasSingleElement(op->getParentRegion()->getBlocks());
401+
return op->getBlock()->mightHaveTerminator() &&
402+
op->getNextNode() == op->getBlock()->getTerminator() &&
403+
op->getParentRegion()->hasOneBlock();
403404
}
404405

405406
/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
@@ -2011,7 +2012,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
20112012
// Second, check the sizes.
20122013
if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
20132014
op.getConstifiedMixedSizes()))
2014-
return false;
2015+
return false;
20152016

20162017
// Finally, check the offset.
20172018
assert(op.getMixedOffsets().size() == 1 &&

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,8 @@ func.func @scopeMerge() {
739739
// CHECK: "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
740740
// CHECK: return
741741

742+
// -----
743+
742744
func.func @scopeMerge2() {
743745
"test.region"() ({
744746
memref.alloca_scope {
@@ -763,6 +765,8 @@ func.func @scopeMerge2() {
763765
// CHECK: return
764766
// CHECK: }
765767

768+
// -----
769+
766770
func.func @scopeMerge3() {
767771
%cnt = "test.count"() : () -> index
768772
"test.region"() ({
@@ -787,6 +791,8 @@ func.func @scopeMerge3() {
787791
// CHECK: return
788792
// CHECK: }
789793

794+
// -----
795+
790796
func.func @scopeMerge4() {
791797
%cnt = "test.count"() : () -> index
792798
"test.region"() ({
@@ -813,6 +819,8 @@ func.func @scopeMerge4() {
813819
// CHECK: return
814820
// CHECK: }
815821

822+
// -----
823+
816824
func.func @scopeMerge5() {
817825
"test.region"() ({
818826
memref.alloca_scope {
@@ -839,6 +847,8 @@ func.func @scopeMerge5() {
839847
// CHECK: return
840848
// CHECK: }
841849

850+
// -----
851+
842852
func.func @scopeInline(%arg : memref<index>) {
843853
%cnt = "test.count"() : () -> index
844854
"test.region"() ({
@@ -855,6 +865,24 @@ func.func @scopeInline(%arg : memref<index>) {
855865

856866
// -----
857867

868+
// Ensure this case not crash.
869+
870+
// CHECK-LABEL: func.func @scope_merge_without_terminator() {
871+
// CHECK: "test.region"()
872+
// CHECK: memref.alloca_scope
873+
func.func @scope_merge_without_terminator() {
874+
"test.region"() ({
875+
memref.alloca_scope {
876+
%cnt = "test.count"() : () -> index
877+
%a = memref.alloca(%cnt) : memref<?xi64>
878+
"test.use"(%a) : (memref<?xi64>) -> ()
879+
}
880+
}) : () -> ()
881+
return
882+
}
883+
884+
// -----
885+
858886
// CHECK-LABEL: func @reinterpret_noop
859887
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
860888
// CHECK-NEXT: return %[[ARG]]

0 commit comments

Comments
 (0)