@@ -78,29 +78,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
78
78
79
79
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
80
80
int64_t scaler = dstBits / srcBits;
81
-
82
- // If all strides and sizes are constant, we can compute the result
83
- // directly without creating the AffineMaxOp.
84
- int64_t constResult = 0 ;
85
- int64_t constStride = 0 ;
86
- int64_t constSize = 0 ;
87
- bool isAllConstant = true ;
88
- for (unsigned i = 0 ; i < sourceRank; ++i) {
89
- if (auto constantStride = getConstantIntValue (strides[i])) {
90
- constStride = *constantStride;
91
- } else {
92
- isAllConstant = false ;
93
- break ;
94
- }
95
- if (auto constantSize = getConstantIntValue (sizes[i])) {
96
- constSize = *constantSize;
97
- } else {
98
- isAllConstant = false ;
99
- break ;
100
- }
101
- constResult = std::max (constResult, constStride * constSize / scaler);
102
- }
103
-
104
81
size_t symbolIndex = 0 ;
105
82
SmallVector<Value> values;
106
83
SmallVector<AffineExpr> productExpressions;
@@ -129,10 +106,11 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
129
106
builder.getContext ());
130
107
131
108
OpFoldResult linearizedSize;
132
- if (isAllConstant) {
133
- linearizedSize = builder.getIndexAttr (constResult);
109
+ Value totalSize =
110
+ builder.createOrFold <affine::AffineMaxOp>(loc, maxMap, values);
111
+ if (auto constantSize = getConstantIntValue (totalSize)) {
112
+ linearizedSize = builder.getIndexAttr (*constantSize);
134
113
} else {
135
- Value totalSize = builder.create <affine::AffineMaxOp>(loc, maxMap, values);
136
114
linearizedSize = totalSize;
137
115
}
138
116
0 commit comments