Skip to content

Add ZA directives for Flang. #76505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions flang/docs/Directives.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,28 @@ A list of non-standard directives supported by Flang
end
end interface
```

## ARM Streaming SVE directives

These directives are added to support ARM specific instructions. All of
these attributes apply to a specific subroutine or function. These directives
are identical to the attributes provided in C and C++ for the same purpose.
See https://arm-software.github.io/acle/main/acle.html#controlling-the-use-of-streaming-mode for more in depth details. (For the following, function is used
to mean both subroutine and function).

### Directives relating to ARM Streaming mode

* `!dir$ arm_streaming` - The function is intended to be used in streaming
mode.
* `!dir$ arm_streaming_compatible` - The function can work both in streaming
mode and non-streaming mode.
* `!dir$ arm_streaming` - The function will enter streaming mode, and return to

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean arm_locally_streaming?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, of course. Typical copy-pasta error. Thanks for pointing it out.

non-streaming mode when reaturning.

### Directives relating to ZA

* `!dir$ arm_shared_za` - A function that uses ZA for input or output.
* `!dir$ arm_new_za` - A function that has ZA state created and destroyed within
the function.
* `!dir$ arm_preserves_za` - Optimisation hint for the compiler that the
function either doesn't alter, or saves and restores the ZA state.
10 changes: 8 additions & 2 deletions flang/include/flang/Lower/PFTBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);

void dump(VariableList &, std::string s = {}); // `s` is an optional dump label

/// Things that can be nested inside of a module or function
/// TODO: add the rest
struct FunctionLikeUnit;
struct CompilerDirectiveUnit;
using NestedUnit = std::variant<FunctionLikeUnit, CompilerDirectiveUnit>;

/// Function-like units may contain evaluations (executable statements) and
/// nested function-like units (internal procedures and function statements).
struct FunctionLikeUnit : public ProgramUnit {
Expand Down Expand Up @@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit {
EvaluationList evaluationList;
LabelEvalMap labelEvaluationMap;
SymbolLabelMap assignSymbolLabelMap;
std::list<FunctionLikeUnit> nestedFunctions;
std::list<NestedUnit> nestedUnits;
/// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
/// is the primary entry point; remaining pairs are alternate entry points.
/// The primary entry point symbol is Null for an anonymous program.
Expand Down Expand Up @@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit {

ModuleStatement beginStmt;
ModuleStatement endStmt;
std::list<FunctionLikeUnit> nestedFunctions;
std::list<NestedUnit> nestedUnits;
EvaluationList evaluationList;
};

Expand Down
3 changes: 2 additions & 1 deletion flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -2894,7 +2894,8 @@ struct ModuleSubprogram {
UNION_CLASS_BOILERPLATE(ModuleSubprogram);
std::variant<common::Indirection<FunctionSubprogram>,
common::Indirection<SubroutineSubprogram>,
common::Indirection<SeparateModuleSubprogram>>
common::Indirection<SeparateModuleSubprogram>,
common::Indirection<CompilerDirective>>
u;
};

Expand Down
123 changes: 109 additions & 14 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "flang/Semantics/runtime-type-info.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Parser/Parser.h"
Expand Down Expand Up @@ -303,9 +304,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
for (Fortran::lower::pft::FunctionLikeUnit &f :
m.nestedFunctions)
declareFunction(f);
for (Fortran::lower::pft::NestedUnit &unit :
m.nestedUnits) {
if (auto *f = std::get_if<
Fortran::lower::pft::FunctionLikeUnit>(&unit))
declareFunction(*f);
}
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
Expand All @@ -322,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&]() { createIntrinsicModuleDefinitions(pft); });

// Primary translation pass.
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits();
for (auto it = units.begin(); it != units.end(); it = std::next(it)) {
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
processSubprogramDirective(it, units.end(), d);
},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
builder = new fir::FirOpBuilder(bridge.getModule(),
bridge.getKindMap());
Expand All @@ -338,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder = nullptr;
},
},
u);
*it);
}

// Once all the code has been translated, create global runtime type info
Expand Down Expand Up @@ -387,13 +394,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {

// Compute the set of host associated entities from the nested functions.
llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
collectHostAssociatedVariables(f, escapeHost);
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
collectHostAssociatedVariables(*f, escapeHost);
}
funit.setHostAssociatedSymbols(escapeHost);

// Declare internal procedures
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
declareFunction(f);
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
declareFunction(*f);
}
}

/// Get the scope that is defining or using \p sym. The returned scope is not
Expand Down Expand Up @@ -4667,8 +4678,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
endNewFunction(funit);
}
funit.setActiveEntry(0);
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
lowerFunc(f); // internal procedure
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
lowerFunc(*f); // internal procedure
}
}

/// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
Expand All @@ -4692,8 +4705,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {

/// Lower functions contained in a module.
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
lowerFunc(f);
for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end();
it = std::next(it)) {
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
processSubprogramDirective(it, mod.nestedUnits.end(), d);
}},
*it);
}
}

void setCurrentPosition(const Fortran::parser::CharBlock &position) {
Expand Down Expand Up @@ -5001,6 +5022,80 @@ class FirConverter : public Fortran::lower::AbstractConverter {
globalOmpRequiresSymbol);
}

/// Process compiler directives that apply to subprograms
template <typename ITERATOR>
void
processSubprogramDirective(ITERATOR it, ITERATOR endIt,
Fortran::lower::pft::CompilerDirectiveUnit &d) {
auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>();
if (!parserDirective)
return;
auto *nvList =
std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>(
&parserDirective->u);
if (!nvList)
return;

// get the function the directive applies to (hopefully the next unit)
mlir::func::FuncOp mlirFunc;
it = std::next(it);
if (it != endIt) {
auto *pftFunction =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it);
if (pftFunction) {
Fortran::lower::CalleeInterface callee{*pftFunction, *this};
mlirFunc = callee.getFuncOp();
}
}

for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) {
std::string name = std::get<Fortran::parser::Name>(nv.t).ToString();

// arm streaming sve directives
auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled;
if (name == "arm_streaming")
streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming;
else if (name == "arm_locally_streaming")
streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally;
else if (name == "arm_streaming_compatible")
streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible;
Comment on lines +5055 to +5061
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a generated helper for this (that returns std::optional<arm_sme::ArmStreamingMode>). I believe it would be mlir::arm_sme::symbolizeArmStreamingMode().

if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) {
if (!mlirFunc) {
// TODO: share diagnostic code with warnings elsewhere
// TODO: source location is printed as loc<"file.f90":line:col>
mlir::Location loc = genLocation(parserDirective->source);
llvm::errs() << loc << ": warning: ignoring directive '" << name
<< "' because it has no associated subprogram\n";
continue;
}
llvm::StringRef attrName =
mlir::arm_sme::stringifyArmStreamingMode(streamingMode);
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
mlirFunc->setAttr(attrName, unitAttr);
}
auto zaMode = mlir::arm_sme::ArmZaMode::Disabled;
if (name == "arm_new_za")
zaMode = mlir::arm_sme::ArmZaMode::NewZA;
else if (name == "arm_shared_za")
zaMode = mlir::arm_sme::ArmZaMode::SharedZA;
else if (name == "arm_preserves_za")
zaMode = mlir::arm_sme::ArmZaMode::PreservesZA;
Comment on lines +5077 to +5082
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be mlir::arm_sme::symbolizeArmZaMode() for this too.

if (zaMode != mlir::arm_sme::ArmZaMode::Disabled) {
if (!mlirFunc) {
// TODO: share diagnostic code with warnings elsewhere
// TODO: source location is printed as loc<"file.f90":line:col>
mlir::Location loc = genLocation(parserDirective->source);
llvm::errs() << loc << ": warning: ignoring directive '" << name
<< "' because it has no associated subprogram\n";
continue;
}
llvm::StringRef attrName = mlir::arm_sme::stringifyArmZaMode(zaMode);
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
mlirFunc->setAttr(attrName, unitAttr);
}
}
}

//===--------------------------------------------------------------------===//

Fortran::lower::LoweringBridge &bridge;
Expand Down
Loading