Skip to content

Commit d55879d

Browse files
authored
[flang][OpenMP] Emit requirements in module files (#163449)
For each program unit, collect the set of requirements from REQUIRES directives in the source, and modules used by the program unit, and add them to the details of the program unit symbol. The requirements in the symbol details as now stored as clauses. Since requirements need to be emitted in the module files as OpenMP directives, this makes the clause emission straightforward via getOpenMPClauseName. Each program unit, including modules, the corresponding symbol will have the transitive closure of the requirements for everything contained or used in that program unit.
1 parent 1adbae9 commit d55879d

File tree

8 files changed

+147
-138
lines changed

8 files changed

+147
-138
lines changed

flang/include/flang/Semantics/symbol.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "flang/Semantics/module-dependences.h"
1717
#include "flang/Support/Fortran.h"
1818
#include "llvm/ADT/DenseMapInfo.h"
19+
#include "llvm/Frontend/OpenMP/OMP.h"
1920

2021
#include <array>
2122
#include <functional>
@@ -50,32 +51,31 @@ using MutableSymbolVector = std::vector<MutableSymbolRef>;
5051

5152
// Mixin for details with OpenMP declarative constructs.
5253
class WithOmpDeclarative {
53-
using OmpAtomicOrderType = common::OmpMemoryOrderType;
54-
5554
public:
56-
ENUM_CLASS(RequiresFlag, ReverseOffload, UnifiedAddress, UnifiedSharedMemory,
57-
DynamicAllocators);
58-
using RequiresFlags = common::EnumSet<RequiresFlag, RequiresFlag_enumSize>;
55+
// The set of requirements for any program unit include requirements
56+
// from any module used in the program unit.
57+
using RequiresClauses =
58+
common::EnumSet<llvm::omp::Clause, llvm::omp::Clause_enumSize>;
5959

6060
bool has_ompRequires() const { return ompRequires_.has_value(); }
61-
const RequiresFlags *ompRequires() const {
61+
const RequiresClauses *ompRequires() const {
6262
return ompRequires_ ? &*ompRequires_ : nullptr;
6363
}
64-
void set_ompRequires(RequiresFlags flags) { ompRequires_ = flags; }
64+
void set_ompRequires(RequiresClauses clauses) { ompRequires_ = clauses; }
6565

6666
bool has_ompAtomicDefaultMemOrder() const {
6767
return ompAtomicDefaultMemOrder_.has_value();
6868
}
69-
const OmpAtomicOrderType *ompAtomicDefaultMemOrder() const {
69+
const common::OmpMemoryOrderType *ompAtomicDefaultMemOrder() const {
7070
return ompAtomicDefaultMemOrder_ ? &*ompAtomicDefaultMemOrder_ : nullptr;
7171
}
72-
void set_ompAtomicDefaultMemOrder(OmpAtomicOrderType flags) {
72+
void set_ompAtomicDefaultMemOrder(common::OmpMemoryOrderType flags) {
7373
ompAtomicDefaultMemOrder_ = flags;
7474
}
7575

7676
private:
77-
std::optional<RequiresFlags> ompRequires_;
78-
std::optional<OmpAtomicOrderType> ompAtomicDefaultMemOrder_;
77+
std::optional<RequiresClauses> ompRequires_;
78+
std::optional<common::OmpMemoryOrderType> ompAtomicDefaultMemOrder_;
7979
};
8080

8181
// A module or submodule.

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4208,18 +4208,17 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
42084208
void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
42094209
const semantics::Symbol *symbol) {
42104210
using MlirRequires = mlir::omp::ClauseRequires;
4211-
using SemaRequires = semantics::WithOmpDeclarative::RequiresFlag;
42124211

42134212
if (auto offloadMod =
42144213
llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
4215-
semantics::WithOmpDeclarative::RequiresFlags semaFlags;
4214+
semantics::WithOmpDeclarative::RequiresClauses reqs;
42164215
if (symbol) {
42174216
common::visit(
42184217
[&](const auto &details) {
42194218
if constexpr (std::is_base_of_v<semantics::WithOmpDeclarative,
42204219
std::decay_t<decltype(details)>>) {
42214220
if (details.has_ompRequires())
4222-
semaFlags = *details.ompRequires();
4221+
reqs = *details.ompRequires();
42234222
}
42244223
},
42254224
symbol->details());
@@ -4228,14 +4227,14 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
42284227
// Use pre-populated omp.requires module attribute if it was set, so that
42294228
// the "-fopenmp-force-usm" compiler option is honored.
42304229
MlirRequires mlirFlags = offloadMod.getRequires();
4231-
if (semaFlags.test(SemaRequires::ReverseOffload))
4230+
if (reqs.test(llvm::omp::Clause::OMPC_dynamic_allocators))
4231+
mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
4232+
if (reqs.test(llvm::omp::Clause::OMPC_reverse_offload))
42324233
mlirFlags = mlirFlags | MlirRequires::reverse_offload;
4233-
if (semaFlags.test(SemaRequires::UnifiedAddress))
4234+
if (reqs.test(llvm::omp::Clause::OMPC_unified_address))
42344235
mlirFlags = mlirFlags | MlirRequires::unified_address;
4235-
if (semaFlags.test(SemaRequires::UnifiedSharedMemory))
4236+
if (reqs.test(llvm::omp::Clause::OMPC_unified_shared_memory))
42364237
mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
4237-
if (semaFlags.test(SemaRequires::DynamicAllocators))
4238-
mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
42394238

42404239
offloadMod.setRequires(mlirFlags);
42414240
}

flang/lib/Semantics/mod-file.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
#include "flang/Semantics/semantics.h"
1818
#include "flang/Semantics/symbol.h"
1919
#include "flang/Semantics/tools.h"
20+
#include "llvm/Frontend/OpenMP/OMP.h"
2021
#include "llvm/Support/FileSystem.h"
2122
#include "llvm/Support/MemoryBuffer.h"
2223
#include "llvm/Support/raw_ostream.h"
2324
#include <algorithm>
2425
#include <fstream>
2526
#include <set>
2627
#include <string_view>
28+
#include <type_traits>
2729
#include <variant>
2830
#include <vector>
2931

@@ -359,6 +361,40 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) {
359361
}
360362
}
361363

364+
static void PutOpenMPRequirements(llvm::raw_ostream &os, const Symbol &symbol) {
365+
using RequiresClauses = WithOmpDeclarative::RequiresClauses;
366+
using OmpMemoryOrderType = common::OmpMemoryOrderType;
367+
368+
const auto [reqs, order]{common::visit(
369+
[&](auto &&details)
370+
-> std::pair<const RequiresClauses *, const OmpMemoryOrderType *> {
371+
if constexpr (std::is_convertible_v<decltype(details),
372+
const WithOmpDeclarative &>) {
373+
return {details.ompRequires(), details.ompAtomicDefaultMemOrder()};
374+
} else {
375+
return {nullptr, nullptr};
376+
}
377+
},
378+
symbol.details())};
379+
380+
if (order) {
381+
llvm::omp::Clause admo{llvm::omp::Clause::OMPC_atomic_default_mem_order};
382+
os << "!$omp requires "
383+
<< parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(admo))
384+
<< '(' << parser::ToLowerCaseLetters(EnumToString(*order)) << ")\n";
385+
}
386+
if (reqs) {
387+
os << "!$omp requires";
388+
reqs->IterateOverMembers([&](llvm::omp::Clause f) {
389+
if (f != llvm::omp::Clause::OMPC_atomic_default_mem_order) {
390+
os << ' '
391+
<< parser::ToLowerCaseLetters(llvm::omp::getOpenMPClauseName(f));
392+
}
393+
});
394+
os << "\n";
395+
}
396+
}
397+
362398
// Put out the visible symbols from scope.
363399
void ModFileWriter::PutSymbols(
364400
const Scope &scope, UnorderedSymbolSet *hermeticModules) {
@@ -396,6 +432,7 @@ void ModFileWriter::PutSymbols(
396432
for (const Symbol &symbol : uses) {
397433
PutUse(symbol);
398434
}
435+
PutOpenMPRequirements(decls_, DEREF(scope.symbol()));
399436
for (const auto &set : scope.equivalenceSets()) {
400437
if (!set.empty() &&
401438
!set.front().symbol.test(Symbol::Flag::CompilerCreated)) {

flang/lib/Semantics/resolve-directives.cpp

Lines changed: 48 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,22 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
435435
return true;
436436
}
437437

438+
bool Pre(const parser::UseStmt &x) {
439+
if (x.moduleName.symbol) {
440+
Scope &thisScope{context_.FindScope(x.moduleName.source)};
441+
common::visit(
442+
[&](auto &&details) {
443+
if constexpr (std::is_convertible_v<decltype(details),
444+
const WithOmpDeclarative &>) {
445+
AddOmpRequiresToScope(thisScope, details.ompRequires(),
446+
details.ompAtomicDefaultMemOrder());
447+
}
448+
},
449+
x.moduleName.symbol->details());
450+
}
451+
return true;
452+
}
453+
438454
bool Pre(const parser::OmpMetadirectiveDirective &x) {
439455
PushContext(x.v.source, llvm::omp::Directive::OMPD_metadirective);
440456
return true;
@@ -538,38 +554,37 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
538554
void Post(const parser::OpenMPFlushConstruct &) { PopContext(); }
539555

540556
bool Pre(const parser::OpenMPRequiresConstruct &x) {
541-
using Flags = WithOmpDeclarative::RequiresFlags;
542-
using Requires = WithOmpDeclarative::RequiresFlag;
557+
using RequiresClauses = WithOmpDeclarative::RequiresClauses;
543558
PushContext(x.source, llvm::omp::Directive::OMPD_requires);
544559

545560
// Gather information from the clauses.
546-
Flags flags;
547-
std::optional<common::OmpMemoryOrderType> memOrder;
561+
RequiresClauses reqs;
562+
const common::OmpMemoryOrderType *memOrder{nullptr};
548563
for (const parser::OmpClause &clause : x.v.Clauses().v) {
549-
flags |= common::visit(
564+
using OmpClause = parser::OmpClause;
565+
reqs |= common::visit(
550566
common::visitors{
551-
[&memOrder](
552-
const parser::OmpClause::AtomicDefaultMemOrder &atomic) {
553-
memOrder = atomic.v.v;
554-
return Flags{};
555-
},
556-
[](const parser::OmpClause::ReverseOffload &) {
557-
return Flags{Requires::ReverseOffload};
558-
},
559-
[](const parser::OmpClause::UnifiedAddress &) {
560-
return Flags{Requires::UnifiedAddress};
567+
[&](const OmpClause::AtomicDefaultMemOrder &atomic) {
568+
memOrder = &atomic.v.v;
569+
return RequiresClauses{};
561570
},
562-
[](const parser::OmpClause::UnifiedSharedMemory &) {
563-
return Flags{Requires::UnifiedSharedMemory};
564-
},
565-
[](const parser::OmpClause::DynamicAllocators &) {
566-
return Flags{Requires::DynamicAllocators};
571+
[&](auto &&s) {
572+
using TypeS = llvm::remove_cvref_t<decltype(s)>;
573+
if constexpr ( //
574+
std::is_same_v<TypeS, OmpClause::DynamicAllocators> ||
575+
std::is_same_v<TypeS, OmpClause::ReverseOffload> ||
576+
std::is_same_v<TypeS, OmpClause::UnifiedAddress> ||
577+
std::is_same_v<TypeS, OmpClause::UnifiedSharedMemory>) {
578+
return RequiresClauses{clause.Id()};
579+
} else {
580+
return RequiresClauses{};
581+
}
567582
},
568-
[](const auto &) { return Flags{}; }},
583+
},
569584
clause.u);
570585
}
571586
// Merge clauses into parents' symbols details.
572-
AddOmpRequiresToScope(currScope(), flags, memOrder);
587+
AddOmpRequiresToScope(currScope(), &reqs, memOrder);
573588
return true;
574589
}
575590
void Post(const parser::OpenMPRequiresConstruct &) { PopContext(); }
@@ -1001,8 +1016,9 @@ class OmpAttributeVisitor : DirectiveAttributeVisitor<llvm::omp::Directive> {
10011016

10021017
std::int64_t ordCollapseLevel{0};
10031018

1004-
void AddOmpRequiresToScope(Scope &, WithOmpDeclarative::RequiresFlags,
1005-
std::optional<common::OmpMemoryOrderType>);
1019+
void AddOmpRequiresToScope(Scope &,
1020+
const WithOmpDeclarative::RequiresClauses *,
1021+
const common::OmpMemoryOrderType *);
10061022
void IssueNonConformanceWarning(llvm::omp::Directive D,
10071023
parser::CharBlock source, unsigned EmitFromVersion);
10081024

@@ -3309,86 +3325,6 @@ void ResolveOmpParts(
33093325
}
33103326
}
33113327

3312-
void ResolveOmpTopLevelParts(
3313-
SemanticsContext &context, const parser::Program &program) {
3314-
if (!context.IsEnabled(common::LanguageFeature::OpenMP)) {
3315-
return;
3316-
}
3317-
3318-
// Gather REQUIRES clauses from all non-module top-level program unit symbols,
3319-
// combine them together ensuring compatibility and apply them to all these
3320-
// program units. Modules are skipped because their REQUIRES clauses should be
3321-
// propagated via USE statements instead.
3322-
WithOmpDeclarative::RequiresFlags combinedFlags;
3323-
std::optional<common::OmpMemoryOrderType> combinedMemOrder;
3324-
3325-
// Function to go through non-module top level program units and extract
3326-
// REQUIRES information to be processed by a function-like argument.
3327-
auto processProgramUnits{[&](auto processFn) {
3328-
for (const parser::ProgramUnit &unit : program.v) {
3329-
if (!std::holds_alternative<common::Indirection<parser::Module>>(
3330-
unit.u) &&
3331-
!std::holds_alternative<common::Indirection<parser::Submodule>>(
3332-
unit.u) &&
3333-
!std::holds_alternative<
3334-
common::Indirection<parser::CompilerDirective>>(unit.u)) {
3335-
Symbol *symbol{common::visit(
3336-
[&context](auto &x) {
3337-
Scope *scope = GetScope(context, x.value());
3338-
return scope ? scope->symbol() : nullptr;
3339-
},
3340-
unit.u)};
3341-
// FIXME There is no symbol defined for MainProgram units in certain
3342-
// circumstances, so REQUIRES information has no place to be stored in
3343-
// these cases.
3344-
if (!symbol) {
3345-
continue;
3346-
}
3347-
common::visit(
3348-
[&](auto &details) {
3349-
if constexpr (std::is_convertible_v<decltype(&details),
3350-
WithOmpDeclarative *>) {
3351-
processFn(*symbol, details);
3352-
}
3353-
},
3354-
symbol->details());
3355-
}
3356-
}
3357-
}};
3358-
3359-
// Combine global REQUIRES information from all program units except modules
3360-
// and submodules.
3361-
processProgramUnits([&](Symbol &symbol, WithOmpDeclarative &details) {
3362-
if (const WithOmpDeclarative::RequiresFlags *
3363-
flags{details.ompRequires()}) {
3364-
combinedFlags |= *flags;
3365-
}
3366-
if (const common::OmpMemoryOrderType *
3367-
memOrder{details.ompAtomicDefaultMemOrder()}) {
3368-
if (combinedMemOrder && *combinedMemOrder != *memOrder) {
3369-
context.Say(symbol.scope()->sourceRange(),
3370-
"Conflicting '%s' REQUIRES clauses found in compilation "
3371-
"unit"_err_en_US,
3372-
parser::ToUpperCaseLetters(llvm::omp::getOpenMPClauseName(
3373-
llvm::omp::Clause::OMPC_atomic_default_mem_order)
3374-
.str()));
3375-
}
3376-
combinedMemOrder = *memOrder;
3377-
}
3378-
});
3379-
3380-
// Update all program units except modules and submodules with the combined
3381-
// global REQUIRES information.
3382-
processProgramUnits([&](Symbol &, WithOmpDeclarative &details) {
3383-
if (combinedFlags.any()) {
3384-
details.set_ompRequires(combinedFlags);
3385-
}
3386-
if (combinedMemOrder) {
3387-
details.set_ompAtomicDefaultMemOrder(*combinedMemOrder);
3388-
}
3389-
});
3390-
}
3391-
33923328
static bool IsSymbolThreadprivate(const Symbol &symbol) {
33933329
if (const auto *details{symbol.detailsIf<HostAssocDetails>()}) {
33943330
return details->symbol().test(Symbol::Flag::OmpThreadprivate);
@@ -3547,23 +3483,22 @@ void OmpAttributeVisitor::CheckLabelContext(const parser::CharBlock source,
35473483
}
35483484

35493485
void OmpAttributeVisitor::AddOmpRequiresToScope(Scope &scope,
3550-
WithOmpDeclarative::RequiresFlags flags,
3551-
std::optional<common::OmpMemoryOrderType> memOrder) {
3486+
const WithOmpDeclarative::RequiresClauses *reqs,
3487+
const common::OmpMemoryOrderType *memOrder) {
35523488
const Scope &programUnit{omp::GetProgramUnit(scope)};
3489+
using RequiresClauses = WithOmpDeclarative::RequiresClauses;
3490+
RequiresClauses combinedReqs{reqs ? *reqs : RequiresClauses{}};
35533491

35543492
if (auto *symbol{const_cast<Symbol *>(programUnit.symbol())}) {
35553493
common::visit(
35563494
[&](auto &details) {
3557-
// Store clauses information into the symbol for the parent and
3558-
// enclosing modules, programs, functions and subroutines.
35593495
if constexpr (std::is_convertible_v<decltype(&details),
35603496
WithOmpDeclarative *>) {
3561-
if (flags.any()) {
3562-
if (const WithOmpDeclarative::RequiresFlags *otherFlags{
3563-
details.ompRequires()}) {
3564-
flags |= *otherFlags;
3497+
if (combinedReqs.any()) {
3498+
if (const RequiresClauses *otherReqs{details.ompRequires()}) {
3499+
combinedReqs |= *otherReqs;
35653500
}
3566-
details.set_ompRequires(flags);
3501+
details.set_ompRequires(combinedReqs);
35673502
}
35683503
if (memOrder) {
35693504
if (details.has_ompAtomicDefaultMemOrder() &&

flang/lib/Semantics/resolve-directives.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,5 @@ class SemanticsContext;
2323
void ResolveAccParts(
2424
SemanticsContext &, const parser::ProgramUnit &, Scope *topScope);
2525
void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &);
26-
void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &);
27-
2826
} // namespace Fortran::semantics
2927
#endif

flang/lib/Semantics/resolve-names.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10687,9 +10687,6 @@ void ResolveNamesVisitor::Post(const parser::Program &x) {
1068710687
CHECK(!attrs_);
1068810688
CHECK(!cudaDataAttr_);
1068910689
CHECK(!GetDeclTypeSpec());
10690-
// Top-level resolution to propagate information across program units after
10691-
// each of them has been resolved separately.
10692-
ResolveOmpTopLevelParts(context(), x);
1069310690
}
1069410691

1069510692
// A singleton instance of the scope -> IMPLICIT rules mapping is

0 commit comments

Comments
 (0)