Skip to content

Commit 6c36b37

Browse files
committed
[MLIR][Python] Python binding support for AffineIfOp
1 parent 929cbe7 commit 6c36b37

File tree

4 files changed

+139
-1
lines changed

4 files changed

+139
-1
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,8 @@ def AffineIfOp : Affine_Op<"if",
407407
}
408408
```
409409
}];
410-
let arguments = (ins Variadic<AnyType>);
410+
let arguments = (ins Variadic<AnyType>,
411+
IntegerSetAttr:$condition);
411412
let results = (outs Variadic<AnyType>:$results);
412413
let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion);
413414

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,15 @@ CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> {
558558
let constBuilderCall = "::mlir::AffineMapAttr::get($0)";
559559
}
560560

561+
// Attributes containing integer sets.
562+
def IntegerSetAttr : Attr<
563+
CPred<"::llvm::isa<::mlir::IntegerSetAttr>($_self)">, "IntegerSet attribute"> {
564+
let storageType = [{::mlir::IntegerSetAttr }];
565+
let returnType = [{ ::mlir::IntegerSet }];
566+
let valueType = NoneType;
567+
let constBuilderCall = "::mlir::IntegerSetAttr::get($0)";
568+
}
569+
561570
// Base class for array attributes.
562571
class ArrayAttrBase<Pred condition, string summary> : Attr<condition, summary> {
563572
let storageType = [{ ::mlir::ArrayAttr }];

mlir/python/mlir/dialects/affine.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,61 @@ def for_(
156156
yield iv, iter_args[0]
157157
else:
158158
yield iv
159+
160+
161+
@_ods_cext.register_operation(_Dialect, replace=True)
162+
class AffineIfOp(AffineIfOp):
163+
"""Specialization for the Affine if op class."""
164+
165+
def __init__(
166+
self,
167+
cond: IntegerSet,
168+
results_: Optional[Type] = None,
169+
*,
170+
cond_operands: Optional[_VariadicResultValueT] = None,
171+
has_else: bool = False,
172+
loc=None,
173+
ip=None,
174+
):
175+
"""Creates an Affine `if` operation.
176+
177+
- `cond` is the integer set used to determine which regions of code
178+
will be executed.
179+
- `results` are the list of types to be yielded by the operand.
180+
- `cond_operands` is the list of arguments to substitute the
181+
dimensions, then symbols in the `cond` integer set expression to
182+
determine whether they are in the set.
183+
- `has_else` determines whether the affine if operation has the else
184+
branch.
185+
"""
186+
if results_ is None:
187+
results_ = []
188+
if cond_operands is None:
189+
cond_operands = []
190+
191+
if cond.n_inputs != len(cond_operands):
192+
raise ValueError(
193+
f"expected {cond.n_inputs} condition operands, got {len(cond_operands)}"
194+
)
195+
196+
operands = []
197+
operands.extend(cond_operands)
198+
results = []
199+
results.extend(results_)
200+
201+
super().__init__(results, cond_operands, cond)
202+
self.regions[0].blocks.append(*[])
203+
if has_else:
204+
self.regions[1].blocks.append(*[])
205+
206+
@property
207+
def then_block(self) -> Block:
208+
"""Returns the then block of the if operation."""
209+
return self.regions[0].blocks[0]
210+
211+
@property
212+
def else_block(self) -> Optional[Block]:
213+
"""Returns the else block of the if operation."""
214+
if len(self.regions[1].blocks) == 0:
215+
return None
216+
return self.regions[1].blocks[0]

mlir/test/python/dialects/affine.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,73 @@ def range_loop_8(lb, ub, memref_v):
263263
add = arith.addi(i, i)
264264
memref.store(add, it, [i])
265265
affine.yield_([it])
266+
267+
268+
# CHECK-LABEL: TEST: testAffineIfWithoutElse
269+
@constructAndPrintInModule
270+
def testAffineIfWithoutElse():
271+
index = IndexType.get()
272+
i32 = IntegerType.get_signless(32)
273+
d0 = AffineDimExpr.get(0)
274+
275+
# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
276+
cond = IntegerSet.get(1, 0, [d0 - 5], [False])
277+
278+
# CHECK-LABEL: func.func @simple_affine_if(
279+
# CHECK-SAME: %[[VAL_0:.*]]: index) {
280+
# CHECK: affine.if #[[$SET0]](%[[VAL_0]]) {
281+
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
282+
# CHECK: %[[VAL_2:.*]] = arith.addi %[[VAL_1]], %[[VAL_1]] : i32
283+
# CHECK: }
284+
# CHECK: return
285+
# CHECK: }
286+
@func.FuncOp.from_py_func(index)
287+
def simple_affine_if(cond_operands):
288+
if_op = affine.AffineIfOp(cond, cond_operands=[cond_operands])
289+
with InsertionPoint(if_op.then_block):
290+
one = arith.ConstantOp(i32, 1)
291+
add = arith.AddIOp(one, one)
292+
affine.AffineYieldOp([])
293+
return
294+
295+
296+
# CHECK-LABEL: TEST: testAffineIfWithElse
297+
@constructAndPrintInModule
298+
def testAffineIfWithElse():
299+
index = IndexType.get()
300+
i32 = IntegerType.get_signless(32)
301+
d0 = AffineDimExpr.get(0)
302+
303+
# CHECK: #[[$SET0:.*]] = affine_set<(d0) : (d0 - 5 >= 0)>
304+
cond = IntegerSet.get(1, 0, [d0 - 5], [False])
305+
306+
# CHECK-LABEL: func.func @simple_affine_if_else(
307+
# CHECK-SAME: %[[VAL_0:.*]]: index) {
308+
# CHECK: %[[VAL_IF:.*]]:2 = affine.if #[[$SET0]](%[[VAL_0]]) -> (i32, i32) {
309+
# CHECK: %[[VAL_XT:.*]] = arith.constant 0 : i32
310+
# CHECK: %[[VAL_YT:.*]] = arith.constant 1 : i32
311+
# CHECK: affine.yield %[[VAL_XT]], %[[VAL_YT]] : i32, i32
312+
# CHECK: } else {
313+
# CHECK: %[[VAL_XF:.*]] = arith.constant 2 : i32
314+
# CHECK: %[[VAL_YF:.*]] = arith.constant 3 : i32
315+
# CHECK: affine.yield %[[VAL_XF]], %[[VAL_YF]] : i32, i32
316+
# CHECK: }
317+
# CHECK: %[[VAL_ADD:.*]] = arith.addi %[[VAL_IF]]#0, %[[VAL_IF]]#1 : i32
318+
# CHECK: return
319+
# CHECK: }
320+
321+
@func.FuncOp.from_py_func(index)
322+
def simple_affine_if_else(cond_operands):
323+
if_op = affine.AffineIfOp(
324+
cond, [i32, i32], cond_operands=[cond_operands], has_else=True
325+
)
326+
with InsertionPoint(if_op.then_block):
327+
x_true = arith.ConstantOp(i32, 0)
328+
y_true = arith.ConstantOp(i32, 1)
329+
affine.AffineYieldOp([x_true, y_true])
330+
with InsertionPoint(if_op.else_block):
331+
x_false = arith.ConstantOp(i32, 2)
332+
y_false = arith.ConstantOp(i32, 3)
333+
affine.AffineYieldOp([x_false, y_false])
334+
add = arith.AddIOp(if_op.results[0], if_op.results[1])
335+
return

0 commit comments

Comments
 (0)