Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions flang/include/flang/Semantics/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -754,12 +754,12 @@ class Symbol {
// OpenMP data-copying attribute
OmpCopyIn, OmpCopyPrivate,
// OpenMP miscellaneous flags
OmpCommonBlock, OmpReduction, OmpAligned, OmpNontemporal, OmpAllocate,
OmpDeclarativeAllocateDirective, OmpExecutableAllocateDirective,
OmpDeclareSimd, OmpDeclareTarget, OmpThreadprivate, OmpDeclareReduction,
OmpFlushed, OmpCriticalLock, OmpIfSpecified, OmpNone, OmpPreDetermined,
OmpImplicit, OmpDependObject, OmpInclusiveScan, OmpExclusiveScan,
OmpInScanReduction);
OmpCommonBlock, OmpReduction, OmpInReduction, OmpAligned, OmpNontemporal,
OmpAllocate, OmpDeclarativeAllocateDirective,
OmpExecutableAllocateDirective, OmpDeclareSimd, OmpDeclareTarget,
OmpThreadprivate, OmpDeclareReduction, OmpFlushed, OmpCriticalLock,
OmpIfSpecified, OmpNone, OmpPreDetermined, OmpImplicit, OmpDependObject,
OmpInclusiveScan, OmpExclusiveScan, OmpInScanReduction);
using Flags = common::EnumSet<Flag, Flag_enumSize>;

const Scope &owner() const { return *owner_; }
Expand Down
51 changes: 49 additions & 2 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,29 @@ bool ClauseProcessor::processIf(
});
return found;
}
bool ClauseProcessor::processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
return findRepeatableClause<omp::clause::InReduction>(
[&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> inReductionVars;
llvm::SmallVector<bool> inReduceVarByRef;
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
ReductionProcessor rp;
rp.processReductionArguments<omp::clause::InReduction>(
currentLocation, converter, clause, inReductionVars,
inReduceVarByRef, inReductionDeclSymbols, inReductionSyms);

// Copy local lists into the output.
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
llvm::copy(inReduceVarByRef,
std::back_inserter(result.inReductionByref));
llvm::copy(inReductionDeclSymbols,
std::back_inserter(result.inReductionSyms));
llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
});
}

bool ClauseProcessor::processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
Expand Down Expand Up @@ -1250,9 +1273,9 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.processReductionArguments(
rp.processReductionArguments<omp::clause::Reduction>(
currentLocation, converter, clause, reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms, result.reductionMod);
reductionDeclSymbols, reductionSyms, &result.reductionMod);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
Expand All @@ -1262,6 +1285,30 @@ bool ClauseProcessor::processReduction(
});
}

bool ClauseProcessor::processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
return findRepeatableClause<omp::clause::TaskReduction>(
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> taskReductionVars;
llvm::SmallVector<bool> TaskReduceVarByRef;
llvm::SmallVector<mlir::Attribute> TaskReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> TaskReductionSyms;
ReductionProcessor rp;
rp.processReductionArguments<omp::clause::TaskReduction>(
currentLocation, converter, clause, taskReductionVars,
TaskReduceVarByRef, TaskReductionDeclSymbols, TaskReductionSyms);
// Copy local lists into the output.
llvm::copy(taskReductionVars,
std::back_inserter(result.taskReductionVars));
llvm::copy(TaskReduceVarByRef,
std::back_inserter(result.taskReductionByref));
llvm::copy(TaskReductionDeclSymbols,
std::back_inserter(result.taskReductionSyms));
llvm::copy(TaskReductionSyms, std::back_inserter(outReductionSyms));
});
}

bool ClauseProcessor::processTo(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::To>(
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ class ClauseProcessor {
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
mlir::omp::IfClauseOps &result) const;
bool processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
Expand All @@ -133,6 +136,9 @@ class ClauseProcessor {
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
bool processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
Expand Down
44 changes: 27 additions & 17 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1773,34 +1773,34 @@ static void genTargetEnterExitUpdateDataClauses(
cp.processNowait(clauseOps);
}

static void genTaskClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
lower::SymMap &symTable,
lower::StatementContext &stmtCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskOperands &clauseOps) {
static void genTaskClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::SymMap &symTable, lower::StatementContext &stmtCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processFinal(stmtCtx, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
cp.processInReduction(loc, clauseOps, inReductionSyms);
cp.processMergeable(clauseOps);
cp.processPriority(stmtCtx, clauseOps);
cp.processUntied(clauseOps);
cp.processDetach(clauseOps);

cp.processTODO<clause::Affinity, clause::InReduction>(
loc, llvm::omp::Directive::OMPD_task);
cp.processTODO<clause::Affinity>(loc, llvm::omp::Directive::OMPD_task);
}

static void genTaskgroupClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupOperands &clauseOps) {
static void genTaskgroupClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processTODO<clause::TaskReduction>(loc,
llvm::omp::Directive::OMPD_taskgroup);
cp.processTaskReduction(loc, clauseOps, taskReductionSyms);
}

static void genTaskwaitClauses(lower::AbstractConverter &converter,
Expand Down Expand Up @@ -2480,8 +2480,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
mlir::omp::TaskOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
genTaskClauses(converter, semaCtx, symTable, stmtCtx, item->clauses, loc,
clauseOps);
clauseOps, inReductionSyms);

if (!enableDelayedPrivatization)
return genOpWithBody<mlir::omp::TaskOp>(
Expand All @@ -2498,6 +2499,8 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
EntryBlockArgs taskArgs;
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
taskArgs.priv.vars = clauseOps.privateVars;
taskArgs.inReduction.syms = inReductionSyms;
taskArgs.inReduction.vars = clauseOps.inReductionVars;

return genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
Expand All @@ -2515,12 +2518,19 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
mlir::omp::TaskgroupOperands clauseOps;
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
taskReductionSyms);

EntryBlockArgs taskgroupArgs;
taskgroupArgs.taskReduction.syms = taskReductionSyms;
taskgroupArgs.taskReduction.vars = clauseOps.taskReductionVars;

return genOpWithBody<mlir::omp::TaskgroupOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_taskgroup)
.setClauses(&item->clauses),
.setClauses(&item->clauses)
.setEntryBlockArgs(&taskgroupArgs),
queue, item, clauseOps);
}

Expand Down
56 changes: 45 additions & 11 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>

static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
Expand All @@ -38,6 +39,37 @@ namespace Fortran {
namespace lower {
namespace omp {

// explicit template declarations
template void
ReductionProcessor::processReductionArguments<omp::clause::Reduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr *reductionMod);

template void
ReductionProcessor::processReductionArguments<omp::clause::TaskReduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::TaskReduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr *reductionMod);

template void
ReductionProcessor::processReductionArguments<omp::clause::InReduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::InReduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr *reductionMod);

ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
Expand Down Expand Up @@ -538,28 +570,30 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
return mlir::omp::ReductionModifier::defaultmod;
}

template <class T>
void ReductionProcessor::processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr &reductionMod) {
mlir::omp::ReductionModifierAttr *reductionMod) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
if (mod.has_value()) {
if (mod.value() == ReductionModifier::Task)
TODO(currentLocation, "Reduction modifier `task` is not supported");
else
reductionMod = mlir::omp::ReductionModifierAttr::get(
firOpBuilder.getContext(), translateReductionModifier(mod.value()));
if constexpr (std::is_same_v<T, omp::clause::Reduction>) {
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
if (mod.has_value()) {
if (mod.value() == ReductionModifier::Task)
TODO(currentLocation, "Reduction modifier `task` is not supported");
else
*reductionMod = mlir::omp::ReductionModifierAttr::get(
firOpBuilder.getContext(), translateReductionModifier(mod.value()));
}
}

mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
std::get<typename T::ReductionIdentifiers>(reduction.t)};
assert(redOperatorList.size() == 1 && "Expecting single operator");
const auto &redOperator = redOperatorList.front();
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
Expand Down
6 changes: 3 additions & 3 deletions flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ class ReductionProcessor {

/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
template <class T>
static void processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr &reductionMod);
mlir::omp::ReductionModifierAttr *reductionMod = nullptr);
};

template <typename FloatOp, typename IntegerOp>
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Semantics/resolve-directives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,12 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
return false;
}

bool Pre(const parser::OmpInReductionClause &x) {
auto &objects{std::get<parser::OmpObjectList>(x.t)};
ResolveOmpObjectList(objects, Symbol::Flag::OmpInReduction);
return false;
}

bool Pre(const parser::OmpClause::Reduction &x) {
const auto &objList{std::get<parser::OmpObjectList>(x.v.t)};
ResolveOmpObjectList(objList, Symbol::Flag::OmpReduction);
Expand Down
15 changes: 0 additions & 15 deletions flang/test/Lower/OpenMP/Todo/task-inreduction.f90

This file was deleted.

10 changes: 0 additions & 10 deletions flang/test/Lower/OpenMP/Todo/taskgroup-task-reduction.f90

This file was deleted.

35 changes: 35 additions & 0 deletions flang/test/Lower/OpenMP/task-inreduction.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s

!CHECK-LABEL: omp.declare_reduction
!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
!CHECK: ^bb0(%{{.*}}: i32):
!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
!CHECK: omp.yield(%[[C0_1]] : i32)
!CHECK: } combiner {
!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
!CHECK: omp.yield(%[[RES]] : i32)
!CHECK: }

!CHECK-LABEL: func.func @_QPomp_task_in_reduction() {
! [...]
!CHECK: omp.task in_reduction(@[[RED_I32_NAME]] %[[VAL_1:.*]]#0 -> %[[ARG0]] : !fir.ref<i32>) {
!CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[ARG0]]
!CHECK-SAME: {uniq_name = "_QFomp_task_in_reductionEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<i32>
!CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
!CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
!CHECK: hlfir.assign %[[VAL_7]] to %[[VAL_4]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
!CHECK: }

subroutine omp_task_in_reduction()
integer i
i = 0
!$omp task in_reduction(+:i)
i = i + 1
!$omp end task
end subroutine omp_task_in_reduction
Loading