Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flang][OpenMP] PFT-based detection of target SPMD #144

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 155 additions & 41 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,68 @@ using namespace Fortran::lower::omp;
// Code generation helper functions
//===----------------------------------------------------------------------===//

static bool evalHasSiblings(lower::pft::Evaluation &eval) {
/// Get the directive enumeration value corresponding to the given OpenMP
/// construct PFT node.
llvm::omp::Directive
extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) {
return common::visit(
common::visitors{
[](const parser::OpenMPAllocatorsConstruct &c) {
return llvm::omp::OMPD_allocators;
},
[](const parser::OpenMPAtomicConstruct &c) {
return llvm::omp::OMPD_atomic;
},
[](const parser::OpenMPBlockConstruct &c) {
return std::get<parser::OmpBlockDirective>(
std::get<parser::OmpBeginBlockDirective>(c.t).t)
.v;
},
[](const parser::OpenMPCriticalConstruct &c) {
return llvm::omp::OMPD_critical;
},
[](const parser::OpenMPDeclarativeAllocate &c) {
return llvm::omp::OMPD_allocate;
},
[](const parser::OpenMPExecutableAllocate &c) {
return llvm::omp::OMPD_allocate;
},
[](const parser::OpenMPLoopConstruct &c) {
return std::get<parser::OmpLoopDirective>(
std::get<parser::OmpBeginLoopDirective>(c.t).t)
.v;
},
[](const parser::OpenMPSectionConstruct &c) {
return llvm::omp::OMPD_section;
},
[](const parser::OpenMPSectionsConstruct &c) {
return std::get<parser::OmpSectionsDirective>(
std::get<parser::OmpBeginSectionsDirective>(c.t).t)
.v;
},
[](const parser::OpenMPStandaloneConstruct &c) {
return common::visit(
common::visitors{
[](const parser::OpenMPSimpleStandaloneConstruct &c) {
return std::get<parser::OmpSimpleStandaloneDirective>(c.t)
.v;
},
[](const parser::OpenMPFlushConstruct &c) {
return llvm::omp::OMPD_flush;
},
[](const parser::OpenMPCancelConstruct &c) {
return llvm::omp::OMPD_cancel;
},
[](const parser::OpenMPCancellationPointConstruct &c) {
return llvm::omp::OMPD_cancellation_point;
}},
c.u);
}},
ompConstruct.u);
}

/// Check whether the parent of the given evaluation contains other evaluations.
static bool evalHasSiblings(const lower::pft::Evaluation &eval) {
auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) {
for (auto &sibling : siblings)
if (&sibling != &eval && !sibling.isEndStmt())
Expand All @@ -67,6 +128,80 @@ static bool evalHasSiblings(lower::pft::Evaluation &eval) {
}});
}

/// Check whether a given evaluation points to an OpenMP loop construct that
/// represents a target SPMD kernel. For this to be true, it must be a `target
/// teams distribute parallel do [simd]` or equivalent construct.
///
/// Currently, this is limited to cases where all relevant OpenMP constructs are
/// either combined or directly nested within the same function. Also, the
/// composite `distribute parallel do` is not identified if split into two
/// explicit nested loops (a `distribute` loop and a `parallel do` loop).
static bool isTargetSPMDLoop(const lower::pft::Evaluation &eval) {
using namespace llvm::omp;

const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
if (!ompEval)
return false;

switch (extractOmpDirective(*ompEval)) {
case OMPD_distribute_parallel_do:
case OMPD_distribute_parallel_do_simd: {
// It will return true only if one of these are true:
// - It has a 'target teams' parent and no siblings.
// - It has a 'teams' parent and no siblings, and the 'teams' has a
// 'target' parent and no siblings.
if (evalHasSiblings(eval))
return false;

const auto *parentEval = eval.parent.getIf<lower::pft::Evaluation>();
if (!parentEval)
return false;

const auto *parentOmpEval = parentEval->getIf<parser::OpenMPConstruct>();
if (!parentOmpEval)
return false;

auto parentDir = extractOmpDirective(*parentOmpEval);
if (parentDir == OMPD_target_teams)
return true;

if (parentDir != OMPD_teams)
return false;

if (evalHasSiblings(*parentEval))
return false;

const auto *parentOfParentEval =
parentEval->parent.getIf<lower::pft::Evaluation>();
if (!parentEval)
return false;

const auto *parentOfParentOmpEval =
parentOfParentEval->getIf<parser::OpenMPConstruct>();
return parentOfParentOmpEval &&
extractOmpDirective(*parentOfParentOmpEval) == OMPD_target;
}
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd: {
// Check there's a 'target' parent and no siblings.
if (evalHasSiblings(eval))
return false;

const auto *parentEval = eval.parent.getIf<lower::pft::Evaluation>();
if (!parentEval)
return false;

const auto *parentOmpEval = parentEval->getIf<parser::OpenMPConstruct>();
return parentOmpEval && extractOmpDirective(*parentOmpEval) == OMPD_target;
}
case OMPD_target_teams_distribute_parallel_do:
case OMPD_target_teams_distribute_parallel_do_simd:
return true;
default:
return false;
}
}

static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) {
mlir::Operation *parentOp = builder.getBlock()->getParentOp();
if (!parentOp)
Expand Down Expand Up @@ -113,8 +248,9 @@ static void genNestedEvaluations(lower::AbstractConverter &converter,
converter.genEval(e);
}

static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
mlir::omp::TargetOp targetOp) {
static bool
mustEvalTeamsThreadsOutsideTarget(const lower::pft::Evaluation &eval,
mlir::omp::TargetOp targetOp) {
if (!targetOp)
return false;

Expand All @@ -123,25 +259,8 @@ static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
if (offloadModOp.getIsTargetDevice())
return false;

auto dir = Fortran::common::visit(
common::visitors{
[&](const parser::OpenMPBlockConstruct &c) {
return std::get<parser::OmpBlockDirective>(
std::get<parser::OmpBeginBlockDirective>(c.t).t)
.v;
},
[&](const parser::OpenMPLoopConstruct &c) {
return std::get<parser::OmpLoopDirective>(
std::get<parser::OmpBeginLoopDirective>(c.t).t)
.v;
},
[&](const auto &) {
llvm_unreachable("Unexpected OpenMP construct");
return llvm::omp::OMPD_unknown;
},
},
eval.get<parser::OpenMPConstruct>().u);

llvm::omp::Directive dir =
extractOmpDirective(eval.get<parser::OpenMPConstruct>());
return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval);
}

Expand Down Expand Up @@ -1722,25 +1841,20 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
firOpBuilder.getModule().getOperation());
auto targetOp = loopNestOp->getParentOfType<mlir::omp::TargetOp>();

if (offloadMod && targetOp && !offloadMod.getIsTargetDevice()) {
if (targetOp.isTargetSPMDLoop()) {
// Lower loop bounds and step, and process collapsing again, putting
// lowered values outside of omp.target this time. This enables
// calculating and accessing the trip count in the host, which is needed
// when lowering to LLVM IR via the OMPIRBuilder.
HostClausesInsertionGuard guard(firOpBuilder);
mlir::omp::LoopRelatedOps loopRelatedOps;
llvm::SmallVector<const semantics::Symbol *> iv;
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processCollapse(loc, eval, loopRelatedOps, iv);
targetOp.getTripCountMutable().assign(
calculateTripCount(converter.getFirOpBuilder(), loc, loopRelatedOps));
} else if (targetOp.getTripCountMutable().size()) {
// The MLIR target operation was updated during PFT lowering,
// and it is no longer an SPMD kernel. Erase the trip count because
// as it is now invalid.
targetOp.getTripCountMutable().erase(0);
}
if (offloadMod && !offloadMod.getIsTargetDevice() && isTargetSPMDLoop(eval)) {
assert(targetOp && "must have omp.target parent");

// Lower loop bounds and step, and process collapsing again, putting lowered
// values outside of omp.target this time. This enables calculating and
// accessing the trip count in the host, which is needed when lowering to
// LLVM IR via the OMPIRBuilder.
HostClausesInsertionGuard guard(firOpBuilder);
mlir::omp::LoopRelatedOps loopRelatedOps;
llvm::SmallVector<const semantics::Symbol *> iv;
ClauseProcessor cp(converter, semaCtx, item->clauses);
cp.processCollapse(loc, eval, loopRelatedOps, iv);
targetOp.getTripCountMutable().assign(
calculateTripCount(converter.getFirOpBuilder(), loc, loopRelatedOps));
}
return loopNestOp;
}
Expand Down
Loading