@@ -1083,17 +1083,34 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1083
1083
return rewriter.notifyMatchFailure (
1084
1084
dim, " Dim op is not defined by a reshape op." );
1085
1085
1086
+ // dim of a memref reshape can be folded if dim.getIndex() dominates the
1087
+ // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1088
+ // cheaply check that either of the following conditions hold:
1089
+ // 1. dim.getIndex() is defined in the same block as reshape but before
1090
+ // reshape.
1091
+ // 2. dim.getIndex() is defined in a parent block of
1092
+ // reshape.
1093
+
1094
+ // Check condition 1
1086
1095
if (dim.getIndex ().getParentBlock () == reshape->getBlock ()) {
1087
1096
if (auto *definingOp = dim.getIndex ().getDefiningOp ()) {
1088
- if (reshape->isBeforeInBlock (definingOp))
1097
+ if (reshape->isBeforeInBlock (definingOp)) {
1089
1098
return rewriter.notifyMatchFailure (
1090
1099
dim,
1091
1100
" dim.getIndex is not defined before reshape in the same block." );
1092
- } // else dim.getIndex is a block argument to reshape->getBlock
1093
- } else if (!dim.getIndex ().getParentRegion ()->isProperAncestor (
1094
- reshape->getParentRegion ()))
1101
+ }
1102
+ } // else dim.getIndex is a block argument to reshape->getBlock and
1103
+ // dominates reshape
1104
+ } // Check condition 2
1105
+ else if (dim->getBlock () != reshape->getBlock () &&
1106
+ !dim.getIndex ().getParentRegion ()->isProperAncestor (
1107
+ reshape->getParentRegion ())) {
1108
+ // If dim and reshape are in the same block but dim.getIndex() isn't, we
1109
+ // already know dim.getIndex() dominates reshape without calling
1110
+ // `isProperAncestor`
1095
1111
return rewriter.notifyMatchFailure (
1096
1112
dim, " dim.getIndex does not dominate reshape." );
1113
+ }
1097
1114
1098
1115
// Place the load directly after the reshape to ensure that the shape memref
1099
1116
// was not mutated.
0 commit comments