diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index a4d2524bccf5c3..6a97b32d80c4b3 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1063,7 +1063,7 @@ bool ClauseProcessor::processReduction( llvm::SmallVector reductionDeclSymbols; llvm::SmallVector reductionSyms; ReductionProcessor rp; - rp.addDeclareReduction( + rp.addDeclareReduction( currentLocation, converter, clause, reductionVars, reduceVarByRef, reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr); @@ -1085,6 +1085,80 @@ bool ClauseProcessor::processReduction( }); } +bool ClauseProcessor::processTaskReduction( + mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result, + llvm::SmallVectorImpl *outReductionTypes, + llvm::SmallVectorImpl *outReductionSyms) const { + return findRepeatableClause( + [&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) { + llvm::SmallVector taskReductionVars; + llvm::SmallVector taskReductionByref; + llvm::SmallVector taskReductionDeclSymbols; + llvm::SmallVector taskReductionSyms; + ReductionProcessor rp; + rp.addDeclareReduction( + currentLocation, converter, clause, taskReductionVars, + taskReductionByref, taskReductionDeclSymbols, + outReductionSyms ? &taskReductionSyms : nullptr); + + // Copy local lists into the output. + llvm::copy(taskReductionVars, + std::back_inserter(result.taskReductionVars)); + llvm::copy(taskReductionByref, + std::back_inserter(result.taskReductionByref)); + llvm::copy(taskReductionDeclSymbols, + std::back_inserter(result.taskReductionSyms)); + + if (outReductionTypes) { + outReductionTypes->reserve(outReductionTypes->size() + + taskReductionVars.size()); + llvm::transform(taskReductionVars, + std::back_inserter(*outReductionTypes), + [](mlir::Value v) { return v.getType(); }); + } + + if (outReductionSyms) + llvm::copy(taskReductionSyms, std::back_inserter(*outReductionSyms)); + }); +} + +bool ClauseProcessor::processInReduction( + mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result, + llvm::SmallVectorImpl *outReductionTypes, + llvm::SmallVectorImpl *outReductionSyms) const { + return findRepeatableClause( + [&](const omp::clause::InReduction &clause, + const parser::CharBlock &source) { + llvm::SmallVector inReductionVars; + llvm::SmallVector inReductionByref; + llvm::SmallVector inReductionDeclSymbols; + llvm::SmallVector inReductionSyms; + ReductionProcessor rp; + rp.addDeclareReduction( + currentLocation, converter, clause, inReductionVars, + inReductionByref, inReductionDeclSymbols, + outReductionSyms ? &inReductionSyms : nullptr); + + // Copy local lists into the output. + llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars)); + llvm::copy(inReductionByref, + std::back_inserter(result.inReductionByref)); + llvm::copy(inReductionDeclSymbols, + std::back_inserter(result.inReductionSyms)); + + if (outReductionTypes) { + outReductionTypes->reserve(outReductionTypes->size() + + inReductionVars.size()); + llvm::transform(inReductionVars, + std::back_inserter(*outReductionTypes), + [](mlir::Value v) { return v.getType(); }); + } + + if (outReductionSyms) + llvm::copy(inReductionSyms, std::back_inserter(*outReductionSyms)); + }); +} + bool ClauseProcessor::processTo( llvm::SmallVectorImpl &result) const { return findRepeatableClause( diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 0c8e7bd47ab5a6..04416a927a1c37 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -129,6 +129,16 @@ class ClauseProcessor { llvm::SmallVectorImpl *reductionTypes = nullptr, llvm::SmallVectorImpl *reductionSyms = nullptr) const; + bool processTaskReduction( + mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result, + llvm::SmallVectorImpl *taskReductionTypes = nullptr, + llvm::SmallVectorImpl *taskReductionSyms = + nullptr) const; + bool processInReduction( + mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result, + llvm::SmallVectorImpl *inReductionTypes = nullptr, + llvm::SmallVectorImpl *inReductionSyms = + nullptr) const; bool processTo(llvm::SmallVectorImpl &result) const; bool processUseDeviceAddr( lower::StatementContext &stmtCtx, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 8195f4a897a90b..7e1e2afc760f0d 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1243,33 +1243,37 @@ static void genTargetEnterExitUpdateDataClauses( cp.processNowait(clauseOps); } -static void genTaskClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - lower::StatementContext &stmtCtx, - const List &clauses, mlir::Location loc, - mlir::omp::TaskOperands &clauseOps) { +static void genTaskClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, mlir::omp::TaskOperands &clauseOps, + llvm::SmallVectorImpl &inReductionTypes, + llvm::SmallVectorImpl &inReductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processDepend(clauseOps); cp.processFinal(stmtCtx, clauseOps); cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps); + cp.processInReduction(loc, clauseOps, &inReductionTypes, &inReductionSyms); cp.processMergeable(clauseOps); cp.processPriority(stmtCtx, clauseOps); cp.processUntied(clauseOps); // TODO Support delayed privatization. - cp.processTODO( + cp.processTODO( loc, llvm::omp::Directive::OMPD_task); } -static void genTaskgroupClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - const List &clauses, mlir::Location loc, - mlir::omp::TaskgroupOperands &clauseOps) { +static void genTaskgroupClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + const List &clauses, mlir::Location loc, + mlir::omp::TaskgroupOperands &clauseOps, + llvm::SmallVectorImpl &taskReductionTypes, + llvm::SmallVectorImpl &taskReductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); - cp.processTODO(loc, - llvm::omp::Directive::OMPD_taskgroup); + cp.processTaskReduction(loc, clauseOps, &taskReductionTypes, + &taskReductionSyms); } static void genTaskwaitClauses(lower::AbstractConverter &converter, @@ -1869,13 +1873,26 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { lower::StatementContext stmtCtx; mlir::omp::TaskOperands clauseOps; - genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps); + llvm::SmallVector inReductionTypes; + llvm::SmallVector inreductionSyms; + genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps, + inReductionTypes, inreductionSyms); - return genOpWithBody( + auto reductionCallback = [&](mlir::Operation *op) { + genReductionVars(op, converter, loc, inreductionSyms, inReductionTypes); + return inreductionSyms; + }; + + auto taskOp = genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_task) - .setClauses(&item->clauses), + .setClauses(&item->clauses) + .setGenRegionEntryCb(reductionCallback), queue, item, clauseOps); + // Add reduction variables as arguments + llvm::SmallVector blockArgLocs(inReductionTypes.size(), loc); + taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs); + return taskOp; } static mlir::omp::TaskgroupOp @@ -1885,13 +1902,21 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable, const ConstructQueue &queue, ConstructQueue::const_iterator item) { mlir::omp::TaskgroupOperands clauseOps; - genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps); + llvm::SmallVector taskReductionTypes; + llvm::SmallVector taskReductionSyms; + genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps, + taskReductionTypes, taskReductionSyms); - return genOpWithBody( + auto taskgroupOp = genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_taskgroup) .setClauses(&item->clauses), queue, item, clauseOps); + + // Add reduction variables as arguments + llvm::SmallVector blockArgLocs(taskReductionSyms.size(), loc); + taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs); + return taskgroupOp; } static mlir::omp::TaskwaitOp @@ -2767,7 +2792,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, !std::holds_alternative(clause.u) && !std::holds_alternative(clause.u) && !std::holds_alternative(clause.u) && - !std::holds_alternative(clause.u)) { + !std::holds_alternative(clause.u) && + !std::holds_alternative(clause.u) && + !std::holds_alternative(clause.u)) { TODO(clauseLocation, "OpenMP Block construct clause"); } } diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index 9da15ba303a475..66853aa121ff00 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -24,6 +24,7 @@ #include "flang/Parser/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "llvm/Support/CommandLine.h" +#include static llvm::cl::opt forceByrefReduction( "force-byref-reduction", @@ -34,6 +35,32 @@ namespace Fortran { namespace lower { namespace omp { +// explicit template declarations +template void ReductionProcessor::addDeclareReduction( + mlir::Location currentLocation, lower::AbstractConverter &converter, + const omp::clause::Reduction &reduction, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reduceVarByRef, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl *reductionSymbols); + +template void +ReductionProcessor::addDeclareReduction( + mlir::Location currentLocation, lower::AbstractConverter &converter, + const omp::clause::TaskReduction &reduction, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reduceVarByRef, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl *reductionSymbols); + +template void ReductionProcessor::addDeclareReduction( + mlir::Location currentLocation, lower::AbstractConverter &converter, + const omp::clause::InReduction &reduction, + llvm::SmallVectorImpl &reductionVars, + llvm::SmallVectorImpl &reduceVarByRef, + llvm::SmallVectorImpl &reductionDeclSymbols, + llvm::SmallVectorImpl *reductionSymbols); + ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( const omp::clause::ProcedureDesignator &pd) { auto redType = llvm::StringSwitch>( @@ -716,22 +743,22 @@ static bool doReductionByRef(mlir::Value reductionVar) { return false; } +template void ReductionProcessor::addDeclareReduction( mlir::Location currentLocation, lower::AbstractConverter &converter, - const omp::clause::Reduction &reduction, - llvm::SmallVectorImpl &reductionVars, + const T &reduction, llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, llvm::SmallVectorImpl *reductionSymbols) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - - if (std::get>( - reduction.t)) - TODO(currentLocation, "Reduction modifiers are not supported"); + if constexpr (std::is_same::value) { + if (std::get>(reduction.t)) + TODO(currentLocation, "Reduction modifiers are not supported"); + } mlir::omp::DeclareReductionOp decl; const auto &redOperatorList{ - std::get(reduction.t)}; + std::get(reduction.t)}; assert(redOperatorList.size() == 1 && "Expecting single operator"); const auto &redOperator = redOperatorList.front(); const auto &objectList{std::get(reduction.t)}; diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index 0ed5782e5da1b7..d71fde93de33d0 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -120,10 +120,10 @@ class ReductionProcessor { /// Creates a reduction declaration and associates it with an OpenMP block /// directive. + template static void addDeclareReduction( mlir::Location currentLocation, lower::AbstractConverter &converter, - const omp::clause::Reduction &reduction, - llvm::SmallVectorImpl &reductionVars, + const T &reduction, llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, llvm::SmallVectorImpl *reductionSymbols = diff --git a/flang/test/Lower/OpenMP/task_array_reduction.f90 b/flang/test/Lower/OpenMP/task_array_reduction.f90 new file mode 100644 index 00000000000000..74693343744c26 --- /dev/null +++ b/flang/test/Lower/OpenMP/task_array_reduction.f90 @@ -0,0 +1,50 @@ +! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s + +! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_Uxf32 : !fir.ref>> alloc { +! [...] +! CHECK: omp.yield +! CHECK-LABEL: } init { +! [...] +! CHECK: omp.yield +! CHECK-LABEL: } combiner { +! [...] +! CHECK: omp.yield +! CHECK-LABEL: } cleanup { +! [...] +! CHECK: omp.yield +! CHECK: } + +! CHECK-LABEL: func.func @_QPtaskreduction +! CHECK-SAME: (%[[VAL_0:.*]]: !fir.box> {fir.bindc_name = "x"}) { +! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope +! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]] +! CHECK-SAME {uniq_name = "_QFtaskreductionEx"} : (!fir.box>, !fir.dscope) -> (!fir.box>, !fir.box>) +! CHECK: omp.parallel { +! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box> +! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref>> +! CHECK: omp.taskgroup task_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] -> %[[VAL_4:.*]]: !fir.ref>>) { +! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box> +! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_5]] : !fir.ref>> +! CHECK: omp.task in_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref>>) { +! [...] +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.terminator +! CHECK: } +! CHECK: return +! CHECK: } + +subroutine taskReduction(x) + real, dimension(:) :: x + !$omp parallel + !$omp taskgroup task_reduction(+:x) + !$omp task in_reduction(+:x) + x = x + 1 + !$omp end task + !$omp end taskgroup + !$omp end parallel +end subroutine + diff --git a/flang/test/Lower/OpenMP/task_in_reduction.f90 b/flang/test/Lower/OpenMP/task_in_reduction.f90 new file mode 100644 index 00000000000000..26c079d5ac8aa5 --- /dev/null +++ b/flang/test/Lower/OpenMP/task_in_reduction.f90 @@ -0,0 +1,48 @@ +! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s + +!CHECK-LABEL: omp.declare_reduction +!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init { +!CHECK: ^bb0(%{{.*}}: i32): +!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32 +!CHECK: omp.yield(%[[C0_1]] : i32) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32): +!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32 +!CHECK: omp.yield(%[[RES]] : i32) +!CHECK: } + +!CHECK-LABEL: func.func @_QPin_reduction() { +!CHECK: %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFin_reductionEx"} +!CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFin_reductionEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32 +!CHECK: hlfir.assign %[[VAL_2]] to %[[VAL_1]]#0 : i32, !fir.ref +!CHECK: omp.parallel { +!CHECK: omp.taskgroup task_reduction(@[[RED_I32_NAME]] %[[VAL_1]]#0 -> %[[VAL_3:.*]] : !fir.ref) { +!CHECK: omp.task in_reduction(@[[RED_I32_NAME]] %[[VAL_1]]#0 -> %[[VAL_4:.*]] : !fir.ref) { +!CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_4]] {uniq_name = "_QFin_reductionEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref +!CHECK: %[[VAL_7:.*]] = arith.constant 1 : i32 +!CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i32 +!CHECK: hlfir.assign %[[VAL_8]] to %[[VAL_5]]#0 : i32, !fir.ref +!CHECK: omp.terminator +!CHECK: } +!CHECK: omp.terminator +!CHECK: } +!CHECK: omp.terminator +!CHECK: } +!CHECK: return +!CHECK: } + +subroutine in_reduction + integer :: x + x = 0 + !$omp parallel + !$omp taskgroup task_reduction(+:x) + !$omp task in_reduction(+:x) + x = x + 1 + !$omp end task + !$omp end taskgroup + !$omp end parallel +end subroutine + diff --git a/flang/test/Lower/OpenMP/task_reduction.f90 b/flang/test/Lower/OpenMP/task_reduction.f90 new file mode 100644 index 00000000000000..70f64ba34d3405 --- /dev/null +++ b/flang/test/Lower/OpenMP/task_reduction.f90 @@ -0,0 +1,43 @@ +! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s + +!CHECK-LABEL: omp.declare_reduction +!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init { +!CHECK: ^bb0(%{{.*}}: i32): +!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32 +!CHECK: omp.yield(%[[C0_1]] : i32) +!CHECK: } combiner { +!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32): +!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32 +!CHECK: omp.yield(%[[RES]] : i32) +!CHECK: } + +!CHECK-LABEL: func.func @_QPtest_add() { +!CHECK: %[[VAL_0:.*]] = fir.address_of(@_QFtest_addEx) : !fir.ref +!CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFtest_addEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: omp.taskgroup task_reduction(@[[RED_I32_NAME]] %[[VAL_1]]#0 -> %[[VAL_2:.*]] : !fir.ref) { +!CHECK: omp.task { +!CHECK: %[[VAL_3:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFtest_addEx"} +!CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_3]] {uniq_name = "_QFtest_addEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref +!CHECK: hlfir.assign %[[VAL_5]] to %[[VAL_4]]#0 : i32, !fir.ref +!CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref +!CHECK: %[[VAL_7:.*]] = arith.constant 1 : i32 +!CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i32 +!CHECK: hlfir.assign %[[VAL_8]] to %[[VAL_4]]#0 : i32, !fir.ref +!CHECK: omp.terminator +!CHECK: } +!CHECK: omp.terminator +!CHECK: } +!CHECK: return +!CHECK: } + +subroutine test_add + integer :: x = 0 + !$omp taskgroup task_reduction(+:x) + !$omp task + x = x + 1 + !$omp end task + !$omp end taskgroup +end subroutine +