Skip to content

Commit

Permalink
[Flang] [OpenMP] Support for lowering task_reduction and in_reduction…
Browse files Browse the repository at this point in the history
… to MLIR
  • Loading branch information
kaviya2510 committed Oct 7, 2024
1 parent 66b2820 commit 711fe33
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 28 deletions.
76 changes: 75 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(
rp.addDeclareReduction<omp::clause::Reduction>(
currentLocation, converter, clause, reductionVars, reduceVarByRef,
reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);

Expand All @@ -1085,6 +1085,80 @@ bool ClauseProcessor::processReduction(
});
}

bool ClauseProcessor::processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
return findRepeatableClause<omp::clause::TaskReduction>(
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> taskReductionVars;
llvm::SmallVector<bool> taskReductionByref;
llvm::SmallVector<mlir::Attribute> taskReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction<omp::clause::TaskReduction>(
currentLocation, converter, clause, taskReductionVars,
taskReductionByref, taskReductionDeclSymbols,
outReductionSyms ? &taskReductionSyms : nullptr);

// Copy local lists into the output.
llvm::copy(taskReductionVars,
std::back_inserter(result.taskReductionVars));
llvm::copy(taskReductionByref,
std::back_inserter(result.taskReductionByref));
llvm::copy(taskReductionDeclSymbols,
std::back_inserter(result.taskReductionSyms));

if (outReductionTypes) {
outReductionTypes->reserve(outReductionTypes->size() +
taskReductionVars.size());
llvm::transform(taskReductionVars,
std::back_inserter(*outReductionTypes),
[](mlir::Value v) { return v.getType(); });
}

if (outReductionSyms)
llvm::copy(taskReductionSyms, std::back_inserter(*outReductionSyms));
});
}

bool ClauseProcessor::processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
return findRepeatableClause<omp::clause::InReduction>(
[&](const omp::clause::InReduction &clause,
const parser::CharBlock &source) {
llvm::SmallVector<mlir::Value> inReductionVars;
llvm::SmallVector<bool> inReductionByref;
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction<omp::clause::InReduction>(
currentLocation, converter, clause, inReductionVars,
inReductionByref, inReductionDeclSymbols,
outReductionSyms ? &inReductionSyms : nullptr);

// Copy local lists into the output.
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
llvm::copy(inReductionByref,
std::back_inserter(result.inReductionByref));
llvm::copy(inReductionDeclSymbols,
std::back_inserter(result.inReductionSyms));

if (outReductionTypes) {
outReductionTypes->reserve(outReductionTypes->size() +
inReductionVars.size());
llvm::transform(inReductionVars,
std::back_inserter(*outReductionTypes),
[](mlir::Value v) { return v.getType(); });
}

if (outReductionSyms)
llvm::copy(inReductionSyms, std::back_inserter(*outReductionSyms));
});
}

bool ClauseProcessor::processTo(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::To>(
Expand Down
10 changes: 10 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
nullptr) const;
bool processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *taskReductionTypes = nullptr,
llvm::SmallVectorImpl<const semantics::Symbol *> *taskReductionSyms =
nullptr) const;
bool processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<mlir::Type> *inReductionTypes = nullptr,
llvm::SmallVectorImpl<const semantics::Symbol *> *inReductionSyms =
nullptr) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
Expand Down
63 changes: 45 additions & 18 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1243,33 +1243,37 @@ static void genTargetEnterExitUpdateDataClauses(
cp.processNowait(clauseOps);
}

static void genTaskClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskOperands &clauseOps) {
static void genTaskClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &inReductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processDepend(clauseOps);
cp.processFinal(stmtCtx, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
cp.processInReduction(loc, clauseOps, &inReductionTypes, &inReductionSyms);
cp.processMergeable(clauseOps);
cp.processPriority(stmtCtx, clauseOps);
cp.processUntied(clauseOps);
// TODO Support delayed privatization.

cp.processTODO<clause::Affinity, clause::Detach, clause::InReduction>(
cp.processTODO<clause::Affinity, clause::Detach>(
loc, llvm::omp::Directive::OMPD_task);
}

static void genTaskgroupClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupOperands &clauseOps) {
static void genTaskgroupClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupOperands &clauseOps,
llvm::SmallVectorImpl<mlir::Type> &taskReductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processTODO<clause::TaskReduction>(loc,
llvm::omp::Directive::OMPD_taskgroup);
cp.processTaskReduction(loc, clauseOps, &taskReductionTypes,
&taskReductionSyms);
}

static void genTaskwaitClauses(lower::AbstractConverter &converter,
Expand Down Expand Up @@ -1869,13 +1873,26 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
lower::StatementContext stmtCtx;
mlir::omp::TaskOperands clauseOps;
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
llvm::SmallVector<mlir::Type> inReductionTypes;
llvm::SmallVector<const semantics::Symbol *> inreductionSyms;
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
inReductionTypes, inreductionSyms);

return genOpWithBody<mlir::omp::TaskOp>(
auto reductionCallback = [&](mlir::Operation *op) {
genReductionVars(op, converter, loc, inreductionSyms, inReductionTypes);
return inreductionSyms;
};

auto taskOp = genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_task)
.setClauses(&item->clauses),
.setClauses(&item->clauses)
.setGenRegionEntryCb(reductionCallback),
queue, item, clauseOps);
// Add reduction variables as arguments
llvm::SmallVector<mlir::Location> blockArgLocs(inReductionTypes.size(), loc);
taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs);
return taskOp;
}

static mlir::omp::TaskgroupOp
Expand All @@ -1885,13 +1902,21 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
mlir::omp::TaskgroupOperands clauseOps;
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
llvm::SmallVector<mlir::Type> taskReductionTypes;
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
taskReductionTypes, taskReductionSyms);

return genOpWithBody<mlir::omp::TaskgroupOp>(
auto taskgroupOp = genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_taskgroup)
.setClauses(&item->clauses),
queue, item, clauseOps);

// Add reduction variables as arguments
llvm::SmallVector<mlir::Location> blockArgLocs(taskReductionSyms.size(), loc);
taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs);
return taskgroupOp;
}

static mlir::omp::TaskwaitOp
Expand Down Expand Up @@ -2767,7 +2792,9 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
!std::holds_alternative<clause::ThreadLimit>(clause.u) &&
!std::holds_alternative<clause::Threads>(clause.u) &&
!std::holds_alternative<clause::UseDeviceAddr>(clause.u) &&
!std::holds_alternative<clause::UseDevicePtr>(clause.u)) {
!std::holds_alternative<clause::UseDevicePtr>(clause.u) &&
!std::holds_alternative<clause::TaskReduction>(clause.u) &&
!std::holds_alternative<clause::InReduction>(clause.u)) {
TODO(clauseLocation, "OpenMP Block construct clause");
}
}
Expand Down
41 changes: 34 additions & 7 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>

static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
Expand All @@ -34,6 +35,32 @@ namespace Fortran {
namespace lower {
namespace omp {

// explicit template declarations
template void ReductionProcessor::addDeclareReduction<omp::clause::Reduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols);

template void
ReductionProcessor::addDeclareReduction<omp::clause::TaskReduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::TaskReduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols);

template void ReductionProcessor::addDeclareReduction<omp::clause::InReduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::InReduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols);

ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
Expand Down Expand Up @@ -716,22 +743,22 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}

template <class T>
void ReductionProcessor::addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
if constexpr (std::is_same<T, omp::clause::Reduction>::value) {
if (std::get<std::optional<typename T::ReductionModifier>>(reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
}

mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
std::get<typename T::ReductionIdentifiers>(reduction.t)};
assert(redOperatorList.size() == 1 && "Expecting single operator");
const auto &redOperator = redOperatorList.front();
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ class ReductionProcessor {

/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
template <class T>
static void addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSymbols =
Expand Down
50 changes: 50 additions & 0 deletions flang/test/Lower/OpenMP/task_array_reduction.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s

! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_Uxf32 : !fir.ref<!fir.box<!fir.array<?xf32>>> alloc {
! [...]
! CHECK: omp.yield
! CHECK-LABEL: } init {
! [...]
! CHECK: omp.yield
! CHECK-LABEL: } combiner {
! [...]
! CHECK: omp.yield
! CHECK-LABEL: } cleanup {
! [...]
! CHECK: omp.yield
! CHECK: }

! CHECK-LABEL: func.func @_QPtaskreduction
! CHECK-SAME: (%[[VAL_0:.*]]: !fir.box<!fir.array<?xf32>> {fir.bindc_name = "x"}) {
! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]]
! CHECK-SAME {uniq_name = "_QFtaskreductionEx"} : (!fir.box<!fir.array<?xf32>>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>)
! CHECK: omp.parallel {
! CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_3]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
! CHECK: omp.taskgroup task_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_3]] -> %[[VAL_4:.*]]: !fir.ref<!fir.box<!fir.array<?xf32>>>) {
! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<?xf32>>
! CHECK: fir.store %[[VAL_2]]#1 to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<?xf32>>>
! CHECK: omp.task in_reduction(byref @add_reduction_byref_box_Uxf32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<?xf32>>>) {
! [...]
! CHECK: omp.terminator
! CHECK: }
! CHECK: omp.terminator
! CHECK: }
! CHECK: omp.terminator
! CHECK: }
! CHECK: return
! CHECK: }

subroutine taskReduction(x)
real, dimension(:) :: x
!$omp parallel
!$omp taskgroup task_reduction(+:x)
!$omp task in_reduction(+:x)
x = x + 1
!$omp end task
!$omp end taskgroup
!$omp end parallel
end subroutine

Loading

0 comments on commit 711fe33

Please sign in to comment.