Skip to content

Commit 2e39f5d

Browse files
committed
Use constant total size if available
1 parent a372a07 commit 2e39f5d

File tree

1 file changed

+4
-26
lines changed

1 file changed

+4
-26
lines changed

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -78,29 +78,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
7878

7979
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
8080
int64_t scaler = dstBits / srcBits;
81-
82-
// If all strides and sizes are constant, we can compute the result
83-
// directly without creating the AffineMaxOp.
84-
int64_t constResult = 0;
85-
int64_t constStride = 0;
86-
int64_t constSize = 0;
87-
bool isAllConstant = true;
88-
for (unsigned i = 0; i < sourceRank; ++i) {
89-
if (auto constantStride = getConstantIntValue(strides[i])) {
90-
constStride = *constantStride;
91-
} else {
92-
isAllConstant = false;
93-
break;
94-
}
95-
if (auto constantSize = getConstantIntValue(sizes[i])) {
96-
constSize = *constantSize;
97-
} else {
98-
isAllConstant = false;
99-
break;
100-
}
101-
constResult = std::max(constResult, constStride * constSize / scaler);
102-
}
103-
10481
size_t symbolIndex = 0;
10582
SmallVector<Value> values;
10683
SmallVector<AffineExpr> productExpressions;
@@ -129,10 +106,11 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
129106
builder.getContext());
130107

131108
OpFoldResult linearizedSize;
132-
if (isAllConstant) {
133-
linearizedSize = builder.getIndexAttr(constResult);
109+
Value totalSize =
110+
builder.createOrFold<affine::AffineMaxOp>(loc, maxMap, values);
111+
if (auto constantSize = getConstantIntValue(totalSize)) {
112+
linearizedSize = builder.getIndexAttr(*constantSize);
134113
} else {
135-
Value totalSize = builder.create<affine::AffineMaxOp>(loc, maxMap, values);
136114
linearizedSize = totalSize;
137115
}
138116

0 commit comments

Comments
 (0)