Skip to content

Commit 49caecb

Browse files
committed
Reapply "[Flang][OpenMP][Lower] NFC: Move clause processing helpers into the ClauseProcessor (#85258)"
This patch contains slight modifications to the reverted PR #85258 to avoid issues with constructs containing multiple reduction clauses, uncovered by a test on the gfortran testsuite. This reverts commit 9f80444.
1 parent c20596c commit 49caecb

File tree

5 files changed

+90
-80
lines changed

5 files changed

+90
-80
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,25 @@ addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
208208
useDeviceSymbols.push_back(object.id());
209209
}
210210

211+
static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
212+
mlir::Location loc,
213+
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
214+
llvm::SmallVectorImpl<mlir::Value> &upperBound,
215+
llvm::SmallVectorImpl<mlir::Value> &step,
216+
std::size_t loopVarTypeSize) {
217+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
218+
// The types of lower bound, upper bound, and step are converted into the
219+
// type of the loop variable if necessary.
220+
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
221+
for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
222+
lowerBound[it] =
223+
firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
224+
upperBound[it] =
225+
firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
226+
step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
227+
}
228+
}
229+
211230
//===----------------------------------------------------------------------===//
212231
// ClauseProcessor unique clauses
213232
//===----------------------------------------------------------------------===//
@@ -217,8 +236,7 @@ bool ClauseProcessor::processCollapse(
217236
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
218237
llvm::SmallVectorImpl<mlir::Value> &upperBound,
219238
llvm::SmallVectorImpl<mlir::Value> &step,
220-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
221-
std::size_t &loopVarTypeSize) const {
239+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const {
222240
bool found = false;
223241
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
224242

@@ -236,7 +254,7 @@ bool ClauseProcessor::processCollapse(
236254
found = true;
237255
}
238256

239-
loopVarTypeSize = 0;
257+
std::size_t loopVarTypeSize = 0;
240258
do {
241259
Fortran::lower::pft::Evaluation *doLoop =
242260
&doConstructEval->getFirstNestedEvaluation();
@@ -267,6 +285,9 @@ bool ClauseProcessor::processCollapse(
267285
&*std::next(doConstructEval->getNestedEvaluations().begin());
268286
} while (collapseValue > 0);
269287

288+
convertLoopBounds(converter, currentLocation, lowerBound, upperBound, step,
289+
loopVarTypeSize);
290+
270291
return found;
271292
}
272293

@@ -906,16 +927,38 @@ bool ClauseProcessor::processMap(
906927

907928
bool ClauseProcessor::processReduction(
908929
mlir::Location currentLocation,
909-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
910-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
911-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
912-
const {
930+
llvm::SmallVectorImpl<mlir::Value> &outReductionVars,
931+
llvm::SmallVectorImpl<mlir::Type> &outReductionTypes,
932+
llvm::SmallVectorImpl<mlir::Attribute> &outReductionDeclSymbols,
933+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
934+
*outReductionSymbols) const {
913935
return findRepeatableClause<omp::clause::Reduction>(
914936
[&](const omp::clause::Reduction &clause,
915937
const Fortran::parser::CharBlock &) {
938+
// Use local lists of reductions to prevent variables from other
939+
// already-processed reduction clauses from impacting this reduction.
940+
// For example, the whole `reductionVars` array is queried to decide
941+
// whether to do the reduction byref.
942+
llvm::SmallVector<mlir::Value> reductionVars;
943+
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
944+
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
916945
ReductionProcessor rp;
917946
rp.addReductionDecl(currentLocation, converter, clause, reductionVars,
918-
reductionDeclSymbols, reductionSymbols);
947+
reductionDeclSymbols,
948+
outReductionSymbols ? &reductionSymbols : nullptr);
949+
950+
// Copy local lists into the output.
951+
llvm::copy(reductionVars, std::back_inserter(outReductionVars));
952+
llvm::copy(reductionDeclSymbols,
953+
std::back_inserter(outReductionDeclSymbols));
954+
if (outReductionSymbols)
955+
llvm::copy(reductionSymbols,
956+
std::back_inserter(*outReductionSymbols));
957+
958+
outReductionTypes.reserve(outReductionTypes.size() +
959+
reductionVars.size());
960+
llvm::transform(reductionVars, std::back_inserter(outReductionTypes),
961+
[](mlir::Value v) { return v.getType(); });
919962
});
920963
}
921964

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,12 @@ class ClauseProcessor {
5656
clauses(makeList(clauses, semaCtx)) {}
5757

5858
// 'Unique' clauses: They can appear at most once in the clause list.
59-
bool
60-
processCollapse(mlir::Location currentLocation,
61-
Fortran::lower::pft::Evaluation &eval,
62-
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
63-
llvm::SmallVectorImpl<mlir::Value> &upperBound,
64-
llvm::SmallVectorImpl<mlir::Value> &step,
65-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
66-
std::size_t &loopVarTypeSize) const;
59+
bool processCollapse(
60+
mlir::Location currentLocation, Fortran::lower::pft::Evaluation &eval,
61+
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
62+
llvm::SmallVectorImpl<mlir::Value> &upperBound,
63+
llvm::SmallVectorImpl<mlir::Value> &step,
64+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) const;
6765
bool processDefault() const;
6866
bool processDevice(Fortran::lower::StatementContext &stmtCtx,
6967
mlir::Value &result) const;
@@ -126,6 +124,7 @@ class ClauseProcessor {
126124
bool
127125
processReduction(mlir::Location currentLocation,
128126
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
127+
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
129128
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
130129
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
131130
*reductionSymbols = nullptr) const;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 10 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -214,24 +214,6 @@ static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
214214
firOpBuilder.restoreInsertionPoint(insPt);
215215
}
216216

217-
static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
218-
std::size_t loopVarTypeSize) {
219-
// OpenMP runtime requires 32-bit or 64-bit loop variables.
220-
loopVarTypeSize = loopVarTypeSize * 8;
221-
if (loopVarTypeSize < 32) {
222-
loopVarTypeSize = 32;
223-
} else if (loopVarTypeSize > 64) {
224-
loopVarTypeSize = 64;
225-
mlir::emitWarning(converter.getCurrentLocation(),
226-
"OpenMP loop iteration variable cannot have more than 64 "
227-
"bits size and will be narrowed into 64 bits.");
228-
}
229-
assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
230-
"OpenMP loop iteration variable size must be transformed into 32-bit "
231-
"or 64-bit");
232-
return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
233-
}
234-
235217
static mlir::Operation *
236218
createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
237219
mlir::Location loc, mlir::Value indexVal,
@@ -568,6 +550,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
568550
mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
569551
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
570552
reductionVars;
553+
llvm::SmallVector<mlir::Type> reductionTypes;
571554
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
572555
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
573556

@@ -578,13 +561,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
578561
cp.processDefault();
579562
cp.processAllocate(allocatorOperands, allocateOperands);
580563
if (!outerCombined)
581-
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
582-
&reductionSymbols);
583-
584-
llvm::SmallVector<mlir::Type> reductionTypes;
585-
reductionTypes.reserve(reductionVars.size());
586-
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
587-
[](mlir::Value v) { return v.getType(); });
564+
cp.processReduction(currentLocation, reductionVars, reductionTypes,
565+
reductionDeclSymbols, &reductionSymbols);
588566

589567
auto reductionCallback = [&](mlir::Operation *op) {
590568
llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
@@ -1468,25 +1446,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
14681446
standaloneConstruct.u);
14691447
}
14701448

1471-
static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
1472-
mlir::Location loc,
1473-
llvm::SmallVectorImpl<mlir::Value> &lowerBound,
1474-
llvm::SmallVectorImpl<mlir::Value> &upperBound,
1475-
llvm::SmallVectorImpl<mlir::Value> &step,
1476-
std::size_t loopVarTypeSize) {
1477-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1478-
// The types of lower bound, upper bound, and step are converted into the
1479-
// type of the loop variable if necessary.
1480-
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
1481-
for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
1482-
lowerBound[it] =
1483-
firOpBuilder.createConvert(loc, loopVarType, lowerBound[it]);
1484-
upperBound[it] =
1485-
firOpBuilder.createConvert(loc, loopVarType, upperBound[it]);
1486-
step[it] = firOpBuilder.createConvert(loc, loopVarType, step[it]);
1487-
}
1488-
}
1489-
14901449
static llvm::SmallVector<const Fortran::semantics::Symbol *>
14911450
genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
14921451
mlir::Location &loc,
@@ -1520,7 +1479,7 @@ genLoopAndReductionVars(
15201479
mlir::Location &loc,
15211480
llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
15221481
llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
1523-
llvm::SmallVectorImpl<mlir::Type> &reductionTypes) {
1482+
llvm::ArrayRef<mlir::Type> reductionTypes) {
15241483
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
15251484

15261485
llvm::SmallVector<mlir::Type> blockArgTypes;
@@ -1582,16 +1541,15 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
15821541
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
15831542
llvm::SmallVector<mlir::Value> alignedVars, nontemporalVars;
15841543
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
1544+
llvm::SmallVector<mlir::Type> reductionTypes;
15851545
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
15861546
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
15871547
mlir::IntegerAttr simdlenClauseOperand, safelenClauseOperand;
1588-
std::size_t loopVarTypeSize;
15891548

15901549
ClauseProcessor cp(converter, semaCtx, loopOpClauseList);
1591-
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
1592-
loopVarTypeSize);
1550+
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
15931551
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
1594-
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
1552+
cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols);
15951553
cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
15961554
cp.processSimdlen(simdlenClauseOperand);
15971555
cp.processSafelen(safelenClauseOperand);
@@ -1601,9 +1559,6 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
16011559
Fortran::parser::OmpClause::Nontemporal,
16021560
Fortran::parser::OmpClause::Order>(loc, ompDirective);
16031561

1604-
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
1605-
loopVarTypeSize);
1606-
16071562
mlir::TypeRange resultType;
16081563
auto simdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
16091564
loc, resultType, lowerBound, upperBound, step, alignedVars,
@@ -1641,27 +1596,23 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
16411596
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, reductionVars;
16421597
llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
16431598
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
1599+
llvm::SmallVector<mlir::Type> reductionTypes;
16441600
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
16451601
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
16461602
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
16471603
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
16481604
mlir::UnitAttr nowaitClauseOperand, byrefOperand, scheduleSimdClauseOperand;
16491605
mlir::IntegerAttr orderedClauseOperand;
16501606
mlir::omp::ScheduleModifierAttr scheduleModClauseOperand;
1651-
std::size_t loopVarTypeSize;
16521607

16531608
ClauseProcessor cp(converter, semaCtx, beginClauseList);
1654-
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
1655-
loopVarTypeSize);
1609+
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv);
16561610
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
1657-
cp.processReduction(loc, reductionVars, reductionDeclSymbols,
1611+
cp.processReduction(loc, reductionVars, reductionTypes, reductionDeclSymbols,
16581612
&reductionSymbols);
16591613
cp.processTODO<Fortran::parser::OmpClause::Linear,
16601614
Fortran::parser::OmpClause::Order>(loc, ompDirective);
16611615

1662-
convertLoopBounds(converter, loc, lowerBound, upperBound, step,
1663-
loopVarTypeSize);
1664-
16651616
if (ReductionProcessor::doReductionByRef(reductionVars))
16661617
byrefOperand = firOpBuilder.getUnitAttr();
16671618

@@ -1702,11 +1653,6 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
17021653
auto *nestedEval = getCollapsedLoopEval(
17031654
eval, Fortran::lower::getCollapseValue(beginClauseList));
17041655

1705-
llvm::SmallVector<mlir::Type> reductionTypes;
1706-
reductionTypes.reserve(reductionVars.size());
1707-
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
1708-
[](mlir::Value v) { return v.getType(); });
1709-
17101656
auto ivCallback = [&](mlir::Operation *op) {
17111657
return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
17121658
reductionTypes);

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <flang/Lower/AbstractConverter.h>
1717
#include <flang/Lower/ConvertType.h>
18+
#include <flang/Optimizer/Builder/FIRBuilder.h>
1819
#include <flang/Parser/parse-tree.h>
1920
#include <flang/Parser/tools.h>
2021
#include <flang/Semantics/tools.h>
@@ -70,6 +71,24 @@ void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
7071
}
7172
}
7273

74+
mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
75+
std::size_t loopVarTypeSize) {
76+
// OpenMP runtime requires 32-bit or 64-bit loop variables.
77+
loopVarTypeSize = loopVarTypeSize * 8;
78+
if (loopVarTypeSize < 32) {
79+
loopVarTypeSize = 32;
80+
} else if (loopVarTypeSize > 64) {
81+
loopVarTypeSize = 64;
82+
mlir::emitWarning(converter.getCurrentLocation(),
83+
"OpenMP loop iteration variable cannot have more than 64 "
84+
"bits size and will be narrowed into 64 bits.");
85+
}
86+
assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
87+
"OpenMP loop iteration variable size must be transformed into 32-bit "
88+
"or 64-bit");
89+
return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
90+
}
91+
7392
void gatherFuncAndVarSyms(
7493
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
7594
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
5151
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
5252
bool isVal = false);
5353

54+
mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
55+
std::size_t loopVarTypeSize);
56+
5457
void gatherFuncAndVarSyms(
5558
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
5659
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);

0 commit comments

Comments
 (0)