Skip to content

Commit a372a07

Browse files
committed
Fixing the memref linearization size computation
1 parent 673047e commit a372a07

File tree

3 files changed

+73
-22
lines changed

3 files changed

+73
-22
lines changed

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
6666
SmallVector<AffineExpr> symbols(2 * sourceRank);
6767
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
6868
AffineExpr addMulMap = builder.getAffineConstantExpr(0);
69-
AffineExpr mulMap = builder.getAffineConstantExpr(1);
7069

7170
SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
7271

@@ -75,18 +74,70 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
7574
addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
7675
offsetValues[offsetIdx] = indicesVec[i];
7776
offsetValues[offsetIdx + 1] = strides[i];
78-
79-
mulMap = mulMap * symbols[i];
8077
}
8178

8279
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
8380
int64_t scaler = dstBits / srcBits;
84-
mulMap = mulMap.floorDiv(scaler);
81+
82+
// If all strides and sizes are constant, we can compute the result
83+
// directly without creating the AffineMaxOp.
84+
int64_t constResult = 0;
85+
int64_t constStride = 0;
86+
int64_t constSize = 0;
87+
bool isAllConstant = true;
88+
for (unsigned i = 0; i < sourceRank; ++i) {
89+
if (auto constantStride = getConstantIntValue(strides[i])) {
90+
constStride = *constantStride;
91+
} else {
92+
isAllConstant = false;
93+
break;
94+
}
95+
if (auto constantSize = getConstantIntValue(sizes[i])) {
96+
constSize = *constantSize;
97+
} else {
98+
isAllConstant = false;
99+
break;
100+
}
101+
constResult = std::max(constResult, constStride * constSize / scaler);
102+
}
103+
104+
size_t symbolIndex = 0;
105+
SmallVector<Value> values;
106+
SmallVector<AffineExpr> productExpressions;
107+
for (unsigned i = 0; i < sourceRank; ++i) {
108+
AffineExpr strideExpr, sizeExpr;
109+
OpFoldResult stride = strides[i];
110+
OpFoldResult size = sizes[i];
111+
if (auto constantStride = getConstantIntValue(stride)) {
112+
strideExpr = builder.getAffineConstantExpr(*constantStride);
113+
} else {
114+
strideExpr = symbols[symbolIndex++];
115+
values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
116+
}
117+
118+
if (auto constantSize = getConstantIntValue(size)) {
119+
sizeExpr = builder.getAffineConstantExpr(*constantSize);
120+
} else {
121+
sizeExpr = symbols[symbolIndex++];
122+
values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
123+
}
124+
125+
productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
126+
}
127+
AffineMap maxMap = AffineMap::get(
128+
/*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
129+
builder.getContext());
130+
131+
OpFoldResult linearizedSize;
132+
if (isAllConstant) {
133+
linearizedSize = builder.getIndexAttr(constResult);
134+
} else {
135+
Value totalSize = builder.create<affine::AffineMaxOp>(loc, maxMap, values);
136+
linearizedSize = totalSize;
137+
}
85138

86139
OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
87140
builder, loc, addMulMap.floorDiv(scaler), offsetValues);
88-
OpFoldResult linearizedSize =
89-
affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
90141

91142
// Adjust baseOffset by the scale factor (dstBits / srcBits).
92143
AffineExpr s0;

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
104104
%1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
105105
return %1 : i4
106106
}
107-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
107+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
108108
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
109109
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
110110
// CHECK: func @memref_load_i4_dynamic(
111111
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
112112
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
113113
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
114114
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
115-
// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
115+
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
116116
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
117117
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
118118
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -122,15 +122,15 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
122122
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
123123
// CHECK: return %[[TRUNC]]
124124

125-
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
125+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
126126
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
127127
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
128128
// CHECK32: func @memref_load_i4_dynamic(
129129
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
130130
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
131131
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
132132
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
133-
// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
133+
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
134134
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
135135
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
136136
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -399,7 +399,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
399399
memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
400400
return
401401
}
402-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
402+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
403403
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
404404
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
405405
// CHECK: func @memref_store_i4_dynamic(
@@ -408,7 +408,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
408408
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
409409
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
410410
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
411-
// CHECK-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
411+
// CHECK-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
412412
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
413413
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
414414
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
@@ -423,7 +423,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
423423
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
424424
// CHECK: return
425425

426-
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
426+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
427427
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
428428
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
429429
// CHECK32: func @memref_store_i4_dynamic(
@@ -432,7 +432,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
432432
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
433433
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
434434
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
435-
// CHECK32-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
435+
// CHECK32-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
436436
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
437437
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
438438
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,27 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
5858
%1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
5959
return %1 : vector<8xi4>
6060
}
61-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
61+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
6262
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
6363
// CHECK: func.func @vector_load_i4_dynamic(
6464
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
6565
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
6666
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
6767
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
68-
// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
68+
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
6969
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
7070
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
7171
// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
7272
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
7373

74-
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
74+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
7575
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
7676
// CHECK32: func.func @vector_load_i4_dynamic(
7777
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
7878
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
7979
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
8080
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
81-
// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
81+
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
8282
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
8383
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
8484
// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
@@ -450,29 +450,29 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
450450
return
451451
}
452452

453-
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
453+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
454454
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
455455
// CHECK: func @vector_store_i4_dynamic
456456
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
457457
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
458458
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
459459
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
460460
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
461-
// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
461+
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
462462
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
463463
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
464464
// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
465465
// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
466466

467-
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
467+
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
468468
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
469469
// CHECK32: func @vector_store_i4_dynamic
470470
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
471471
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
472472
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
473473
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
474474
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
475-
// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
475+
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
476476
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
477477
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
478478
// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>

0 commit comments

Comments
 (0)