Skip to content

[Clang][Sema][OpenMP] Allow thread_limit to accept multiple expressions #102715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clang/docs/OpenMPSupport.rst
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ considered for standardization. Please post on the
| device extension | `'ompx_bare' clause on 'target teams' construct | :good:`prototyped` | #66844, #70612 |
| | <https://www.osti.gov/servlets/purl/2205717>`_ | | |
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+
| device extension | Multi-dim 'num_teams' clause on 'target teams ompx_bare' construct | :good:`partial` | #99732, #101407 |
| device extension | Multi-dim 'num_teams' and 'thread_limit' clause on 'target teams ompx_bare' | :good:`partial` | #99732, #101407, #102715 |
| | construct | | |
+------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+

.. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35
5 changes: 3 additions & 2 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,9 @@ Improvements
^^^^^^^^^^^^
- Improve the handling of mapping array-section for struct containing nested structs with user defined mappers

- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct.
This allows the target region to be launched with multi-dim grid on GPUs.
- `num_teams` and `thead_limit` now accept multiple expressions when it is used
along in ``target teams ompx_bare`` construct. This allows the target region
to be launched with multi-dim grid on GPUs.

Additional Information
======================
Expand Down
81 changes: 49 additions & 32 deletions clang/include/clang/AST/OpenMPClause.h
Original file line number Diff line number Diff line change
Expand Up @@ -6462,61 +6462,78 @@ class OMPNumTeamsClause final
/// \endcode
/// In this example directive '#pragma omp teams' has clause 'thread_limit'
/// with single expression 'n'.
class OMPThreadLimitClause : public OMPClause, public OMPClauseWithPreInit {
friend class OMPClauseReader;
///
/// When 'ompx_bare' clause exists on a 'target' directive, 'thread_limit'
/// clause can accept up to three expressions.
///
/// \code
/// #pragma omp target teams ompx_bare thread_limit(x, y, z)
/// \endcode
class OMPThreadLimitClause final
: public OMPVarListClause<OMPThreadLimitClause>,
public OMPClauseWithPreInit,
private llvm::TrailingObjects<OMPThreadLimitClause, Expr *> {
friend OMPVarListClause;
friend TrailingObjects;

/// Location of '('.
SourceLocation LParenLoc;

/// ThreadLimit number.
Stmt *ThreadLimit = nullptr;
OMPThreadLimitClause(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc, SourceLocation EndLoc,
unsigned N)
: OMPVarListClause(llvm::omp::OMPC_thread_limit, StartLoc, LParenLoc,
EndLoc, N),
OMPClauseWithPreInit(this) {}

/// Set the ThreadLimit number.
///
/// \param E ThreadLimit number.
void setThreadLimit(Expr *E) { ThreadLimit = E; }
/// Build an empty clause.
OMPThreadLimitClause(unsigned N)
: OMPVarListClause(llvm::omp::OMPC_thread_limit, SourceLocation(),
SourceLocation(), SourceLocation(), N),
OMPClauseWithPreInit(this) {}

public:
/// Build 'thread_limit' clause.
/// Creates clause with a list of variables \a VL.
///
/// \param E Expression associated with this clause.
/// \param HelperE Helper Expression associated with this clause.
/// \param CaptureRegion Innermost OpenMP region where expressions in this
/// clause must be captured.
/// \param C AST context.
/// \param StartLoc Starting location of the clause.
/// \param LParenLoc Location of '('.
/// \param EndLoc Ending location of the clause.
OMPThreadLimitClause(Expr *E, Stmt *HelperE,
OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc)
: OMPClause(llvm::omp::OMPC_thread_limit, StartLoc, EndLoc),
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), ThreadLimit(E) {
setPreInitStmt(HelperE, CaptureRegion);
}
/// \param VL List of references to the variables.
/// \param PreInit
static OMPThreadLimitClause *
Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc,
SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit);

/// Build an empty clause.
OMPThreadLimitClause()
: OMPClause(llvm::omp::OMPC_thread_limit, SourceLocation(),
SourceLocation()),
OMPClauseWithPreInit(this) {}
/// Creates an empty clause with \a N variables.
///
/// \param C AST context.
/// \param N The number of variables.
static OMPThreadLimitClause *CreateEmpty(const ASTContext &C, unsigned N);

/// Sets the location of '('.
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }

/// Returns the location of '('.
SourceLocation getLParenLoc() const { return LParenLoc; }

/// Return ThreadLimit number.
Expr *getThreadLimit() { return cast<Expr>(ThreadLimit); }
/// Return ThreadLimit expressions.
ArrayRef<Expr *> getThreadLimit() { return getVarRefs(); }

/// Return ThreadLimit number.
Expr *getThreadLimit() const { return cast<Expr>(ThreadLimit); }
/// Return ThreadLimit expressions.
ArrayRef<Expr *> getThreadLimit() const {
return const_cast<OMPThreadLimitClause *>(this)->getThreadLimit();
}

child_range children() { return child_range(&ThreadLimit, &ThreadLimit + 1); }
child_range children() {
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
reinterpret_cast<Stmt **>(varlist_end()));
}

const_child_range children() const {
return const_child_range(&ThreadLimit, &ThreadLimit + 1);
auto Children = const_cast<OMPThreadLimitClause *>(this)->children();
return const_child_range(Children.begin(), Children.end());
}

child_range used_children() {
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3836,8 +3836,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPThreadLimitClause(
OMPThreadLimitClause *C) {
TRY_TO(VisitOMPClauseList(C));
TRY_TO(VisitOMPClauseWithPreInit(C));
TRY_TO(TraverseStmt(C->getThreadLimit()));
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,7 @@ class SemaOpenMP : public SemaBase {
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'thread_limit' clause.
OMPClause *ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
OMPClause *ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
Expand Down
26 changes: 23 additions & 3 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,24 @@ OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
return new (Mem) OMPNumTeamsClause(N);
}

OMPThreadLimitClause *OMPThreadLimitClause::Create(
const ASTContext &C, OpenMPDirectiveKind CaptureRegion,
SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc,
ArrayRef<Expr *> VL, Stmt *PreInit) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
OMPThreadLimitClause *Clause =
new (Mem) OMPThreadLimitClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
Clause->setVarRefs(VL);
Clause->setPreInitStmt(PreInit, CaptureRegion);
return Clause;
}

OMPThreadLimitClause *OMPThreadLimitClause::CreateEmpty(const ASTContext &C,
unsigned N) {
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
return new (Mem) OMPThreadLimitClause(N);
}

//===----------------------------------------------------------------------===//
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2081,9 +2099,11 @@ void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
}

void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {
OS << "thread_limit(";
Node->getThreadLimit()->printPretty(OS, nullptr, Policy, 0);
OS << ")";
if (!Node->varlist_empty()) {
OS << "thread_limit";
VisitOMPClauseList(Node, '(');
OS << ")";
}
}

void OMPClausePrinter::VisitOMPPriorityClause(OMPPriorityClause *Node) {
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,9 +862,8 @@ void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
}
void OMPClauseProfiler::VisitOMPThreadLimitClause(
const OMPThreadLimitClause *C) {
VisitOMPClauseList(C);
VistOMPClauseWithPreInit(C);
if (C->getThreadLimit())
Profiler->VisitStmt(C->getThreadLimit());
}
void OMPClauseProfiler::VisitOMPPriorityClause(const OMPPriorityClause *C) {
VistOMPClauseWithPreInit(C);
Expand Down
15 changes: 10 additions & 5 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6332,7 +6332,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
CGOpenMPInnerExprInfo CGInfo(CGF, *CS);
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
CodeGenFunction::LexicalScope Scope(
CGF, ThreadLimitClause->getThreadLimit()->getSourceRange());
CGF,
ThreadLimitClause->getThreadLimit().front()->getSourceRange());
if (const auto *PreInit =
cast_or_null<DeclStmt>(ThreadLimitClause->getPreInitStmt())) {
for (const auto *I : PreInit->decls()) {
Expand All @@ -6349,7 +6350,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
}
}
if (ThreadLimitClause)
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
ThreadLimitExpr);
if (const auto *Dir = dyn_cast_or_null<OMPExecutableDirective>(Child)) {
if (isOpenMPTeamsDirective(Dir->getDirectiveKind()) &&
!isOpenMPDistributeDirective(Dir->getDirectiveKind())) {
Expand All @@ -6370,7 +6372,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
ThreadLimitExpr);
}
const CapturedStmt *CS = D.getInnermostCapturedStmt();
getNumThreads(CGF, CS, NTPtr, UpperBound, UpperBoundOnly, CondVal);
Expand All @@ -6388,7 +6391,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
ThreadLimitExpr);
}
getNumThreads(CGF, D.getInnermostCapturedStmt(), NTPtr, UpperBound,
UpperBoundOnly, CondVal);
Expand Down Expand Up @@ -6424,7 +6428,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective(
if (D.hasClausesOfKind<OMPThreadLimitClause>()) {
CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF);
const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>();
CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr);
CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(),
ThreadLimitExpr);
}
if (D.hasClausesOfKind<OMPNumThreadsClause>()) {
CodeGenFunction::RunCleanupsScope NumThreadsScope(CGF);
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CodeGen/CGStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5259,7 +5259,7 @@ void CodeGenFunction::EmitOMPTargetTaskBasedDirective(
// enclosing this target region. This will indirectly set the thread_limit
// for every applicable construct within target region.
CGF.CGM.getOpenMPRuntime().emitThreadLimitClause(
CGF, TL->getThreadLimit(), S.getBeginLoc());
CGF, TL->getThreadLimit().front(), S.getBeginLoc());
}
BodyGen(CGF);
};
Expand Down Expand Up @@ -6860,7 +6860,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF,
const auto *TL = S.getSingleClause<OMPThreadLimitClause>();
if (NT || TL) {
const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr;
const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr;
const Expr *ThreadLimit = TL ? TL->getThreadLimit().front() : nullptr;

CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit,
S.getBeginLoc());
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3175,7 +3175,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
case OMPC_simdlen:
case OMPC_collapse:
case OMPC_ordered:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
case OMPC_num_tasks:
Expand Down Expand Up @@ -3332,6 +3331,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
: ParseOpenMPClause(CKind, WrongDirective);
break;
case OMPC_num_teams:
case OMPC_thread_limit:
if (!FirstClause) {
Diag(Tok, diag::err_omp_more_one_clause)
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
Expand Down
Loading
Loading