Skip to content

Commit c6f67b8

Browse files
authored
[mlir][affine] Add ValueBoundsOpInterface to [de]linearize_index (#121833)
Since a need for it came up dowstream (in proving that loops run at least once), this commit implements the ValueBoundsOpInterface for affine.delinearize_index and affine.linearize_index, using affine map representations of the operations they perform. These implementations also use information from outer bounds to impose additional constraints when those are available.
1 parent 2015c0a commit c6f67b8

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,64 @@ struct AffineMaxOpInterface
9191
};
9292
};
9393

94+
struct AffineDelinearizeIndexOpInterface
95+
: public ValueBoundsOpInterface::ExternalModel<
96+
AffineDelinearizeIndexOpInterface, AffineDelinearizeIndexOp> {
97+
void populateBoundsForIndexValue(Operation *rawOp, Value value,
98+
ValueBoundsConstraintSet &cstr) const {
99+
auto op = cast<AffineDelinearizeIndexOp>(rawOp);
100+
auto result = cast<OpResult>(value);
101+
assert(result.getOwner() == rawOp &&
102+
"bounded value isn't a result of this delinearize_index");
103+
unsigned resIdx = result.getResultNumber();
104+
105+
AffineExpr linearIdx = cstr.getExpr(op.getLinearIndex());
106+
107+
SmallVector<OpFoldResult> basis = op.getPaddedBasis();
108+
AffineExpr divisor = cstr.getExpr(1);
109+
for (OpFoldResult basisElem : llvm::drop_begin(basis, resIdx + 1))
110+
divisor = divisor * cstr.getExpr(basisElem);
111+
112+
if (resIdx == 0) {
113+
cstr.bound(value) == linearIdx.floorDiv(divisor);
114+
if (!basis.front().isNull())
115+
cstr.bound(value) < cstr.getExpr(basis.front());
116+
return;
117+
}
118+
AffineExpr thisBasis = cstr.getExpr(basis[resIdx]);
119+
cstr.bound(value) == (linearIdx % (thisBasis * divisor)).floorDiv(divisor);
120+
}
121+
};
122+
123+
struct AffineLinearizeIndexOpInterface
124+
: public ValueBoundsOpInterface::ExternalModel<
125+
AffineLinearizeIndexOpInterface, AffineLinearizeIndexOp> {
126+
void populateBoundsForIndexValue(Operation *rawOp, Value value,
127+
ValueBoundsConstraintSet &cstr) const {
128+
auto op = cast<AffineLinearizeIndexOp>(rawOp);
129+
assert(value == op.getResult() &&
130+
"value isn't the result of this linearize");
131+
132+
AffineExpr bound = cstr.getExpr(0);
133+
AffineExpr stride = cstr.getExpr(1);
134+
SmallVector<OpFoldResult> basis = op.getPaddedBasis();
135+
OperandRange multiIndex = op.getMultiIndex();
136+
unsigned numArgs = multiIndex.size();
137+
for (auto [revArgNum, length] : llvm::enumerate(llvm::reverse(basis))) {
138+
unsigned argNum = numArgs - (revArgNum + 1);
139+
if (argNum == 0)
140+
break;
141+
OpFoldResult indexAsFoldRes = getAsOpFoldResult(multiIndex[argNum]);
142+
bound = bound + cstr.getExpr(indexAsFoldRes) * stride;
143+
stride = stride * cstr.getExpr(length);
144+
}
145+
bound = bound + cstr.getExpr(op.getMultiIndex().front()) * stride;
146+
cstr.bound(value) == bound;
147+
if (op.getDisjoint() && !basis.front().isNull()) {
148+
cstr.bound(value) < stride *cstr.getExpr(basis.front());
149+
}
150+
}
151+
};
94152
} // namespace
95153
} // namespace mlir
96154

@@ -100,6 +158,10 @@ void mlir::affine::registerValueBoundsOpInterfaceExternalModels(
100158
AffineApplyOp::attachInterface<AffineApplyOpInterface>(*ctx);
101159
AffineMaxOp::attachInterface<AffineMaxOpInterface>(*ctx);
102160
AffineMinOp::attachInterface<AffineMinOpInterface>(*ctx);
161+
AffineDelinearizeIndexOp::attachInterface<
162+
AffineDelinearizeIndexOpInterface>(*ctx);
163+
AffineLinearizeIndexOp::attachInterface<AffineLinearizeIndexOpInterface>(
164+
*ctx);
103165
});
104166
}
105167

mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,84 @@ func.func @compare_maps(%a: index, %b: index) {
155155
: (index, index, index, index) -> ()
156156
return
157157
}
158+
159+
// -----
160+
161+
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)>
162+
// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)>
163+
// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)>
164+
// CHECK-LABEL: func.func @delinearize_static
165+
// CHECK-SAME: (%[[arg0:.+]]: index)
166+
// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]]
167+
// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]]
168+
// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]]
169+
// CHECK: return %[[v1]], %[[v2]], %[[v3]]
170+
func.func @delinearize_static(%arg0: index) -> (index, index, index) {
171+
%c2 = arith.constant 2 : index
172+
%c3 = arith.constant 3 : index
173+
%0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index
174+
%1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
175+
%2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
176+
%3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index)
177+
// expected-remark @below{{true}}
178+
"test.compare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> ()
179+
// expected-remark @below{{true}}
180+
"test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> ()
181+
return %1, %2, %3 : index, index, index
182+
}
183+
184+
// -----
185+
186+
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 floordiv 15)>
187+
// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> ((s0 mod 15) floordiv 5)>
188+
// CHECK-DAG: #[[$map3:.+]] = affine_map<()[s0] -> (s0 mod 5)>
189+
// CHECK-LABEL: func.func @delinearize_static_no_outer_bound
190+
// CHECK-SAME: (%[[arg0:.+]]: index)
191+
// CHECK-DAG: %[[v1:.+]] = affine.apply #[[$map1]]()[%[[arg0]]]
192+
// CHECK-DAG: %[[v2:.+]] = affine.apply #[[$map2]]()[%[[arg0]]]
193+
// CHECK-DAG: %[[v3:.+]] = affine.apply #[[$map3]]()[%[[arg0]]]
194+
// CHECK: return %[[v1]], %[[v2]], %[[v3]]
195+
func.func @delinearize_static_no_outer_bound(%arg0: index) -> (index, index, index) {
196+
%c2 = arith.constant 2 : index
197+
%c3 = arith.constant 3 : index
198+
%0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index
199+
%1 = "test.reify_bound"(%0#0) {type = "EQ"} : (index) -> (index)
200+
%2 = "test.reify_bound"(%0#1) {type = "EQ"} : (index) -> (index)
201+
%3 = "test.reify_bound"(%0#2) {type = "EQ"} : (index) -> (index)
202+
"test.compaare"(%0#0, %c2) {cmp = "LT"} : (index, index) -> ()
203+
// expected-remark @below{{true}}
204+
"test.compare"(%0#1, %c3) {cmp = "LT"} : (index, index) -> ()
205+
return %1, %2, %3 : index, index, index
206+
}
207+
208+
// -----
209+
210+
// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
211+
// CHECK-LABEL: func.func @linearize_static
212+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index)
213+
// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]]
214+
// CHECK: return %[[v1]]
215+
func.func @linearize_static(%arg0: index, %arg1: index) -> index {
216+
%c6 = arith.constant 6 : index
217+
%0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 3) : index
218+
%1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
219+
// expected-remark @below{{true}}
220+
"test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> ()
221+
return %1 : index
222+
}
223+
224+
// -----
225+
226+
// CHECK: #[[$map:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
227+
// CHECK-LABEL: func.func @linearize_static_no_outer_bound
228+
// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index)
229+
// CHECK: %[[v1:.+]] = affine.apply #[[$map]]()[%[[arg1]], %[[arg0]]]
230+
// CHECK: return %[[v1]]
231+
func.func @linearize_static_no_outer_bound(%arg0: index, %arg1: index) -> index {
232+
%c6 = arith.constant 6 : index
233+
%0 = affine.linearize_index disjoint [%arg0, %arg1] by (3) : index
234+
%1 = "test.reify_bound"(%0) {type = "EQ"} : (index) -> (index)
235+
// expected-error @below{{unknown}}
236+
"test.compare"(%0, %c6) {cmp = "LT"} : (index, index) -> ()
237+
return %1 : index
238+
}

0 commit comments

Comments
 (0)