Skip to content

Commit

Permalink
[MLIR][OpenMP] Improve omp.section block arguments handling
Browse files Browse the repository at this point in the history
The `omp.section` operation is an outlier in that the block arguments it has
are defined by clauses on the required parent `omp.sections` operation.

This patch updates the definition of this operation introducing the
`BlockArgOpenMPOpInterface` to simplify the handling and verification of these
block arguments, implemented based on the parent `omp.sections`.
  • Loading branch information
skatrak committed Oct 1, 2024
1 parent 4e52e6a commit 2ebd822
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 2 deletions.
12 changes: 10 additions & 2 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
// 2.8.1 Sections Construct
//===----------------------------------------------------------------------===//

def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">],
singleRegion = true> {
def SectionOp : OpenMP_Op<"section", traits = [
BlockArgOpenMPOpInterface, HasParent<"SectionsOp">
], singleRegion = true> {
let summary = "section directive";
let description = [{
A section operation encloses a region which represents one section in a
Expand All @@ -218,6 +219,13 @@ def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">],
operation. This is done to reflect situations where these block arguments
represent variables private to each section.
}];
let extraClassDeclaration = [{
// Override BlockArgOpenMPOpInterface methods based on the parent
// omp.sections operation. Only forward-declare here because SectionsOp is
// not completely defined at this point.
unsigned numPrivateBlockArgs();
unsigned numReductionBlockArgs();
}] # clausesExtraClassDeclaration;
let assemblyFormat = "$region attr-dict";
}

Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1844,6 +1844,18 @@ LogicalResult TeamsOp::verify() {
getReductionByref());
}

//===----------------------------------------------------------------------===//
// SectionOp
//===----------------------------------------------------------------------===//

unsigned SectionOp::numPrivateBlockArgs() {
return getParentOp().numPrivateBlockArgs();
}

unsigned SectionOp::numReductionBlockArgs() {
return getParentOp().numReductionBlockArgs();
}

//===----------------------------------------------------------------------===//
// SectionsOp
//===----------------------------------------------------------------------===//
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,31 @@ func.func @omp_sections() {

// -----

omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
%0 = arith.constant 0.0 : f32
omp.yield (%0 : f32)
}
combiner {
^bb1(%arg0: f32, %arg1: f32):
%1 = arith.addf %arg0, %arg1 : f32
omp.yield (%1 : f32)
}

func.func @omp_sections(%x : !llvm.ptr) {
omp.sections reduction(@add_f32 %x -> %arg0 : !llvm.ptr) {
// expected-error @below {{op expected at least 1 entry block argument(s)}}
omp.section {
omp.terminator
}
omp.terminator
}
return
}

// -----

func.func @omp_single(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.single" (%data_var) ({
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1127,11 +1127,13 @@ func.func @sections_reduction() {
omp.sections reduction(@add_f32 %0 -> %arg0 : !llvm.ptr) {
// CHECK: omp.section
omp.section {
^bb0(%arg1 : !llvm.ptr):
%1 = arith.constant 2.0 : f32
omp.terminator
}
// CHECK: omp.section
omp.section {
^bb0(%arg1 : !llvm.ptr):
%1 = arith.constant 3.0 : f32
omp.terminator
}
Expand All @@ -1148,11 +1150,13 @@ func.func @sections_reduction_byref() {
omp.sections reduction(byref @add_f32 %0 -> %arg0 : !llvm.ptr) {
// CHECK: omp.section
omp.section {
^bb0(%arg1 : !llvm.ptr):
%1 = arith.constant 2.0 : f32
omp.terminator
}
// CHECK: omp.section
omp.section {
^bb0(%arg1 : !llvm.ptr):
%1 = arith.constant 3.0 : f32
omp.terminator
}
Expand Down Expand Up @@ -1246,10 +1250,12 @@ func.func @sections_reduction2() {
// CHECK: omp.sections reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>)
omp.sections reduction(@add2_f32 %0 -> %arg0 : memref<1xf32>) {
omp.section {
^bb0(%arg1 : !llvm.ptr):
%1 = arith.constant 2.0 : f32
omp.terminator
}
omp.section {
^bb0(%arg1 : !llvm.ptr):
%1 = arith.constant 2.0 : f32
omp.terminator
}
Expand Down

0 comments on commit 2ebd822

Please sign in to comment.