Skip to content

Commit 7d27a0a

Browse files
committed
[Clang][OpenMP] Allow num_teams to accept multiple expressions
1 parent 9e97f80 commit 7d27a0a

11 files changed

+472
-304
lines changed

clang/include/clang/AST/OpenMPClause.h

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6131,60 +6131,77 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>,
61316131
/// \endcode
61326132
/// In this example directive '#pragma omp teams' has clause 'num_teams'
61336133
/// with single expression 'n'.
6134-
class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit {
6135-
friend class OMPClauseReader;
6134+
///
6135+
/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause
6136+
/// can accept up to three expressions.
6137+
///
6138+
/// \code
6139+
/// #pragma omp target teams ompx_bare num_teams(x, y, z)
6140+
/// \endcode
6141+
class OMPNumTeamsClause final
6142+
: public OMPVarListClause<OMPNumTeamsClause>,
6143+
public OMPClauseWithPreInit,
6144+
private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> {
6145+
friend OMPVarListClause;
6146+
friend TrailingObjects;
61366147

61376148
/// Location of '('.
61386149
SourceLocation LParenLoc;
61396150

6140-
/// NumTeams number.
6141-
Stmt *NumTeams = nullptr;
6151+
OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc,
6152+
SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N)
6153+
: OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc,
6154+
N),
6155+
OMPClauseWithPreInit(this) {}
61426156

6143-
/// Set the NumTeams number.
6144-
///
6145-
/// \param E NumTeams number.
6146-
void setNumTeams(Expr *E) { NumTeams = E; }
6157+
/// Build an empty clause.
6158+
OMPNumTeamsClause(unsigned N)
6159+
: OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6160+
SourceLocation(), SourceLocation(), N),
6161+
OMPClauseWithPreInit(this) {}
61476162

61486163
public:
6149-
/// Build 'num_teams' clause.
6164+
/// Creates clause with a list of variables \a VL.
61506165
///
6151-
/// \param E Expression associated with this clause.
6152-
/// \param HelperE Helper Expression associated with this clause.
6153-
/// \param CaptureRegion Innermost OpenMP region where expressions in this
6154-
/// clause must be captured.
6166+
/// \param C AST context.
61556167
/// \param StartLoc Starting location of the clause.
61566168
/// \param LParenLoc Location of '('.
61576169
/// \param EndLoc Ending location of the clause.
6158-
OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion,
6159-
SourceLocation StartLoc, SourceLocation LParenLoc,
6160-
SourceLocation EndLoc)
6161-
: OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc),
6162-
OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) {
6163-
setPreInitStmt(HelperE, CaptureRegion);
6164-
}
6170+
/// \param VL List of references to the variables.
6171+
/// \param PreInit
6172+
static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc,
6173+
SourceLocation LParenLoc,
6174+
SourceLocation EndLoc, ArrayRef<Expr *> VL,
6175+
Stmt *PreInit);
61656176

6166-
/// Build an empty clause.
6167-
OMPNumTeamsClause()
6168-
: OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(),
6169-
SourceLocation()),
6170-
OMPClauseWithPreInit(this) {}
6177+
/// Creates an empty clause with \a N variables.
6178+
///
6179+
/// \param C AST context.
6180+
/// \param N The number of variables.
6181+
static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N);
61716182

61726183
/// Sets the location of '('.
61736184
void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
61746185

61756186
/// Returns the location of '('.
61766187
SourceLocation getLParenLoc() const { return LParenLoc; }
61776188

6178-
/// Return NumTeams number.
6179-
Expr *getNumTeams() { return cast<Expr>(NumTeams); }
6189+
/// Return NumTeams number. By default, we return the first expression.
6190+
Expr *getNumTeams() { return getVarRefs().front(); }
61806191

6181-
/// Return NumTeams number.
6182-
Expr *getNumTeams() const { return cast<Expr>(NumTeams); }
6192+
/// Return NumTeams number. By default, we return the first expression.
6193+
Expr *getNumTeams() const {
6194+
return const_cast<OMPNumTeamsClause *>(this)->getNumTeams();
6195+
}
61836196

6184-
child_range children() { return child_range(&NumTeams, &NumTeams + 1); }
6197+
child_range children() {
6198+
return child_range(reinterpret_cast<Stmt **>(varlist_begin()),
6199+
reinterpret_cast<Stmt **>(varlist_end()));
6200+
}
61856201

61866202
const_child_range children() const {
6187-
return const_child_range(&NumTeams, &NumTeams + 1);
6203+
auto Children = const_cast<OMPNumTeamsClause *>(this)->children();
6204+
return const_child_range(Children.begin(), Children.end());
61886205
}
61896206

61906207
child_range used_children() {

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) {
37933793
template <typename Derived>
37943794
bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause(
37953795
OMPNumTeamsClause *C) {
3796+
TRY_TO(VisitOMPClauseList(C));
37963797
TRY_TO(VisitOMPClauseWithPreInit(C));
3797-
TRY_TO(TraverseStmt(C->getNumTeams()));
37983798
return true;
37993799
}
38003800

clang/include/clang/Sema/SemaOpenMP.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,8 @@ class SemaOpenMP : public SemaBase {
12271227
const OMPVarListLocTy &Locs, bool NoDiagnose = false,
12281228
ArrayRef<Expr *> UnresolvedMappers = std::nullopt);
12291229
/// Called on well-formed 'num_teams' clause.
1230-
OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc,
1230+
OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
1231+
SourceLocation StartLoc,
12311232
SourceLocation LParenLoc,
12321233
SourceLocation EndLoc);
12331234
/// Called on well-formed 'thread_limit' clause.

clang/lib/AST/OpenMPClause.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const {
17201720
return *It;
17211721
}
17221722

1723+
OMPNumTeamsClause *
1724+
OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc,
1725+
SourceLocation LParenLoc, SourceLocation EndLoc,
1726+
ArrayRef<Expr *> VL, Stmt *PreInit) {
1727+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
1728+
OMPNumTeamsClause *Clause =
1729+
new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size());
1730+
Clause->setVarRefs(VL);
1731+
Clause->setPreInitStmt(PreInit);
1732+
return Clause;
1733+
}
1734+
1735+
OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C,
1736+
unsigned N) {
1737+
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N));
1738+
return new (Mem) OMPNumTeamsClause(N);
1739+
}
1740+
17231741
//===----------------------------------------------------------------------===//
17241742
// OpenMP clauses printing methods
17251743
//===----------------------------------------------------------------------===//
@@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) {
19771995
}
19781996

19791997
void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) {
1980-
OS << "num_teams(";
1981-
Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0);
1982-
OS << ")";
1998+
if (!Node->varlist_empty()) {
1999+
OS << "num_teams";
2000+
VisitOMPClauseList(Node, '(');
2001+
OS << ")";
2002+
}
19832003
}
19842004

19852005
void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) {

clang/lib/AST/StmtProfile.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) {
843843
VisitOMPClauseList(C);
844844
}
845845
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
846+
VisitOMPClauseList(C);
846847
VistOMPClauseWithPreInit(C);
847-
if (C->getNumTeams())
848-
Profiler->VisitStmt(C->getNumTeams());
849848
}
850849
void OMPClauseProfiler::VisitOMPThreadLimitClause(
851850
const OMPThreadLimitClause *C) {

clang/lib/Parse/ParseOpenMP.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
30983098
case OMPC_simdlen:
30993099
case OMPC_collapse:
31003100
case OMPC_ordered:
3101-
case OMPC_num_teams:
31023101
case OMPC_thread_limit:
31033102
case OMPC_priority:
31043103
case OMPC_grainsize:
@@ -3252,6 +3251,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
32523251
? ParseOpenMPSimpleClause(CKind, WrongDirective)
32533252
: ParseOpenMPClause(CKind, WrongDirective);
32543253
break;
3254+
case OMPC_num_teams:
3255+
if (!FirstClause) {
3256+
Diag(Tok, diag::err_omp_more_one_clause)
3257+
<< getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0;
3258+
ErrorFound = true;
3259+
}
3260+
[[clang::fallthrough]];
32553261
case OMPC_private:
32563262
case OMPC_firstprivate:
32573263
case OMPC_lastprivate:

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13901,6 +13901,20 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective(
1390113901
return StmtError();
1390213902
}
1390313903

13904+
const OMPClause *NumTeamsClause = nullptr;
13905+
bool HasNumTeamsClause = llvm::any_of(Clauses, [&](const OMPClause *C) {
13906+
NumTeamsClause = C;
13907+
return C->getClauseKind() == OMPC_num_teams;
13908+
});
13909+
13910+
if (HasNumTeamsClause) {
13911+
ArrayRef<const Expr *> NumTeams =
13912+
cast<OMPNumTeamsClause>(NumTeamsClause)->getVarRefs();
13913+
if (!HasBareClause && NumTeams.size() > 1) {
13914+
return StmtError();
13915+
}
13916+
}
13917+
1390413918
return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc,
1390513919
Clauses, AStmt);
1390613920
}
@@ -15041,9 +15055,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
1504115055
case OMPC_ordered:
1504215056
Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr);
1504315057
break;
15044-
case OMPC_num_teams:
15045-
Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc);
15046-
break;
1504715058
case OMPC_thread_limit:
1504815059
Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc);
1504915060
break;
@@ -15147,6 +15158,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind,
1514715158
case OMPC_affinity:
1514815159
case OMPC_when:
1514915160
case OMPC_bind:
15161+
case OMPC_num_teams:
1515015162
default:
1515115163
llvm_unreachable("Clause is not allowed.");
1515215164
}
@@ -17010,6 +17022,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
1701017022
static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier),
1701117023
ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc);
1701217024
break;
17025+
case OMPC_num_teams:
17026+
Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc);
17027+
break;
1701317028
case OMPC_if:
1701417029
case OMPC_depobj:
1701517030
case OMPC_final:
@@ -17040,7 +17055,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
1704017055
case OMPC_device:
1704117056
case OMPC_threads:
1704217057
case OMPC_simd:
17043-
case OMPC_num_teams:
1704417058
case OMPC_thread_limit:
1704517059
case OMPC_priority:
1704617060
case OMPC_grainsize:
@@ -21703,32 +21717,40 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const {
2170321717
return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl();
2170421718
}
2170521719

21706-
OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams,
21720+
OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList,
2170721721
SourceLocation StartLoc,
2170821722
SourceLocation LParenLoc,
2170921723
SourceLocation EndLoc) {
21710-
Expr *ValExpr = NumTeams;
21711-
Stmt *HelperValStmt = nullptr;
21712-
21713-
// OpenMP [teams Constrcut, Restrictions]
21714-
// The num_teams expression must evaluate to a positive integer value.
21715-
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
21716-
/*StrictlyPositive=*/true))
21724+
if (VarList.empty())
2171721725
return nullptr;
2171821726

2171921727
OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
2172021728
OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause(
2172121729
DKind, OMPC_num_teams, getLangOpts().OpenMP);
21722-
if (CaptureRegion != OMPD_unknown &&
21723-
!SemaRef.CurContext->isDependentContext()) {
21730+
21731+
for (Expr *ValExpr : VarList) {
21732+
// OpenMP [teams Constrcut, Restrictions]
21733+
// The num_teams expression must evaluate to a positive integer value.
21734+
if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams,
21735+
/*StrictlyPositive=*/true))
21736+
return nullptr;
21737+
}
21738+
21739+
if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext())
21740+
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc,
21741+
EndLoc, VarList, /*PreInit=*/nullptr);
21742+
21743+
llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
21744+
SmallVector<Expr *, 3> Vars;
21745+
for (Expr *ValExpr : VarList) {
2172421746
ValExpr = SemaRef.MakeFullExpr(ValExpr).get();
21725-
llvm::MapVector<const Expr *, DeclRefExpr *> Captures;
2172621747
ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get();
21727-
HelperValStmt = buildPreInits(getASTContext(), Captures);
21748+
Vars.push_back(ValExpr);
2172821749
}
2172921750

21730-
return new (getASTContext()) OMPNumTeamsClause(
21731-
ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc);
21751+
Stmt *PreInit = buildPreInits(getASTContext(), Captures);
21752+
return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc,
21753+
Vars, PreInit);
2173221754
}
2173321755

2173421756
OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,

clang/lib/Serialization/ASTReader.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10562,7 +10562,7 @@ OMPClause *OMPClauseReader::readClause() {
1056210562
break;
1056310563
}
1056410564
case llvm::omp::OMPC_num_teams:
10565-
C = new (Context) OMPNumTeamsClause();
10565+
C = OMPNumTeamsClause::CreateEmpty(Context, Record.readInt());
1056610566
break;
1056710567
case llvm::omp::OMPC_thread_limit:
1056810568
C = new (Context) OMPThreadLimitClause();
@@ -11350,8 +11350,13 @@ void OMPClauseReader::VisitOMPAllocateClause(OMPAllocateClause *C) {
1135011350

1135111351
void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
1135211352
VisitOMPClauseWithPreInit(C);
11353-
C->setNumTeams(Record.readSubExpr());
1135411353
C->setLParenLoc(Record.readSourceLocation());
11354+
unsigned NumVars = C->varlist_size();
11355+
SmallVector<Expr *, 16> Vars;
11356+
Vars.reserve(NumVars);
11357+
for (unsigned i = 0; i != NumVars; ++i)
11358+
Vars.push_back(Record.readSubExpr());
11359+
C->setVarRefs(Vars);
1135511360
}
1135611361

1135711362
void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {

clang/lib/Serialization/ASTWriter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7528,9 +7528,11 @@ void OMPClauseWriter::VisitOMPAllocateClause(OMPAllocateClause *C) {
75287528
}
75297529

75307530
void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
7531+
Record.push_back(C->varlist_size());
75317532
VisitOMPClauseWithPreInit(C);
7532-
Record.AddStmt(C->getNumTeams());
75337533
Record.AddSourceLocation(C->getLParenLoc());
7534+
for (auto *VE : C->varlists())
7535+
Record.AddStmt(VE);
75347536
}
75357537

75367538
void OMPClauseWriter::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) {

0 commit comments

Comments
 (0)