Skip to content

Commit d6fb899

Browse files
authored
[MLIR][OpenMP] Improve loop wrapper representation (#97706)
This patch replaces the `SingleBlockImplicitTerminator<"TerminatorOp">` trait of loop wrapper operations for the `SingleBlock` trait. This enables a more robust implementation of the `LoopWrapperInterface::isWrapper()` method, since it does no longer have to deal with the potentially missing (implicit) terminator. The `LoopWrapperInterface::isWrapper()` method is also extended to not identify as wrappers those operations which have a loop wrapper operation inside that is not taking a wrapper role. This is important for cases where `omp.parallel` is nested, which can but is not required to work as a loop wrapper. Tests are updated to integrate these representation and validation changes.
1 parent 4c23625 commit d6fb899

File tree

7 files changed

+71
-12
lines changed

7 files changed

+71
-12
lines changed

flang/test/Fir/convert-to-llvm-openmp-and-fir.fir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
200200
fir.store %3 to %6 : !fir.ref<i32>
201201
omp.yield
202202
}
203+
omp.terminator
203204
}
204205
omp.terminator
205206
}
@@ -225,6 +226,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
225226
// CHECK: llvm.store %[[I1]], %[[ARR_I_REF]] : i32, !llvm.ptr
226227
// CHECK: omp.yield
227228
// CHECK: }
229+
// CHECK: omp.terminator
228230
// CHECK: }
229231
// CHECK: omp.terminator
230232
// CHECK: }
@@ -518,6 +520,7 @@ func.func @_QPsimd_with_nested_loop() {
518520
fir.store %7 to %3 : !fir.ref<i32>
519521
omp.yield
520522
}
523+
omp.terminator
521524
}
522525
return
523526
}
@@ -538,6 +541,7 @@ func.func @_QPsimd_with_nested_loop() {
538541
// CHECK: ^bb3:
539542
// CHECK: omp.yield
540543
// CHECK: }
544+
// CHECK: omp.terminator
541545
// CHECK: }
542546
// CHECK: llvm.return
543547
// CHECK: }

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
354354

355355
def WsloopOp : OpenMP_Op<"wsloop", traits = [
356356
AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
357-
RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
357+
RecursiveMemoryEffects, SingleBlock
358358
], clauses = [
359359
// TODO: Complete clause list (allocate, private).
360360
// TODO: Sort clauses alphabetically.
@@ -418,7 +418,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
418418

419419
def SimdOp : OpenMP_Op<"simd", traits = [
420420
AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
421-
RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
421+
RecursiveMemoryEffects, SingleBlock
422422
], clauses = [
423423
// TODO: Complete clause list (linear, private, reduction).
424424
OpenMP_AlignedClause, OpenMP_IfClause, OpenMP_NontemporalClause,
@@ -485,7 +485,7 @@ def YieldOp : OpenMP_Op<"yield",
485485
//===----------------------------------------------------------------------===//
486486
def DistributeOp : OpenMP_Op<"distribute", traits = [
487487
AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
488-
RecursiveMemoryEffects, SingleBlockImplicitTerminator<"TerminatorOp">
488+
RecursiveMemoryEffects, SingleBlock
489489
], clauses = [
490490
// TODO: Complete clause list (private).
491491
// TODO: Sort clauses alphabetically.
@@ -575,7 +575,7 @@ def TaskOp : OpenMP_Op<"task", traits = [
575575
def TaskloopOp : OpenMP_Op<"taskloop", traits = [
576576
AttrSizedOperandSegments, AutomaticAllocationScope,
577577
DeclareOpInterfaceMethods<LoopWrapperInterface>, RecursiveMemoryEffects,
578-
SingleBlockImplicitTerminator<"TerminatorOp">
578+
SingleBlock
579579
], clauses = [
580580
// TODO: Complete clause list (private).
581581
// TODO: Sort clauses alphabetically.

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
8484
/*description=*/[{
8585
Tell whether the operation could be taking the role of a loop wrapper.
8686
That is, it has a single region with a single block in which there are
87-
two operations: another wrapper or `omp.loop_nest` operation and a
88-
terminator.
87+
two operations: another wrapper (also taking a loop wrapper role) or
88+
`omp.loop_nest` operation and a terminator.
8989
}],
9090
/*retTy=*/"bool",
9191
/*methodName=*/"isWrapper",
@@ -102,8 +102,14 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
102102

103103
Operation &firstOp = *r.op_begin();
104104
Operation &secondOp = *(std::next(r.op_begin()));
105-
return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) &&
106-
secondOp.hasTrait<OpTrait::IsTerminator>();
105+
106+
if (!secondOp.hasTrait<OpTrait::IsTerminator>())
107+
return false;
108+
109+
if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
110+
return wrapper.isWrapper();
111+
112+
return ::llvm::isa<LoopNestOp>(firstOp);
107113
}]
108114
>,
109115
InterfaceMethod<

mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ func.func @loop_nest_block_arg(%val : i32, %ub : i32, %i : index) {
174174
^bb3:
175175
omp.yield
176176
}
177+
omp.terminator
177178
}
178179
return
179180
}

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ func.func @unknown_clause() {
1111
// -----
1212

1313
func.func @not_wrapper() {
14+
// expected-error@+1 {{op must be a loop wrapper}}
1415
omp.distribute {
15-
// expected-error@+1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}}
1616
omp.parallel {
1717
%0 = arith.constant 0 : i32
1818
omp.terminator
@@ -383,12 +383,16 @@ func.func @omp_simd() -> () {
383383

384384
// -----
385385

386-
func.func @omp_simd_nested_wrapper() -> () {
386+
func.func @omp_simd_nested_wrapper(%lb : index, %ub : index, %step : index) -> () {
387387
// expected-error @below {{op must wrap an 'omp.loop_nest' directly}}
388388
omp.simd {
389389
omp.distribute {
390+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
391+
omp.yield
392+
}
390393
omp.terminator
391394
}
395+
omp.terminator
392396
}
393397
return
394398
}
@@ -1960,6 +1964,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
19601964
}
19611965
omp.terminator
19621966
}
1967+
omp.terminator
19631968
}
19641969
return
19651970
}
@@ -2158,11 +2163,13 @@ func.func @omp_distribute_wrapper() -> () {
21582163

21592164
// -----
21602165

2161-
func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
2166+
func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
21622167
// expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
21632168
omp.distribute {
21642169
"omp.wsloop"() ({
2165-
%0 = arith.constant 0 : i32
2170+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
2171+
"omp.yield"() : () -> ()
2172+
}
21662173
"omp.terminator"() : () -> ()
21672174
}) : () -> ()
21682175
"omp.terminator"() : () -> ()

0 commit comments

Comments
 (0)