Skip to content

[flang][OpenMP] Enable tiling #143715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
943ecef
Initial implementation of tiling.
jsjodin Mar 14, 2025
bcaacf9
Fix tests and limit the nesting of construct to only tiling.
jsjodin Jun 10, 2025
8c2cd48
Enable stand-alone tiling, but it gives a warning and converting to s…
jsjodin Jun 10, 2025
5c17a0d
Add minimal test, remove debug print.
jsjodin Jun 11, 2025
7e7437b
Fix formatting
jsjodin Jun 13, 2025
3367c5e
Fix formatting
jsjodin Jun 14, 2025
753653d
Fix test.
jsjodin Jun 19, 2025
2d09a69
Add more mlir tests. Set collapse value when lowering from SCF to Ope…
jsjodin Jun 20, 2025
db09c91
Fix formatting
jsjodin Jun 20, 2025
ab74c1a
Use llvm::SmallVector instead of std::stack
jsjodin Jun 20, 2025
88348a3
Improve test a bit to make sure IVs are used as expected.
jsjodin Jun 21, 2025
d362617
Fix comments to clarify canonicalization.
jsjodin Jun 21, 2025
60aa4b0
Special handling of tile directive when dealing with start end end lo…
jsjodin Jun 21, 2025
ad32cc3
Inline functions.
jsjodin Jun 21, 2025
6da33a4
Remove debug code.
jsjodin Jun 23, 2025
991e042
Reuse loop op lowering, add comment.
jsjodin Jun 23, 2025
b4e2109
Fix formatting.
jsjodin Jun 23, 2025
5a3d8d2
Remove curly braces.
jsjodin Jun 23, 2025
3029793
Avoid attaching the sizes clause to the parent construct, instead fin…
jsjodin Jun 25, 2025
bdead72
Fix formatting
jsjodin Jun 25, 2025
f1c260d
Fix unparse and add a test for nested loop constructs.
jsjodin Jun 26, 2025
7701e07
Use more convenient function to get OpenMPLoopConstruct. Fix comments.
jsjodin Jun 26, 2025
7c3b3e5
Fix formatting.
jsjodin Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flang/include/flang/Lower/OpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &,
void genOpenMPSymbolProperties(AbstractConverter &converter,
const pft::Variable &var);

int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
Expand Down
4 changes: 3 additions & 1 deletion flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -5024,8 +5024,10 @@ struct OpenMPBlockConstruct {
struct OpenMPLoopConstruct {
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
: t({std::move(a), std::nullopt, std::nullopt}) {}
: t({std::move(a), std::nullopt, std::nullopt, std::nullopt}) {}
std::tuple<OmpBeginLoopDirective, std::optional<DoConstruct>,
// Inner loop construct used to handle tiling for now.
std::optional<common::Indirection<OpenMPLoopConstruct>>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you add a comment here explaining that this is used mainly for tiling constructs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this strictly only support nesting of two loop constructs or do you plan for it to work like a linked list: with each loop construct pointing to the next inner-most construct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can work as a linked list for multiple nested loop constructs. Note that we also might want to keep the DoConstruct inside the innermost loop construct in the future, instead of the outer one. The choice to keep it with the outer loop construct was to minimize the changes in the rest of the lowering since we can only handle tiling. Moving the DoConstruct in the canonicalization to the tiling loop construct (inner most) can be a separate PR if that is a better representation. Allowing multiple nested ops would be the final step, at that point we would need to use the CLIs, canonical loops and loop transformation ops.

Copy link
Contributor

@Stylie777 Stylie777 Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowing multiple nested ops would be the final step, at that point we would need to use the CLIs, canonical loops and loop transformation ops.

As part of my work for #110008 I have implemented nested ops for tile and unroll in the semantics checks (PR coming soon, just doing finishing touches). It does work slightly differently in that std::optional<DoConstruct> has become std::optional<std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>> in the Parse Tree.

This approach stores either one or the other, and then Flang can determine which is being used. I personally think this is a better approach as then you are not opening the possibility for there being a DoConstruct and OpenMPLoopConstruct. It works in a similar way where it will link together the LoopConstructs until you reach a DoConstruct.

Its currently a standalone change, as this is not merged yet, but it would be good to agree on an approach so there is minimal changes required when one of these is merged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that having the nesting done in this way is a better approach, I forgot that I actually tried out putting the DoConstruct inside the innermost OpenMPLoopConstruct, and the code generation was fine, my comment above was incorrect. The parse tree after canonicalization looks like this:

| | ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
| | | OmpBeginLoopDirective
| | | | OmpLoopDirective -> llvm::omp::Directive = do
| | | | OmpClauseList -> 
| | | OpenMPLoopConstruct
| | | | OmpBeginLoopDirective
| | | | | OmpLoopDirective -> llvm::omp::Directive = tile
| | | | | OmpClauseList -> OmpClause -> Sizes -> Scalar -> Integer -> Expr = '2_4'
| | | | | | LiteralConstant -> IntLiteralConstant = '2'
| | | | DoConstruct
| | | | | NonLabelDoStmt
| | | | | | LoopControl -> LoopBounds
| | | | | | | Scalar -> Name = 'x'
| | | | | | | Scalar -> Expr = '1_4'
| | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
| | | | | | | Scalar -> Expr = '100_4'
| | | | | | | | LiteralConstant -> IntLiteralConstant = '100'

So this should be pretty compatible with what you are working on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, the changes I have worked on give a similar parsing outcome. For reference, this will be my version of OpenMPLoopConstruct

struct OpenMPLoopConstruct {
  TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
  OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
      : t({std::move(a), std::nullopt, std::nullopt}) {}
  std::tuple<OmpBeginLoopDirective, std::optional<std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>>,
      std::optional<OmpEndLoopDirective>>
      t;
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed a change for a test and to fix unparse.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for reference, my PR for the Semantics covering unroll and tile are here: #145917

There does seem to be some crossover with this patch, so some rebasing will be needed depending on which is merged first.

std::optional<OmpEndLoopDirective>>
t;
};
Expand Down
68 changes: 53 additions & 15 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

using namespace Fortran::lower::omp;
using namespace Fortran::common::openmp;
using namespace Fortran::semantics;

static llvm::cl::opt<bool> DumpAtomicAnalysis("fdebug-dump-atomic-analysis");

Expand Down Expand Up @@ -456,6 +457,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
return;

const parser::OmpClauseList *beginClauseList = nullptr;
const parser::OmpClauseList *middleClauseList = nullptr;
const parser::OmpClauseList *endClauseList = nullptr;
common::visit(
common::visitors{
Expand All @@ -473,6 +475,23 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
beginClauseList =
&std::get<parser::OmpClauseList>(beginDirective.t);

// FIXME(JAN): For now we check if there is an inner
// OpenMPLoopConstruct, and extract the size clause from there
const auto &innerOptional = std::get<std::optional<
common::Indirection<parser::OpenMPLoopConstruct>>>(
ompConstruct.t);
if (innerOptional.has_value()) {
const auto &innerLoopDirective = innerOptional.value().value();
const auto &innerBegin =
std::get<parser::OmpBeginLoopDirective>(
innerLoopDirective.t);
const auto &innerDirective =
std::get<parser::OmpLoopDirective>(innerBegin.t);
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
middleClauseList =
&std::get<parser::OmpClauseList>(innerBegin.t);
}
}
if (auto &endDirective =
std::get<std::optional<parser::OmpEndLoopDirective>>(
ompConstruct.t))
Expand All @@ -485,6 +504,9 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
assert(beginClauseList && "expected begin directive");
clauses.append(makeClauses(*beginClauseList, semaCtx));

if (middleClauseList)
clauses.append(makeClauses(*middleClauseList, semaCtx));

if (endClauseList)
clauses.append(makeClauses(*endClauseList, semaCtx));
};
Expand Down Expand Up @@ -960,6 +982,7 @@ static void genLoopVars(
storeOp =
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Extra whitespace

firOpBuilder.setInsertionPointAfter(storeOp);
}

Expand Down Expand Up @@ -1712,6 +1735,30 @@ genLoopNestClauses(lower::AbstractConverter &converter,
cp.processCollapse(loc, eval, clauseOps, iv);

clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();

fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
for (auto &clause : clauses) {
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
const auto &collapse = std::get<clause::Collapse>(clause.u);
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
// This case handles the stand-alone tiling construct
const auto &sizes = std::get<clause::Sizes>(clause.u);
llvm::SmallVector<int64_t> sizeValues;
for (auto &size : sizes.v) {
int64_t sizeValue = evaluate::ToInt64(size).value();
sizeValues.push_back(sizeValue);
}
clauseOps.tileSizes = sizeValues;
}
}

llvm::SmallVector<int64_t> sizeValues;
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
collectTileSizesFromOpenMPConstruct(ompCons, sizeValues, semaCtx);
if (sizeValues.size() > 0)
clauseOps.tileSizes = sizeValues;
}

static void genLoopClauses(
Expand Down Expand Up @@ -2085,9 +2132,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
return llvm::SmallVector<const semantics::Symbol *>(iv);
};

auto *nestedEval =
getCollapsedLoopEval(eval, getCollapseValue(item->clauses));

uint64_t nestValue = getCollapseValue(item->clauses);
nestValue = nestValue < iv.size() ? iv.size() : nestValue;
auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
return genOpWithBody<mlir::omp::LoopNestOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
directive)
Expand Down Expand Up @@ -3610,6 +3657,8 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
item);
break;
case llvm::omp::Directive::OMPD_tile:
newOp = genLoopOp(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_unroll: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
TODO(loc, "Unhandled loop directive (" +
Expand Down Expand Up @@ -4186,6 +4235,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t);
List<Clause> clauses = makeClauses(
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);

if (auto &endLoopDirective =
std::get<std::optional<parser::OmpEndLoopDirective>>(
loopConstruct.t)) {
Expand Down Expand Up @@ -4292,18 +4342,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
lower::genDeclareTargetIntGlobal(converter, var);
}

int64_t
Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) {
for (const parser::OmpClause &clause : clauseList.v) {
if (const auto &collapseClause =
std::get_if<parser::OmpClause::Collapse>(&clause.u)) {
const auto *expr = semantics::GetExpr(collapseClause->v);
return evaluate::ToInt64(*expr).value();
}
}
return 1;
}

void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter,
const lower::pft::Variable &var) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Expand Down
112 changes: 104 additions & 8 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "Clauses.h"

#include "ClauseFinder.h"
#include "flang/Evaluate/fold.h"
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
#include <flang/Lower/DirectivesCommon.h>
Expand All @@ -24,10 +25,30 @@
#include <flang/Parser/parse-tree.h>
#include <flang/Parser/tools.h>
#include <flang/Semantics/tools.h>
#include <flang/Semantics/type.h>
#include <llvm/Support/CommandLine.h>

#include <iterator>

using namespace Fortran::semantics;

template <typename T>
MaybeIntExpr EvaluateIntExpr(SemanticsContext &context, const T &expr) {
if (MaybeExpr maybeExpr{
Fold(context.foldingContext(), AnalyzeExpr(context, expr))}) {
if (auto *intExpr{Fortran::evaluate::UnwrapExpr<SomeIntExpr>(*maybeExpr)}) {
return std::move(*intExpr);
}
}
return std::nullopt;
}

template <typename T>
std::optional<std::int64_t> EvaluateInt64(SemanticsContext &context,
const T &expr) {
return Fortran::evaluate::ToInt64(EvaluateIntExpr(context, expr));
}

llvm::cl::opt<bool> treatIndexAsSection(
"openmp-treat-index-as-section",
llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
Expand All @@ -38,14 +59,21 @@ namespace lower {
namespace omp {

int64_t getCollapseValue(const List<Clause> &clauses) {
auto iter = llvm::find_if(clauses, [](const Clause &clause) {
return clause.id == llvm::omp::Clause::OMPC_collapse;
});
if (iter != clauses.end()) {
const auto &collapse = std::get<clause::Collapse>(iter->u);
return evaluate::ToInt64(collapse.v).value();
int64_t collapseValue = 1;
int64_t numTileSizes = 0;
for (auto &clause : clauses) {
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
const auto &collapse = std::get<clause::Collapse>(clause.u);
collapseValue = evaluate::ToInt64(collapse.v).value();
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
const auto &sizes = std::get<clause::Sizes>(clause.u);
numTileSizes = sizes.v.size();
}
}
return 1;

collapseValue = collapseValue - numTileSizes;
int64_t result = collapseValue > numTileSizes ? collapseValue : numTileSizes;
return result;
Comment on lines +74 to +76
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain to me why we need this calculation? To see the effect of this, I tried replacing these few lines with simply return collapseValue; and ran all tests but no tests failed. So it seems this part is not tested. A test can also help explaining the purpose of the change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The general thing this computes is the number of loops that need to be considered in the source code. If you have collapse(4) on a loop nest with 2 loops that would be incorrect since we can max collapse 2 loops. However tiling creates new loops, so collapse(4) would theoretically be legal if tiling is done first e.g. tile(5,10) since that will result in 4 loops. This is not really testable though since collapse requires independent loops, which is only true for the 2 outer loops after tiling is done. There is a check for this, and an error message is given if the collapse value is larger than the number of loops that are tiled to prevent incorrect code. We could just use numTileSizes if that is present, but if collapse could handle dependent loops in the future the above calculation should be the correct one.

}

void genObjectList(const ObjectList &objects,
Expand Down Expand Up @@ -606,11 +634,48 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
}
}

// Populates the sizes vector with values if the given OpenMPConstruct
// Contains a loop construct with an inner tiling construct.
void collectTileSizesFromOpenMPConstruct(
const parser::OpenMPConstruct *ompCons,
llvm::SmallVectorImpl<int64_t> &tileSizes, SemanticsContext &semaCtx) {
if (!ompCons)
return;

if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
const auto &innerOptional = std::get<
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
ompLoop->t);
if (innerOptional.has_value()) {
const auto &innerLoopDirective = innerOptional.value().value();
const auto &innerBegin =
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
const auto &innerDirective =
std::get<parser::OmpLoopDirective>(innerBegin.t).v;

if (innerDirective == llvm::omp::Directive::OMPD_tile) {
// Get the size values from parse tree and convert to a vector
const auto &innerClauseList{
std::get<parser::OmpClauseList>(innerBegin.t)};
for (const auto &clause : innerClauseList.v)
if (const auto tclause{
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
for (auto &tval : tclause->v) {
if (const auto v{EvaluateInt64(semaCtx, tval)})
tileSizes.push_back(*v);
}
}
}
}
}
}

bool collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {

bool found = false;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

Expand All @@ -627,6 +692,38 @@ bool collectLoopRelatedInfo(
found = true;
}

// Collect sizes from tile directive if present
std::int64_t sizesLengthValue = 0l;
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
const auto &innerOptional = std::get<
std::optional<common::Indirection<parser::OpenMPLoopConstruct>>>(
ompLoop->t);
if (innerOptional.has_value()) {
const auto &innerLoopDirective = innerOptional.value().value();
const auto &innerBegin =
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
const auto &innerDirective =
std::get<parser::OmpLoopDirective>(innerBegin.t).v;

if (innerDirective == llvm::omp::Directive::OMPD_tile) {
// Get the size values from parse tree and convert to a vector
const auto &innerClauseList{
std::get<parser::OmpClauseList>(innerBegin.t)};
for (const auto &clause : innerClauseList.v)
if (const auto tclause{
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
sizesLengthValue = tclause->v.size();
found = true;
}
}
}
}
}

collapseValue = collapseValue - sizesLengthValue;
collapseValue =
collapseValue < sizesLengthValue ? sizesLengthValue : collapseValue;
std::size_t loopVarTypeSize = 0;
do {
lower::pft::Evaluation *doLoop =
Expand Down Expand Up @@ -659,7 +756,6 @@ bool collectLoopRelatedInfo(
} while (collapseValue > 0);

convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);

return found;
}
} // namespace omp
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ bool collectLoopRelatedInfo(
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);

void collectTileSizesFromOpenMPConstruct(
const parser::OpenMPConstruct *ompCons,
llvm::SmallVectorImpl<int64_t> &tileSizes,
Fortran::semantics::SemanticsContext &semaCtx);

} // namespace omp
} // namespace lower
} // namespace Fortran
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Parser/unparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2926,6 +2926,8 @@ class UnparseVisitor {
Walk(std::get<OmpBeginLoopDirective>(x.t));
Put("\n");
EndOpenMP();
Walk(
std::get<std::optional<common::Indirection<OpenMPLoopConstruct>>>(x.t));
Walk(std::get<std::optional<DoConstruct>>(x.t));
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
}
Expand Down
Loading