Skip to content

Commit 91ccce2

Browse files
committed
[flang][OpenMP][MLIR] Basic support for delayed privatization code-gen
Adds basic support for emitting delayed privatizers from flang. So far, only types of symbols are supported (i.e. scalars), support for more complicated types will be added later. This also makes sure that reductio and delayed privatization work properly together by merging the body-gen callbacks for both in case both clauses are present on the parallel construct.
1 parent 0bb1415 commit 91ccce2

File tree

6 files changed

+307
-28
lines changed

6 files changed

+307
-28
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "flang/Common/Fortran.h"
1717
#include "flang/Lower/LoweringOptions.h"
1818
#include "flang/Lower/PFTDefs.h"
19+
#include "flang/Lower/SymbolMap.h"
1920
#include "flang/Optimizer/Builder/BoxValue.h"
2021
#include "flang/Semantics/symbol.h"
2122
#include "mlir/IR/Builders.h"
@@ -296,6 +297,9 @@ class AbstractConverter {
296297
return loweringOptions;
297298
}
298299

300+
virtual Fortran::lower::SymbolBox
301+
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
302+
299303
private:
300304
/// Options controlling lowering behavior.
301305
const Fortran::lower::LoweringOptions &loweringOptions;

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
10701070
/// Find the symbol in one level up of symbol map such as for host-association
10711071
/// in OpenMP code or return null.
10721072
Fortran::lower::SymbolBox
1073-
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) {
1073+
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) override {
10741074
if (Fortran::lower::SymbolBox v = localSymbols.lookupOneLevelUpSymbol(sym))
10751075
return v;
10761076
return {};

flang/lib/Lower/OpenMP.cpp

Lines changed: 208 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ static llvm::cl::opt<bool> treatIndexAsSection(
4040
llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
4141
llvm::cl::init(true));
4242

43+
static llvm::cl::opt<bool> enableDelayedPrivatization(
44+
"openmp-enable-delayed-privatization",
45+
llvm::cl::desc(
46+
"Emit `[first]private` variables as clauses on the MLIR ops."),
47+
llvm::cl::init(false));
48+
4349
using DeclareTargetCapturePair =
4450
std::pair<mlir::omp::DeclareTargetCaptureClause,
4551
Fortran::semantics::Symbol>;
@@ -147,6 +153,14 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
147153
//===----------------------------------------------------------------------===//
148154

149155
class DataSharingProcessor {
156+
public:
157+
struct DelayedPrivatizationInfo {
158+
llvm::SmallVector<mlir::SymbolRefAttr> privatizers;
159+
llvm::SmallVector<mlir::Value> hostAddresses;
160+
llvm::SmallVector<const Fortran::semantics::Symbol *> hostSymbols;
161+
};
162+
163+
private:
150164
bool hasLastPrivateOp;
151165
mlir::OpBuilder::InsertPoint lastPrivIP;
152166
mlir::OpBuilder::InsertPoint insPt;
@@ -161,6 +175,11 @@ class DataSharingProcessor {
161175
const Fortran::parser::OmpClauseList &opClauseList;
162176
Fortran::lower::pft::Evaluation &eval;
163177

178+
bool useDelayedPrivatization;
179+
Fortran::lower::SymMap *symTable;
180+
181+
DelayedPrivatizationInfo delayedPrivatizationInfo;
182+
164183
bool needBarrier();
165184
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
166185
void collectOmpObjectListSymbol(
@@ -171,21 +190,28 @@ class DataSharingProcessor {
171190
void collectDefaultSymbols();
172191
void privatize();
173192
void defaultPrivatize();
193+
void doPrivatize(const Fortran::semantics::Symbol *sym);
174194
void copyLastPrivatize(mlir::Operation *op);
175195
void insertLastPrivateCompare(mlir::Operation *op);
176196
void cloneSymbol(const Fortran::semantics::Symbol *sym);
177-
void copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym);
197+
void
198+
copyFirstPrivateSymbol(const Fortran::semantics::Symbol *sym,
199+
mlir::OpBuilder::InsertPoint *copyAssignIP = nullptr);
178200
void copyLastPrivateSymbol(const Fortran::semantics::Symbol *sym,
179201
mlir::OpBuilder::InsertPoint *lastPrivIP);
180202
void insertDeallocs();
181203

182204
public:
183205
DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
184206
const Fortran::parser::OmpClauseList &opClauseList,
185-
Fortran::lower::pft::Evaluation &eval)
207+
Fortran::lower::pft::Evaluation &eval,
208+
bool useDelayedPrivatization = false,
209+
Fortran::lower::SymMap *symTable = nullptr)
186210
: hasLastPrivateOp(false), converter(converter),
187211
firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
188-
eval(eval) {}
212+
eval(eval), useDelayedPrivatization(useDelayedPrivatization),
213+
symTable(symTable) {}
214+
189215
// Privatisation is split into two steps.
190216
// Step1 performs cloning of all privatisation clauses and copying for
191217
// firstprivates. Step1 is performed at the place where process/processStep1
@@ -204,6 +230,10 @@ class DataSharingProcessor {
204230
assert(!loopIV && "Loop iteration variable already set");
205231
loopIV = iv;
206232
}
233+
234+
const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const {
235+
return delayedPrivatizationInfo;
236+
}
207237
};
208238

209239
void DataSharingProcessor::processStep1() {
@@ -250,9 +280,10 @@ void DataSharingProcessor::cloneSymbol(const Fortran::semantics::Symbol *sym) {
250280
}
251281

252282
void DataSharingProcessor::copyFirstPrivateSymbol(
253-
const Fortran::semantics::Symbol *sym) {
283+
const Fortran::semantics::Symbol *sym,
284+
mlir::OpBuilder::InsertPoint *copyAssignIP) {
254285
if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate))
255-
converter.copyHostAssociateVar(*sym);
286+
converter.copyHostAssociateVar(*sym, copyAssignIP);
256287
}
257288

258289
void DataSharingProcessor::copyLastPrivateSymbol(
@@ -491,14 +522,10 @@ void DataSharingProcessor::privatize() {
491522
for (const Fortran::semantics::Symbol *sym : privatizedSymbols) {
492523
if (const auto *commonDet =
493524
sym->detailsIf<Fortran::semantics::CommonBlockDetails>()) {
494-
for (const auto &mem : commonDet->objects()) {
495-
cloneSymbol(&*mem);
496-
copyFirstPrivateSymbol(&*mem);
497-
}
498-
} else {
499-
cloneSymbol(sym);
500-
copyFirstPrivateSymbol(sym);
501-
}
525+
for (const auto &mem : commonDet->objects())
526+
doPrivatize(&*mem);
527+
} else
528+
doPrivatize(sym);
502529
}
503530
}
504531

@@ -522,11 +549,96 @@ void DataSharingProcessor::defaultPrivatize() {
522549
!sym->GetUltimate().has<Fortran::semantics::NamelistDetails>() &&
523550
!symbolsInNestedRegions.contains(sym) &&
524551
!symbolsInParentRegions.contains(sym) &&
525-
!privatizedSymbols.contains(sym)) {
552+
!privatizedSymbols.contains(sym))
553+
doPrivatize(sym);
554+
}
555+
}
556+
557+
void DataSharingProcessor::doPrivatize(const Fortran::semantics::Symbol *sym) {
558+
if (!useDelayedPrivatization) {
559+
cloneSymbol(sym);
560+
copyFirstPrivateSymbol(sym);
561+
return;
562+
}
563+
564+
Fortran::lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
565+
assert(hsb && "Host symbol box not found");
566+
567+
mlir::Type symType = hsb.getAddr().getType();
568+
mlir::Location symLoc = hsb.getAddr().getLoc();
569+
std::string privatizerName = sym->name().ToString() + ".privatizer";
570+
bool isFirstPrivate =
571+
sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate);
572+
573+
mlir::omp::PrivateClauseOp privatizerOp = [&]() {
574+
auto moduleOp = firOpBuilder.getModule();
575+
576+
auto uniquePrivatizerName = fir::getTypeAsString(
577+
symType, converter.getKindMap(),
578+
sym->name().ToString() +
579+
(isFirstPrivate ? "_firstprivate" : "_private"));
580+
581+
if (auto existingPrivatizer =
582+
moduleOp.lookupSymbol<mlir::omp::PrivateClauseOp>(
583+
uniquePrivatizerName))
584+
return existingPrivatizer;
585+
586+
auto ip = firOpBuilder.saveInsertionPoint();
587+
firOpBuilder.setInsertionPoint(&moduleOp.getBodyRegion().front(),
588+
moduleOp.getBodyRegion().front().begin());
589+
auto result = firOpBuilder.create<mlir::omp::PrivateClauseOp>(
590+
symLoc, uniquePrivatizerName, symType,
591+
isFirstPrivate ? mlir::omp::DataSharingClauseType ::FirstPrivate
592+
: mlir::omp::DataSharingClauseType::Private);
593+
594+
symTable->pushScope();
595+
596+
// Populate the `alloc` region.
597+
{
598+
mlir::Region &allocRegion = result.getAllocRegion();
599+
mlir::Block *allocEntryBlock = firOpBuilder.createBlock(
600+
&allocRegion, /*insertPt=*/{}, symType, symLoc);
601+
602+
firOpBuilder.setInsertionPointToEnd(allocEntryBlock);
603+
symTable->addSymbol(*sym, allocRegion.getArgument(0));
604+
symTable->pushScope();
526605
cloneSymbol(sym);
527-
copyFirstPrivateSymbol(sym);
606+
firOpBuilder.create<mlir::omp::YieldOp>(
607+
hsb.getAddr().getLoc(),
608+
symTable->shallowLookupSymbol(*sym).getAddr());
609+
symTable->popScope();
528610
}
529-
}
611+
612+
// Poplate the `copy` region if this is a `firstprivate`.
613+
if (isFirstPrivate) {
614+
mlir::Region &copyRegion = result.getCopyRegion();
615+
// First block argument corresponding to the original/host value while
616+
// second block argument corresponding to the privatized value.
617+
mlir::Block *copyEntryBlock = firOpBuilder.createBlock(
618+
&copyRegion, /*insertPt=*/{}, {symType, symType}, {symLoc, symLoc});
619+
firOpBuilder.setInsertionPointToEnd(copyEntryBlock);
620+
symTable->addSymbol(*sym, copyRegion.getArgument(0),
621+
/*force=*/true);
622+
symTable->pushScope();
623+
symTable->addSymbol(*sym, copyRegion.getArgument(1));
624+
auto ip = firOpBuilder.saveInsertionPoint();
625+
copyFirstPrivateSymbol(sym, &ip);
626+
627+
firOpBuilder.create<mlir::omp::YieldOp>(
628+
hsb.getAddr().getLoc(),
629+
symTable->shallowLookupSymbol(*sym).getAddr());
630+
symTable->popScope();
631+
}
632+
633+
symTable->popScope();
634+
firOpBuilder.restoreInsertionPoint(ip);
635+
return result;
636+
}();
637+
638+
delayedPrivatizationInfo.privatizers.push_back(
639+
mlir::SymbolRefAttr::get(privatizerOp));
640+
delayedPrivatizationInfo.hostAddresses.push_back(hsb.getAddr());
641+
delayedPrivatizationInfo.hostSymbols.push_back(sym);
530642
}
531643

532644
//===----------------------------------------------------------------------===//
@@ -2585,6 +2697,7 @@ genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
25852697

25862698
static mlir::omp::ParallelOp
25872699
genParallelOp(Fortran::lower::AbstractConverter &converter,
2700+
Fortran::lower::SymMap &symTable,
25882701
Fortran::semantics::SemanticsContext &semaCtx,
25892702
Fortran::lower::pft::Evaluation &eval, bool genNested,
25902703
mlir::Location currentLocation,
@@ -2617,31 +2730,99 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
26172730
auto reductionCallback = [&](mlir::Operation *op) {
26182731
llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
26192732
currentLocation);
2620-
auto block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
2621-
reductionTypes, locs);
2733+
auto *block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
2734+
reductionTypes, locs);
26222735
for (auto [arg, prv] :
26232736
llvm::zip_equal(reductionSymbols, block->getArguments())) {
26242737
converter.bindSymbol(*arg, prv);
26252738
}
26262739
return reductionSymbols;
26272740
};
26282741

2629-
return genOpWithBody<mlir::omp::ParallelOp>(
2742+
OpWithBodyGenInfo genInfo =
26302743
OpWithBodyGenInfo(converter, semaCtx, currentLocation, eval)
26312744
.setGenNested(genNested)
26322745
.setOuterCombined(outerCombined)
26332746
.setClauses(&clauseList)
26342747
.setReductions(&reductionSymbols, &reductionTypes)
2635-
.setGenRegionEntryCb(reductionCallback),
2748+
.setGenRegionEntryCb(reductionCallback);
2749+
2750+
if (!enableDelayedPrivatization) {
2751+
return genOpWithBody<mlir::omp::ParallelOp>(
2752+
genInfo,
2753+
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
2754+
numThreadsClauseOperand, allocateOperands, allocatorOperands,
2755+
reductionVars,
2756+
reductionDeclSymbols.empty()
2757+
? nullptr
2758+
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
2759+
reductionDeclSymbols),
2760+
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
2761+
/*privatizers=*/nullptr);
2762+
}
2763+
2764+
bool privatize = !outerCombined;
2765+
DataSharingProcessor dsp(converter, clauseList, eval,
2766+
/*useDelayedPrivatization=*/true, &symTable);
2767+
2768+
if (privatize)
2769+
dsp.processStep1();
2770+
2771+
const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo();
2772+
2773+
auto genRegionEntryCB = [&](mlir::Operation *op) {
2774+
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
2775+
2776+
llvm::SmallVector<mlir::Location> reductionLocs(reductionVars.size(),
2777+
currentLocation);
2778+
2779+
auto privateVars = parallelOp.getPrivateVars();
2780+
auto &region = parallelOp.getRegion();
2781+
2782+
llvm::SmallVector<mlir::Type> privateVarTypes = reductionTypes;
2783+
privateVarTypes.reserve(privateVars.size());
2784+
llvm::transform(privateVars, std::back_inserter(privateVarTypes),
2785+
[](mlir::Value v) { return v.getType(); });
2786+
2787+
llvm::SmallVector<mlir::Location> privateVarLocs = reductionLocs;
2788+
privateVarLocs.reserve(privateVars.size());
2789+
llvm::transform(privateVars, std::back_inserter(privateVarLocs),
2790+
[](mlir::Value v) { return v.getLoc(); });
2791+
2792+
converter.getFirOpBuilder().createBlock(&region, /*insertPt=*/{},
2793+
privateVarTypes, privateVarLocs);
2794+
2795+
llvm::SmallVector<const Fortran::semantics::Symbol *> allSymbols =
2796+
reductionSymbols;
2797+
allSymbols.append(delayedPrivatizationInfo.hostSymbols);
2798+
for (auto [arg, prv] : llvm::zip_equal(allSymbols, region.getArguments())) {
2799+
converter.bindSymbol(*arg, prv);
2800+
}
2801+
2802+
return allSymbols;
2803+
};
2804+
2805+
// TODO Merge with the reduction CB.
2806+
genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
2807+
2808+
llvm::SmallVector<mlir::Attribute> privatizers(
2809+
delayedPrivatizationInfo.privatizers.begin(),
2810+
delayedPrivatizationInfo.privatizers.end());
2811+
2812+
return genOpWithBody<mlir::omp::ParallelOp>(
2813+
genInfo,
26362814
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
26372815
numThreadsClauseOperand, allocateOperands, allocatorOperands,
26382816
reductionVars,
26392817
reductionDeclSymbols.empty()
26402818
? nullptr
26412819
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
26422820
reductionDeclSymbols),
2643-
procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
2644-
/*privatizers=*/nullptr);
2821+
procBindKindAttr, delayedPrivatizationInfo.hostAddresses,
2822+
delayedPrivatizationInfo.privatizers.empty()
2823+
? nullptr
2824+
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
2825+
privatizers));
26452826
}
26462827

26472828
static mlir::omp::SectionOp
@@ -3633,7 +3814,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
36333814
if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
36343815
.test(ompDirective)) {
36353816
validDirective = true;
3636-
genParallelOp(converter, semaCtx, eval, /*genNested=*/false,
3817+
genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
36373818
currentLocation, loopOpClauseList,
36383819
/*outerCombined=*/true);
36393820
}
@@ -3722,8 +3903,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
37223903
currentLocation);
37233904
break;
37243905
case llvm::omp::Directive::OMPD_parallel:
3725-
genParallelOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
3726-
beginClauseList);
3906+
genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/true,
3907+
currentLocation, beginClauseList);
37273908
break;
37283909
case llvm::omp::Directive::OMPD_single:
37293910
genSingleOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation,
@@ -3780,7 +3961,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
37803961
.test(directive.v)) {
37813962
bool outerCombined =
37823963
directive.v != llvm::omp::Directive::OMPD_target_parallel;
3783-
genParallelOp(converter, semaCtx, eval, /*genNested=*/false,
3964+
genParallelOp(converter, symTable, semaCtx, eval, /*genNested=*/false,
37843965
currentLocation, beginClauseList, outerCombined);
37853966
combinedDirective = true;
37863967
}
@@ -3863,7 +4044,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
38634044

38644045
// Parallel wrapper of PARALLEL SECTIONS construct
38654046
if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
3866-
genParallelOp(converter, semaCtx, eval,
4047+
genParallelOp(converter, symTable, semaCtx, eval,
38674048
/*genNested=*/false, currentLocation, sectionsClauseList,
38684049
/*outerCombined=*/true);
38694050
} else {

0 commit comments

Comments
 (0)