Skip to content

Commit 7734138

Browse files
authored
[mlir][OpenMP] allow cancellation to not be directly nested (#134084)
omp.cancel and omp.cancellationpoint contain an attribute describing the type of parent construct which should be cancelled. e.g. ``` !$omp cancel do ``` Must be inside of a wsloop. Previously the verifer required the immediate parent to be this operation. This is not quite right because something like the following is valid: ``` !$omp parallel do do i = 1, N if (cond) then !$omp cancel do endif enddo ``` This patch relaxes the verifier to only require that some parent operation matches (not necessarily the immediate parent).
1 parent cbda72a commit 7734138

File tree

3 files changed

+131
-22
lines changed

3 files changed

+131
-22
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

+33-22
Original file line numberDiff line numberDiff line change
@@ -3162,24 +3162,32 @@ void CancelOp::build(OpBuilder &builder, OperationState &state,
31623162
CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
31633163
}
31643164

3165+
static Operation *getParentInSameDialect(Operation *thisOp) {
3166+
Operation *parent = thisOp->getParentOp();
3167+
while (parent) {
3168+
if (parent->getDialect() == thisOp->getDialect())
3169+
return parent;
3170+
parent = parent->getParentOp();
3171+
}
3172+
return nullptr;
3173+
}
3174+
31653175
LogicalResult CancelOp::verify() {
31663176
ClauseCancellationConstructType cct = getCancelDirective();
3167-
Operation *parentOp = (*this)->getParentOp();
3168-
3169-
if (!parentOp) {
3170-
return emitOpError() << "must be used within a region supporting "
3171-
"cancel directive";
3172-
}
3177+
// The next OpenMP operation in the chain of parents
3178+
Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3179+
if (!structuralParent)
3180+
return emitOpError() << "Orphaned cancel construct";
31733181

31743182
if ((cct == ClauseCancellationConstructType::Parallel) &&
3175-
!isa<ParallelOp>(parentOp)) {
3183+
!mlir::isa<ParallelOp>(structuralParent)) {
31763184
return emitOpError() << "cancel parallel must appear "
31773185
<< "inside a parallel region";
31783186
}
31793187
if (cct == ClauseCancellationConstructType::Loop) {
3180-
auto loopOp = dyn_cast<LoopNestOp>(parentOp);
3181-
auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
3182-
loopOp ? loopOp->getParentOp() : nullptr);
3188+
// structural parent will be omp.loop_nest, directly nested inside
3189+
// omp.wsloop
3190+
auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
31833191

31843192
if (!wsloopOp) {
31853193
return emitOpError()
@@ -3195,12 +3203,15 @@ LogicalResult CancelOp::verify() {
31953203
}
31963204

31973205
} else if (cct == ClauseCancellationConstructType::Sections) {
3198-
if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
3206+
// structural parent will be an omp.section, directly nested inside
3207+
// omp.sections
3208+
auto sectionsOp =
3209+
mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
3210+
if (!sectionsOp) {
31993211
return emitOpError() << "cancel sections must appear "
32003212
<< "inside a sections region";
32013213
}
3202-
if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
3203-
cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
3214+
if (sectionsOp.getNowait()) {
32043215
return emitError() << "A sections construct that is canceled "
32053216
<< "must not have a nowait clause";
32063217
}
@@ -3220,25 +3231,25 @@ void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
32203231

32213232
LogicalResult CancellationPointOp::verify() {
32223233
ClauseCancellationConstructType cct = getCancelDirective();
3223-
Operation *parentOp = (*this)->getParentOp();
3224-
3225-
if (!parentOp) {
3226-
return emitOpError() << "must be used within a region supporting "
3227-
"cancellation point directive";
3228-
}
3234+
// The next OpenMP operation in the chain of parents
3235+
Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3236+
if (!structuralParent)
3237+
return emitOpError() << "Orphaned cancellation point";
32293238

32303239
if ((cct == ClauseCancellationConstructType::Parallel) &&
3231-
!(isa<ParallelOp>(parentOp))) {
3240+
!mlir::isa<ParallelOp>(structuralParent)) {
32323241
return emitOpError() << "cancellation point parallel must appear "
32333242
<< "inside a parallel region";
32343243
}
3244+
// Strucutal parent here will be an omp.loop_nest. Get the parent of that to
3245+
// find the wsloop
32353246
if ((cct == ClauseCancellationConstructType::Loop) &&
3236-
(!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
3247+
!mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
32373248
return emitOpError() << "cancellation point loop must appear "
32383249
<< "inside a worksharing-loop region";
32393250
}
32403251
if ((cct == ClauseCancellationConstructType::Sections) &&
3241-
!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
3252+
!mlir::isa<omp::SectionOp>(structuralParent)) {
32423253
return emitOpError() << "cancellation point sections must appear "
32433254
<< "inside a sections region";
32443255
}

mlir/test/Dialect/OpenMP/invalid.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,14 @@ func.func @omp_task(%mem: memref<1xf32>) {
17101710

17111711
// -----
17121712

1713+
func.func @omp_cancel() {
1714+
// expected-error @below {{Orphaned cancel construct}}
1715+
omp.cancel cancellation_construct_type(parallel)
1716+
return
1717+
}
1718+
1719+
// -----
1720+
17131721
func.func @omp_cancel() {
17141722
omp.sections {
17151723
// expected-error @below {{cancel parallel must appear inside a parallel region}}
@@ -1789,6 +1797,14 @@ func.func @omp_cancel5() -> () {
17891797

17901798
// -----
17911799

1800+
func.func @omp_cancellationpoint() {
1801+
// expected-error @below {{Orphaned cancellation point}}
1802+
omp.cancellation_point cancellation_construct_type(parallel)
1803+
return
1804+
}
1805+
1806+
// -----
1807+
17921808
func.func @omp_cancellationpoint() {
17931809
omp.sections {
17941810
// expected-error @below {{cancellation point parallel must appear inside a parallel region}}

mlir/test/Dialect/OpenMP/ops.mlir

+82
Original file line numberDiff line numberDiff line change
@@ -2201,6 +2201,48 @@ func.func @omp_cancel_sections() -> () {
22012201
return
22022202
}
22032203

2204+
func.func @omp_cancel_parallel_nested(%if_cond : i1) -> () {
2205+
omp.parallel {
2206+
scf.if %if_cond {
2207+
// CHECK: omp.cancel cancellation_construct_type(parallel)
2208+
omp.cancel cancellation_construct_type(parallel)
2209+
}
2210+
// CHECK: omp.terminator
2211+
omp.terminator
2212+
}
2213+
return
2214+
}
2215+
2216+
func.func @omp_cancel_wsloop_nested(%lb : index, %ub : index, %step : index,
2217+
%if_cond : i1) {
2218+
omp.wsloop {
2219+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
2220+
scf.if %if_cond {
2221+
// CHECK: omp.cancel cancellation_construct_type(loop)
2222+
omp.cancel cancellation_construct_type(loop)
2223+
}
2224+
// CHECK: omp.yield
2225+
omp.yield
2226+
}
2227+
}
2228+
return
2229+
}
2230+
2231+
func.func @omp_cancel_sections_nested(%if_cond : i1) -> () {
2232+
omp.sections {
2233+
omp.section {
2234+
scf.if %if_cond {
2235+
// CHECK: omp.cancel cancellation_construct_type(sections)
2236+
omp.cancel cancellation_construct_type(sections)
2237+
}
2238+
omp.terminator
2239+
}
2240+
// CHECK: omp.terminator
2241+
omp.terminator
2242+
}
2243+
return
2244+
}
2245+
22042246
func.func @omp_cancellationpoint_parallel() -> () {
22052247
omp.parallel {
22062248
// CHECK: omp.cancellation_point cancellation_construct_type(parallel)
@@ -2241,6 +2283,46 @@ func.func @omp_cancellationpoint_sections() -> () {
22412283
return
22422284
}
22432285

2286+
func.func @omp_cancellationpoint_parallel_nested(%if_cond : i1) -> () {
2287+
omp.parallel {
2288+
scf.if %if_cond {
2289+
// CHECK: omp.cancellation_point cancellation_construct_type(parallel)
2290+
omp.cancellation_point cancellation_construct_type(parallel)
2291+
}
2292+
omp.terminator
2293+
}
2294+
return
2295+
}
2296+
2297+
func.func @omp_cancellationpoint_wsloop_nested(%lb : index, %ub : index, %step : index, %if_cond : i1) {
2298+
omp.wsloop {
2299+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
2300+
scf.if %if_cond {
2301+
// CHECK: omp.cancellation_point cancellation_construct_type(loop)
2302+
omp.cancellation_point cancellation_construct_type(loop)
2303+
}
2304+
// CHECK: omp.yield
2305+
omp.yield
2306+
}
2307+
}
2308+
return
2309+
}
2310+
2311+
func.func @omp_cancellationpoint_sections_nested(%if_cond : i1) -> () {
2312+
omp.sections {
2313+
omp.section {
2314+
scf.if %if_cond {
2315+
// CHECK: omp.cancellation_point cancellation_construct_type(sections)
2316+
omp.cancellation_point cancellation_construct_type(sections)
2317+
}
2318+
omp.terminator
2319+
}
2320+
// CHECK: omp.terminator
2321+
omp.terminator
2322+
}
2323+
return
2324+
}
2325+
22442326
// CHECK-LABEL: @omp_taskgroup_no_tasks
22452327
func.func @omp_taskgroup_no_tasks() -> () {
22462328

0 commit comments

Comments
 (0)