Skip to content

Commit d7f3bd2

Browse files
committed
[Task] : Add comments + enhance check for index in parent block of reshape.
1 parent e50232a commit d7f3bd2

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,17 +1083,34 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
10831083
return rewriter.notifyMatchFailure(
10841084
dim, "Dim op is not defined by a reshape op.");
10851085

1086+
// dim of a memref reshape can be folded if dim.getIndex() dominates the
1087+
// reshape. Instead of using `DominanceInfo` (which is usually costly) we
1088+
// cheaply check that either of the following conditions hold:
1089+
// 1. dim.getIndex() is defined in the same block as reshape but before
1090+
// reshape.
1091+
// 2. dim.getIndex() is defined in a parent block of
1092+
// reshape.
1093+
1094+
// Check condition 1
10861095
if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
10871096
if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1088-
if (reshape->isBeforeInBlock(definingOp))
1097+
if (reshape->isBeforeInBlock(definingOp)) {
10891098
return rewriter.notifyMatchFailure(
10901099
dim,
10911100
"dim.getIndex is not defined before reshape in the same block.");
1092-
} // else dim.getIndex is a block argument to reshape->getBlock
1093-
} else if (!dim.getIndex().getParentRegion()->isProperAncestor(
1094-
reshape->getParentRegion()))
1101+
}
1102+
} // else dim.getIndex is a block argument to reshape->getBlock and
1103+
// dominates reshape
1104+
} // Check condition 2
1105+
else if (dim->getBlock() != reshape->getBlock() &&
1106+
!dim.getIndex().getParentRegion()->isProperAncestor(
1107+
reshape->getParentRegion())) {
1108+
// If dim and reshape are in the same block but dim.getIndex() isn't, we
1109+
// already know dim.getIndex() dominates reshape without calling
1110+
// `isProperAncestor`
10951111
return rewriter.notifyMatchFailure(
10961112
dim, "dim.getIndex does not dominate reshape.");
1113+
}
10971114

10981115
// Place the load directly after the reshape to ensure that the shape memref
10991116
// was not mutated.

0 commit comments

Comments
 (0)