Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][OpenMP]Support for lowering task_reduction and in_reduction to MLIR #111155

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(
rp.addDeclareReduction<omp::clause::Reduction>(
currentLocation, converter, clause, reductionVars, reduceVarByRef,
reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);

Expand All @@ -1085,6 +1085,80 @@ bool ClauseProcessor::processReduction(
});
}

bool ClauseProcessor::processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
return findRepeatableClause<omp::clause::TaskReduction>(
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> taskReductionVars;
llvm::SmallVector<bool> taskReductionByref;
llvm::SmallVector<mlir::Attribute> taskReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction<omp::clause::TaskReduction>(
currentLocation, converter, clause, taskReductionVars,
taskReductionByref, taskReductionDeclSymbols,
outReductionSyms ? &taskReductionSyms : nullptr);

// Copy local lists into the output.
llvm::copy(taskReductionVars,
std::back_inserter(result.taskReductionVars));
llvm::copy(taskReductionByref,
std::back_inserter(result.taskReductionByref));
llvm::copy(taskReductionDeclSymbols,
std::back_inserter(result.taskReductionSyms));

if (outReductionTypes) {
outReductionTypes->reserve(outReductionTypes->size() +
taskReductionVars.size());
llvm::transform(taskReductionVars,
std::back_inserter(*outReductionTypes),
[](mlir::Value v) { return v.getType(); });
}

if (outReductionSyms)
llvm::copy(taskReductionSyms, std::back_inserter(*outReductionSyms));
});
}

bool ClauseProcessor::processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
return findRepeatableClause<omp::clause::InReduction>(
[&](const omp::clause::InReduction &clause,
const parser::CharBlock &source) {
llvm::SmallVector<mlir::Value> inReductionVars;
llvm::SmallVector<bool> inReductionByref;
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction<omp::clause::InReduction>(
currentLocation, converter, clause, inReductionVars,
inReductionByref, inReductionDeclSymbols,
outReductionSyms ? &inReductionSyms : nullptr);

// Copy local lists into the output.
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
llvm::copy(inReductionByref,
std::back_inserter(result.inReductionByref));
llvm::copy(inReductionDeclSymbols,
std::back_inserter(result.inReductionSyms));

if (outReductionTypes) {
outReductionTypes->reserve(outReductionTypes->size() +
inReductionVars.size());
llvm::transform(inReductionVars,
std::back_inserter(*outReductionTypes),
[](mlir::Value v) { return v.getType(); });
}

if (outReductionSyms)
llvm::copy(inReductionSyms, std::back_inserter(*outReductionSyms));
});
}

bool ClauseProcessor::processTo(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::To>(
Expand Down
10 changes: 10 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
nullptr) const;
bool processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *taskReductionTypes = nullptr,
llvm::SmallVectorImpl<const semantics::Symbol *> *taskReductionSyms =
nullptr) const;
bool processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *inReductionTypes = nullptr,
llvm::SmallVectorImpl<const semantics::Symbol *> *inReductionSyms =
nullptr) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
Expand Down
63 changes: 45 additions & 18 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1243,33 +1243,37 @@ static void genTargetEnterExitUpdateDataClauses(
cp.processNowait(clauseOps);
}

static void genTaskClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskOperands &clauseOps) {
static void genTaskClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &inReductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processDepend(clauseOps);
cp.processFinal(stmtCtx, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
cp.processInReduction(loc, clauseOps, &inReductionTypes, &inReductionSyms);
cp.processMergeable(clauseOps);
cp.processPriority(stmtCtx, clauseOps);
cp.processUntied(clauseOps);
// TODO Support delayed privatization.

cp.processTODO<clause::Affinity, clause::Detach, clause::InReduction>(
cp.processTODO<clause::Affinity, clause::Detach>(
loc, llvm::omp::Directive::OMPD_task);
}

static void genTaskgroupClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupOperands &clauseOps) {
static void genTaskgroupClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupOperands &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &taskReductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processTODO<clause::TaskReduction>(loc,
llvm::omp::Directive::OMPD_taskgroup);
cp.processTaskReduction(loc, clauseOps, &taskReductionTypes,
&taskReductionSyms);
}

static void genTaskwaitClauses(lower::AbstractConverter &converter,
Expand Down Expand Up @@ -1869,13 +1873,26 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
lower::StatementContext stmtCtx;
mlir::omp::TaskOperands clauseOps;
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
llvm::SmallVector<mlir::Type> inReductionTypes;
llvm::SmallVector<const semantics::Symbol *> inreductionSyms;
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
inReductionTypes, inreductionSyms);

return genOpWithBody<mlir::omp::TaskOp>(
auto reductionCallback = [&](mlir::Operation *op) {
genReductionVars(op, converter, loc, inreductionSyms, inReductionTypes);
return inreductionSyms;
};

auto taskOp = genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_task)
.setClauses(&item->clauses),
.setClauses(&item->clauses)
.setGenRegionEntryCb(reductionCallback),
queue, item, clauseOps);
// Add reduction variables as arguments
llvm::SmallVector<mlir::Location> blockArgLocs(inReductionTypes.size(), loc);
taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs);
return taskOp;
}

static mlir::omp::TaskgroupOp
Expand All @@ -1885,13 +1902,21 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
mlir::omp::TaskgroupOperands clauseOps;
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
llvm::SmallVector<mlir::Type> taskReductionTypes;
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
taskReductionTypes, taskReductionSyms);

return genOpWithBody<mlir::omp::TaskgroupOp>(
auto taskgroupOp = genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_taskgroup)
.setClauses(&item->clauses),
queue, item, clauseOps);

// Add reduction variables as arguments
llvm::SmallVector<mlir::Location> blockArgLocs(taskReductionSyms.size(), loc);
taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs);
return taskgroupOp;
}

static mlir::omp::TaskwaitOp
Expand Down Expand Up @@ -2767,7 +2792,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
!std::holds_alternative<clause::ThreadLimit>(clause.u) &&
!std::holds_alternative<clause::Threads>(clause.u) &&
!std::holds_alternative<clause::UseDeviceAddr>(clause.u) &&
!std::holds_alternative<clause::UseDevicePtr>(clause.u)) {
!std::holds_alternative<clause::UseDevicePtr>(clause.u) &&
!std::holds_alternative<clause::TaskReduction>(clause.u) &&
!std::holds_alternative<clause::InReduction>(clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
Expand Down
41 changes: 34 additions & 7 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>

static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
Expand All @@ -34,6 +35,32 @@ namespace Fortran {
namespace lower {
namespace omp {

// explicit template declarations
template void ReductionProcessor::addDeclareReduction<omp::clause::Reduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols);

template void
ReductionProcessor::addDeclareReduction<omp::clause::TaskReduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::TaskReduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols);

template void ReductionProcessor::addDeclareReduction<omp::clause::InReduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::InReduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols);

ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
Expand Down Expand Up @@ -716,22 +743,22 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}

template <class T>
void ReductionProcessor::addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
if constexpr (std::is_same<T, omp::clause::Reduction>::value) {
if (std::get<std::optional<typename T::ReductionModifier>>(reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
}

mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
std::get<typename T::ReductionIdentifiers>(reduction.t)};
assert(redOperatorList.size() == 1 && "Expecting single operator");
const auto &redOperator = redOperatorList.front();
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ class ReductionProcessor {

/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
template <class T>
static void addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols =
Expand Down
50 changes: 50 additions & 0 deletions flang/test/Lower/OpenMP/task_array_reduction.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s

! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_Uxf32 : !fir.ref<!fir.box<!fir.array<?xf32>>> alloc {
! [...]
! CHECK: omp.yield
! CHECK-LABEL: } init {
! [...]
! CHECK: omp.yield
! CHECK-LABEL: } combiner {
! [...]
! CHECK: omp.yield
! CHECK-LABEL: } cleanup {
! [...]
! CHECK: omp.yield
! CHECK: }

! CHECK-LABEL: func.func @_QPtaskreduction
! CHECK-SAME: (%[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]]
! CHECK-SAME {uniq_name = "_QFtaskreductionEx"} : (!fir.box<!fir.array<?xf32>>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
! CHECK: omp.parallel {
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
! CHECK: omp.taskgroup task_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] -> %[[VAL_4:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>) {
! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
! CHECK: omp.task in_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
! [...]
! CHECK: omp.terminator
! CHECK: }
! CHECK: omp.terminator
! CHECK: }
! CHECK: omp.terminator
! CHECK: }
! CHECK: return
! CHECK: }

subroutine taskReduction(x)
real, dimension(:) :: x
!$omp parallel
!$omp taskgroup task_reduction(+:x)
!$omp task in_reduction(+:x)
x = x + 1
!$omp end task
!$omp end taskgroup
!$omp end parallel
end subroutine

Loading
Loading