diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index d2a2b44c042fb7..66f63fc02fe2f3 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -207,8 +207,9 @@ def TeamsOp : OpenMP_Op<"teams", traits = [ // 2.8.1 Sections Construct //===----------------------------------------------------------------------===// -def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">], - singleRegion = true> { +def SectionOp : OpenMP_Op<"section", traits = [ + BlockArgOpenMPOpInterface, HasParent<"SectionsOp"> + ], singleRegion = true> { let summary = "section directive"; let description = [{ A section operation encloses a region which represents one section in a @@ -218,6 +219,13 @@ def SectionOp : OpenMP_Op<"section", [HasParent<"SectionsOp">], operation. This is done to reflect situations where these block arguments represent variables private to each section. }]; + let extraClassDeclaration = [{ + // Override BlockArgOpenMPOpInterface methods based on the parent + // omp.sections operation. Only forward-declare here because SectionsOp is + // not completely defined at this point. + unsigned numPrivateBlockArgs(); + unsigned numReductionBlockArgs(); + }] # clausesExtraClassDeclaration; let assemblyFormat = "$region attr-dict"; } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index bb886323238263..d516c8d9e0be6c 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1844,6 +1844,18 @@ LogicalResult TeamsOp::verify() { getReductionByref()); } +//===----------------------------------------------------------------------===// +// SectionOp +//===----------------------------------------------------------------------===// + +unsigned SectionOp::numPrivateBlockArgs() { + return getParentOp().numPrivateBlockArgs(); +} + +unsigned SectionOp::numReductionBlockArgs() { + return getParentOp().numReductionBlockArgs(); +} + //===----------------------------------------------------------------------===// // SectionsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 273aeb975c9c3c..a780efe6d22e16 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1584,6 +1584,31 @@ func.func @omp_sections() { // ----- +omp.declare_reduction @add_f32 : f32 +init { +^bb0(%arg: f32): + %0 = arith.constant 0.0 : f32 + omp.yield (%0 : f32) +} +combiner { +^bb1(%arg0: f32, %arg1: f32): + %1 = arith.addf %arg0, %arg1 : f32 + omp.yield (%1 : f32) +} + +func.func @omp_sections(%x : !llvm.ptr) { + omp.sections reduction(@add_f32 %x -> %arg0 : !llvm.ptr) { + // expected-error @below {{op expected at least 1 entry block argument(s)}} + omp.section { + omp.terminator + } + omp.terminator + } + return +} + +// ----- + func.func @omp_single(%data_var : memref) -> () { // expected-error @below {{expected equal sizes for allocate and allocator variables}} "omp.single" (%data_var) ({ diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index ce3351ba1149f3..a4423782a723bf 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -1127,11 +1127,13 @@ func.func @sections_reduction() { omp.sections reduction(@add_f32 %0 -> %arg0 : !llvm.ptr) { // CHECK: omp.section omp.section { + ^bb0(%arg1 : !llvm.ptr): %1 = arith.constant 2.0 : f32 omp.terminator } // CHECK: omp.section omp.section { + ^bb0(%arg1 : !llvm.ptr): %1 = arith.constant 3.0 : f32 omp.terminator } @@ -1148,11 +1150,13 @@ func.func @sections_reduction_byref() { omp.sections reduction(byref @add_f32 %0 -> %arg0 : !llvm.ptr) { // CHECK: omp.section omp.section { + ^bb0(%arg1 : !llvm.ptr): %1 = arith.constant 2.0 : f32 omp.terminator } // CHECK: omp.section omp.section { + ^bb0(%arg1 : !llvm.ptr): %1 = arith.constant 3.0 : f32 omp.terminator } @@ -1246,10 +1250,12 @@ func.func @sections_reduction2() { // CHECK: omp.sections reduction(@add2_f32 %{{.+}} -> %{{.+}} : memref<1xf32>) omp.sections reduction(@add2_f32 %0 -> %arg0 : memref<1xf32>) { omp.section { + ^bb0(%arg1 : !llvm.ptr): %1 = arith.constant 2.0 : f32 omp.terminator } omp.section { + ^bb0(%arg1 : !llvm.ptr): %1 = arith.constant 2.0 : f32 omp.terminator }