From 8ae2bbdb9afec6fd086efb791746affb87af3662 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Fri, 16 Aug 2024 16:02:30 +0100 Subject: [PATCH] [Flang][OpenMP] PFT-based detection of target SPMD This patch improves the fix in #125 to detect target SPMD kernels during Flang lowering to MLIR. It transitions from a MLIR-based check to a PFT-based check, which is a more resilient alternative since the MLIR representation is in process of being built where it's being checked. --- flang/lib/Lower/OpenMP/OpenMP.cpp | 196 +++++++++++++++++++----- flang/test/Lower/OpenMP/target-spmd.f90 | 157 +++++++++++++++++++ 2 files changed, 312 insertions(+), 41 deletions(-) create mode 100644 flang/test/Lower/OpenMP/target-spmd.f90 diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 245a5e63ea1b7b..4a640eb1a37dda 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -46,7 +46,68 @@ using namespace Fortran::lower::omp; // Code generation helper functions //===----------------------------------------------------------------------===// -static bool evalHasSiblings(lower::pft::Evaluation &eval) { +/// Get the directive enumeration value corresponding to the given OpenMP +/// construct PFT node. +llvm::omp::Directive +extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) { + return common::visit( + common::visitors{ + [](const parser::OpenMPAllocatorsConstruct &c) { + return llvm::omp::OMPD_allocators; + }, + [](const parser::OpenMPAtomicConstruct &c) { + return llvm::omp::OMPD_atomic; + }, + [](const parser::OpenMPBlockConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [](const parser::OpenMPCriticalConstruct &c) { + return llvm::omp::OMPD_critical; + }, + [](const parser::OpenMPDeclarativeAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPExecutableAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPLoopConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [](const parser::OpenMPSectionConstruct &c) { + return llvm::omp::OMPD_section; + }, + [](const parser::OpenMPSectionsConstruct &c) { + return std::get( + std::get(c.t).t) + .v; + }, + [](const parser::OpenMPStandaloneConstruct &c) { + return common::visit( + common::visitors{ + [](const parser::OpenMPSimpleStandaloneConstruct &c) { + return std::get(c.t) + .v; + }, + [](const parser::OpenMPFlushConstruct &c) { + return llvm::omp::OMPD_flush; + }, + [](const parser::OpenMPCancelConstruct &c) { + return llvm::omp::OMPD_cancel; + }, + [](const parser::OpenMPCancellationPointConstruct &c) { + return llvm::omp::OMPD_cancellation_point; + }}, + c.u); + }}, + ompConstruct.u); +} + +/// Check whether the parent of the given evaluation contains other evaluations. +static bool evalHasSiblings(const lower::pft::Evaluation &eval) { auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) { for (auto &sibling : siblings) if (&sibling != &eval && !sibling.isEndStmt()) @@ -67,6 +128,80 @@ static bool evalHasSiblings(lower::pft::Evaluation &eval) { }}); } +/// Check whether a given evaluation points to an OpenMP loop construct that +/// represents a target SPMD kernel. For this to be true, it must be a `target +/// teams distribute parallel do [simd]` or equivalent construct. +/// +/// Currently, this is limited to cases where all relevant OpenMP constructs are +/// either combined or directly nested within the same function. Also, the +/// composite `distribute parallel do` is not identified if split into two +/// explicit nested loops (a `distribute` loop and a `parallel do` loop). +static bool isTargetSPMDLoop(const lower::pft::Evaluation &eval) { + using namespace llvm::omp; + + const auto *ompEval = eval.getIf(); + if (!ompEval) + return false; + + switch (extractOmpDirective(*ompEval)) { + case OMPD_distribute_parallel_do: + case OMPD_distribute_parallel_do_simd: { + // It will return true only if one of these are true: + // - It has a 'target teams' parent and no siblings. + // - It has a 'teams' parent and no siblings, and the 'teams' has a + // 'target' parent and no siblings. + if (evalHasSiblings(eval)) + return false; + + const auto *parentEval = eval.parent.getIf(); + if (!parentEval) + return false; + + const auto *parentOmpEval = parentEval->getIf(); + if (!parentOmpEval) + return false; + + auto parentDir = extractOmpDirective(*parentOmpEval); + if (parentDir == OMPD_target_teams) + return true; + + if (parentDir != OMPD_teams) + return false; + + if (evalHasSiblings(*parentEval)) + return false; + + const auto *parentOfParentEval = + parentEval->parent.getIf(); + if (!parentEval) + return false; + + const auto *parentOfParentOmpEval = + parentOfParentEval->getIf(); + return parentOfParentOmpEval && + extractOmpDirective(*parentOfParentOmpEval) == OMPD_target; + } + case OMPD_teams_distribute_parallel_do: + case OMPD_teams_distribute_parallel_do_simd: { + // Check there's a 'target' parent and no siblings. + if (evalHasSiblings(eval)) + return false; + + const auto *parentEval = eval.parent.getIf(); + if (!parentEval) + return false; + + const auto *parentOmpEval = parentEval->getIf(); + return parentOmpEval && extractOmpDirective(*parentOmpEval) == OMPD_target; + } + case OMPD_target_teams_distribute_parallel_do: + case OMPD_target_teams_distribute_parallel_do_simd: + return true; + default: + return false; + } +} + static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) { mlir::Operation *parentOp = builder.getBlock()->getParentOp(); if (!parentOp) @@ -113,8 +248,9 @@ static void genNestedEvaluations(lower::AbstractConverter &converter, converter.genEval(e); } -static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval, - mlir::omp::TargetOp targetOp) { +static bool +mustEvalTeamsThreadsOutsideTarget(const lower::pft::Evaluation &eval, + mlir::omp::TargetOp targetOp) { if (!targetOp) return false; @@ -123,25 +259,8 @@ static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval, if (offloadModOp.getIsTargetDevice()) return false; - auto dir = Fortran::common::visit( - common::visitors{ - [&](const parser::OpenMPBlockConstruct &c) { - return std::get( - std::get(c.t).t) - .v; - }, - [&](const parser::OpenMPLoopConstruct &c) { - return std::get( - std::get(c.t).t) - .v; - }, - [&](const auto &) { - llvm_unreachable("Unexpected OpenMP construct"); - return llvm::omp::OMPD_unknown; - }, - }, - eval.get().u); - + llvm::omp::Directive dir = + extractOmpDirective(eval.get()); return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval); } @@ -1722,25 +1841,20 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable, firOpBuilder.getModule().getOperation()); auto targetOp = loopNestOp->getParentOfType(); - if (offloadMod && targetOp && !offloadMod.getIsTargetDevice()) { - if (targetOp.isTargetSPMDLoop()) { - // Lower loop bounds and step, and process collapsing again, putting - // lowered values outside of omp.target this time. This enables - // calculating and accessing the trip count in the host, which is needed - // when lowering to LLVM IR via the OMPIRBuilder. - HostClausesInsertionGuard guard(firOpBuilder); - mlir::omp::LoopRelatedOps loopRelatedOps; - llvm::SmallVector iv; - ClauseProcessor cp(converter, semaCtx, item->clauses); - cp.processCollapse(loc, eval, loopRelatedOps, iv); - targetOp.getTripCountMutable().assign( - calculateTripCount(converter.getFirOpBuilder(), loc, loopRelatedOps)); - } else if (targetOp.getTripCountMutable().size()) { - // The MLIR target operation was updated during PFT lowering, - // and it is no longer an SPMD kernel. Erase the trip count because - // as it is now invalid. - targetOp.getTripCountMutable().erase(0); - } + if (offloadMod && !offloadMod.getIsTargetDevice() && isTargetSPMDLoop(eval)) { + assert(targetOp && "must have omp.target parent"); + + // Lower loop bounds and step, and process collapsing again, putting lowered + // values outside of omp.target this time. This enables calculating and + // accessing the trip count in the host, which is needed when lowering to + // LLVM IR via the OMPIRBuilder. + HostClausesInsertionGuard guard(firOpBuilder); + mlir::omp::LoopRelatedOps loopRelatedOps; + llvm::SmallVector iv; + ClauseProcessor cp(converter, semaCtx, item->clauses); + cp.processCollapse(loc, eval, loopRelatedOps, iv); + targetOp.getTripCountMutable().assign( + calculateTripCount(converter.getFirOpBuilder(), loc, loopRelatedOps)); } return loopNestOp; } diff --git a/flang/test/Lower/OpenMP/target-spmd.f90 b/flang/test/Lower/OpenMP/target-spmd.f90 new file mode 100644 index 00000000000000..f1fc699c8d6dd2 --- /dev/null +++ b/flang/test/Lower/OpenMP/target-spmd.f90 @@ -0,0 +1,157 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +!CHECK-LABEL: func.func @_QPdistribute_parallel_do_generic() { +subroutine distribute_parallel_do_generic() + ! CHECK: omp.target + ! CHECK-NOT: trip_count({{.*}}) + ! CHECK-SAME: { + !$omp target + !$omp teams + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + call bar() + !$omp end teams + !$omp end target + + ! CHECK: omp.target + ! CHECK-NOT: trip_count({{.*}}) + ! CHECK-SAME: { + !$omp target teams + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + call bar() + !$omp end target teams +end subroutine distribute_parallel_do_generic + +!CHECK-LABEL: func.func @_QPdistribute_parallel_do_spmd() { +subroutine distribute_parallel_do_spmd() + ! CHECK: omp.target + ! CHECK-SAME: trip_count({{.*}}) + !$omp target + !$omp teams + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + !$omp end teams + !$omp end target + + ! CHECK: omp.target + ! CHECK-SAME: trip_count({{.*}}) + !$omp target teams + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + !$omp end target teams +end subroutine distribute_parallel_do_spmd + +!CHECK-LABEL: func.func @_QPdistribute_parallel_do_simd_generic() { +subroutine distribute_parallel_do_simd_generic() + ! CHECK: omp.target + ! CHECK-NOT: trip_count({{.*}}) + ! CHECK-SAME: { + !$omp target + !$omp teams + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + call bar() + !$omp end teams + !$omp end target + + ! CHECK: omp.target + ! CHECK-NOT: trip_count({{.*}}) + ! CHECK-SAME: { + !$omp target teams + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + call bar() + !$omp end target teams +end subroutine distribute_parallel_do_simd_generic + +!CHECK-LABEL: func.func @_QPdistribute_parallel_do_simd_spmd() { +subroutine distribute_parallel_do_simd_spmd() + ! CHECK: omp.target + ! CHECK-SAME: trip_count({{.*}}) + !$omp target + !$omp teams + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + !$omp end teams + !$omp end target + + ! CHECK: omp.target + ! CHECK-SAME: trip_count({{.*}}) + !$omp target teams + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + !$omp end target teams +end subroutine distribute_parallel_do_simd_spmd + +!CHECK-LABEL: func.func @_QPteams_distribute_parallel_do_spmd() { +subroutine teams_distribute_parallel_do_spmd() + ! CHECK: omp.target + ! CHECK-SAME: trip_count({{.*}}) + !$omp target + !$omp teams distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end teams distribute parallel do + !$omp end target +end subroutine teams_distribute_parallel_do_spmd + +!CHECK-LABEL: func.func @_QPteams_distribute_parallel_do_simd_spmd() { +subroutine teams_distribute_parallel_do_simd_spmd() + ! CHECK: omp.target + ! CHECK-SAME: trip_count({{.*}}) + !$omp target + !$omp teams distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end teams distribute parallel do simd + !$omp end target +end subroutine teams_distribute_parallel_do_simd_spmd + +!CHECK-LABEL: func.func @_QPtarget_teams_distribute_parallel_do_spmd() { +subroutine target_teams_distribute_parallel_do_spmd() + ! CHECK: omp.target + ! CHECK-SAME: trip_count({{.*}}) + !$omp target teams distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end target teams distribute parallel do +end subroutine target_teams_distribute_parallel_do_spmd + +!CHECK-LABEL: func.func @_QPtarget_teams_distribute_parallel_do_simd_spmd() { +subroutine target_teams_distribute_parallel_do_simd_spmd() + ! CHECK: omp.target + ! CHECK-SAME: trip_count({{.*}}) + !$omp target teams distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end target teams distribute parallel do simd +end subroutine target_teams_distribute_parallel_do_simd_spmd