Skip to content

Commit

Permalink
[Flang][OpenMP] Prevent re-composition of composite constructs (llvm#…
Browse files Browse the repository at this point in the history
…102613)

After decomposition of OpenMP compound constructs and assignment of
applicable clauses to each leaf construct, composite constructs are then
combined again into a single element in the construct queue. This helped
later lowering stages easily identify composite constructs.

However, as a result of the re-composition stage, the same list of
clauses is used to produce all MLIR operations corresponding to each
leaf of the original composite construct. This undoes existing logic
introducing implicit clauses and deciding to which leaf construct(s)
each clause applies.

This patch removes construct re-composition logic and updates Flang
lowering to be able to identify composite constructs from a list of leaf
constructs. As a result, the right set of clauses is produced for each
operation representing a leaf of a composite construct.

PR stack:
- llvm#102612
- llvm#102613
  • Loading branch information
skatrak authored and dmpolukhin committed Sep 2, 2024
1 parent d061551 commit 9b2a999
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 496 deletions.
60 changes: 26 additions & 34 deletions flang/lib/Lower/OpenMP/Decomposer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Frontend/OpenMP/ClauseT.h"
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -68,12 +67,6 @@ struct ConstructDecomposition {
};
} // namespace

static UnitConstruct mergeConstructs(uint32_t version,
llvm::ArrayRef<UnitConstruct> units) {
tomp::ConstructCompositionT compose(version, units);
return compose.merged;
}

namespace Fortran::lower::omp {
LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const UnitConstruct &uc) {
Expand All @@ -90,38 +83,37 @@ ConstructQueue buildConstructQueue(
Fortran::lower::pft::Evaluation &eval, const parser::CharBlock &source,
llvm::omp::Directive compound, const List<Clause> &clauses) {

List<UnitConstruct> constructs;

ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
assert(!decompose.output.empty() && "Construct decomposition failed");

llvm::SmallVector<llvm::omp::Directive> loweringUnits;
std::ignore =
llvm::omp::getLeafOrCompositeConstructs(compound, loweringUnits);
uint32_t version = getOpenMPVersionAttribute(modOp);

int leafIndex = 0;
for (llvm::omp::Directive dir_id : loweringUnits) {
llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
llvm::omp::getLeafConstructsOrSelf(dir_id);
size_t numLeafs = leafsOrSelf.size();

llvm::ArrayRef<UnitConstruct> toMerge{&decompose.output[leafIndex],
numLeafs};
auto &uc = constructs.emplace_back(mergeConstructs(version, toMerge));

if (!transferLocations(clauses, uc.clauses)) {
// If some clauses are left without source information, use the
// directive's source.
for (auto &clause : uc.clauses) {
if (clause.source.empty())
clause.source = source;
}
}
leafIndex += numLeafs;
for (UnitConstruct &uc : decompose.output) {
assert(getLeafConstructs(uc.id).empty() && "unexpected compound directive");
// If some clauses are left without source information, use the directive's
// source.
for (auto &clause : uc.clauses)
if (clause.source.empty())
clause.source = source;
}

return decompose.output;
}

bool matchLeafSequence(ConstructQueue::const_iterator item,
const ConstructQueue &queue,
llvm::omp::Directive directive) {
llvm::ArrayRef<llvm::omp::Directive> leafDirs =
llvm::omp::getLeafConstructsOrSelf(directive);

for (auto [dir, leaf] :
llvm::zip_longest(leafDirs, llvm::make_range(item, queue.end()))) {
if (!dir.has_value() || !leaf.has_value())
return false;

if (*dir != leaf->id)
return false;
}

return constructs;
return true;
}

bool isLastItemInQueue(ConstructQueue::const_iterator item,
Expand Down
10 changes: 9 additions & 1 deletion flang/lib/Lower/OpenMP/Decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

#include "Clauses.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
#include "llvm/Frontend/OpenMP/OMP.h"
#include "llvm/Support/Compiler.h"
Expand Down Expand Up @@ -49,6 +48,15 @@ ConstructQueue buildConstructQueue(mlir::ModuleOp modOp,

bool isLastItemInQueue(ConstructQueue::const_iterator item,
const ConstructQueue &queue);

/// Try to match the leaf constructs conforming the given \c directive to the
/// range of leaf constructs starting from \c item to the end of the \c queue.
/// If \c directive doesn't represent a compound directive, check that \c item
/// matches that directive and is the only element before the end of the
/// \c queue.
bool matchLeafSequence(ConstructQueue::const_iterator item,
const ConstructQueue &queue,
llvm::omp::Directive directive);
} // namespace Fortran::lower::omp

#endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H
96 changes: 65 additions & 31 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,7 @@ static void genCompositeDistributeParallelDoSimd(
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
assert(std::distance(item, queue.end()) == 4 && "Invalid leaf constructs");
TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
}

Expand All @@ -2054,17 +2055,23 @@ static void genCompositeDistributeSimd(
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;

assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
ConstructQueue::const_iterator distributeItem = item;
ConstructQueue::const_iterator simdItem = std::next(distributeItem);

// Clause processing.
mlir::omp::DistributeOperands distributeClauseOps;
genDistributeClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
distributeClauseOps);
genDistributeClauses(converter, semaCtx, stmtCtx, distributeItem->clauses,
loc, distributeClauseOps);

mlir::omp::SimdOperands simdClauseOps;
genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps);
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);

// Pass the innermost leaf construct's clauses because that's where COLLAPSE
// is placed by construct decomposition.
mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
loopNestClauseOps, iv);

// Operation creation.
Expand All @@ -2086,7 +2093,7 @@ static void genCompositeDistributeSimd(
llvm::concat<mlir::BlockArgument>(distributeOp.getRegion().getArguments(),
simdOp.getRegion().getArguments()));

genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem,
loopNestClauseOps, iv, /*wrapperSyms=*/{}, wrapperArgs,
llvm::omp::Directive::OMPD_distribute_simd, dsp);
}
Expand All @@ -2100,19 +2107,25 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;

assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
ConstructQueue::const_iterator doItem = item;
ConstructQueue::const_iterator simdItem = std::next(doItem);

// Clause processing.
mlir::omp::WsloopOperands wsloopClauseOps;
llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
llvm::SmallVector<mlir::Type> wsloopReductionTypes;
genWsloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);

mlir::omp::SimdOperands simdClauseOps;
genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps);
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);

// Pass the innermost leaf construct's clauses because that's where COLLAPSE
// is placed by construct decomposition.
mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
loopNestClauseOps, iv);

// Operation creation.
Expand All @@ -2133,7 +2146,7 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
auto wrapperArgs = llvm::to_vector(llvm::concat<mlir::BlockArgument>(
wsloopOp.getRegion().getArguments(), simdOp.getRegion().getArguments()));

genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem,
loopNestClauseOps, iv, wsloopReductionSyms, wrapperArgs,
llvm::omp::Directive::OMPD_do_simd, dsp);
}
Expand All @@ -2143,13 +2156,44 @@ static void genCompositeTaskloopSimd(
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
TODO(loc, "Composite TASKLOOP SIMD");
}

//===----------------------------------------------------------------------===//
// Dispatch
//===----------------------------------------------------------------------===//

static bool genOMPCompositeDispatch(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
using llvm::omp::Directive;
using lower::omp::matchLeafSequence;

if (matchLeafSequence(item, queue, Directive::OMPD_distribute_parallel_do))
genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
queue, item, dsp);
else if (matchLeafSequence(item, queue,
Directive::OMPD_distribute_parallel_do_simd))
genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
loc, queue, item, dsp);
else if (matchLeafSequence(item, queue, Directive::OMPD_distribute_simd))
genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
item, dsp);
else if (matchLeafSequence(item, queue, Directive::OMPD_do_simd))
genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
dsp);
else if (matchLeafSequence(item, queue, Directive::OMPD_taskloop_simd))
genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
item, dsp);
else
return false;

return true;
}

static void genOMPDispatch(lower::AbstractConverter &converter,
lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
Expand All @@ -2163,10 +2207,18 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
llvm::omp::Association::Loop;
if (loopLeaf) {
symTable.pushScope();
// TODO: Use one DataSharingProcessor for each leaf of a composite
// construct.
loopDsp.emplace(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
/*useDelayedPrivatization=*/false, &symTable);
loopDsp->processStep1();

if (genOMPCompositeDispatch(converter, symTable, semaCtx, eval, loc, queue,
item, *loopDsp)) {
symTable.popScope();
return;
}
}

switch (llvm::omp::Directive dir = item->id) {
Expand Down Expand Up @@ -2262,29 +2314,11 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
// that use this construct, add a single construct for now.
genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
break;

// Composite constructs
case llvm::omp::Directive::OMPD_distribute_parallel_do:
genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
queue, item, *loopDsp);
break;
case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
loc, queue, item, *loopDsp);
break;
case llvm::omp::Directive::OMPD_distribute_simd:
genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
item, *loopDsp);
break;
case llvm::omp::Directive::OMPD_do_simd:
genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
*loopDsp);
break;
case llvm::omp::Directive::OMPD_taskloop_simd:
genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
item, *loopDsp);
break;
default:
// Combined and composite constructs should have been split into a sequence
// of leaf constructs when building the construct queue.
assert(!llvm::omp::isLeafConstruct(dir) &&
"Unexpected compound construct.");
break;
}

Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
subroutine testDoSimdLinear(int_array)
integer :: int_array(*)
!CHECK: not yet implemented: Unhandled clause LINEAR in DO construct
!CHECK: not yet implemented: Unhandled clause LINEAR in SIMD construct
!$omp do simd linear(int_array)
do index_ = 1, 10
end do
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenMP/default-clause-byref.f90
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ subroutine nested_default_clause_tests
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_testsEz"}
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_K_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_K]] {uniq_name = "_QFnested_default_clause_testsEk"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_X:.*]] : {{.*}}) {
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenMP/default-clause.f90
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ end program default_clause_lowering
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_test1Ez"}
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_test1Ex"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[PRIVATE_K_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_K]] {uniq_name = "_QFnested_default_clause_test1Ek"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_X:.*]] : {{.*}}) {
Expand Down
Loading

0 comments on commit 9b2a999

Please sign in to comment.