Skip to content

Commit e0c5f0b

Browse files
committed
[MLIR][OpenMP] Support basic materialization for omp.private ops
Adds basic support for materializing delayed privatization. So far, the restrictions on the implementation are: - Only `private` clauses are supported (`firstprivate` support will be added in a later PR). - Only single-block `omp.private -> alloc` regions are supported (multi-block ones will be supported in a later PR).
1 parent 8c5e9cf commit e0c5f0b

File tree

3 files changed

+304
-44
lines changed

3 files changed

+304
-44
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1957,7 +1957,10 @@ LogicalResult PrivateClauseOp::verify() {
19571957
Type symType = getType();
19581958

19591959
auto verifyTerminator = [&](Operation *terminator) -> LogicalResult {
1960-
if (!terminator->hasSuccessors() && !llvm::isa<YieldOp>(terminator))
1960+
if (!terminator->getBlock()->getSuccessors().empty())
1961+
return success();
1962+
1963+
if (!llvm::isa<YieldOp>(terminator))
19611964
return mlir::emitError(terminator->getLoc())
19621965
<< "expected exit block terminator to be an `omp.yield` op.";
19631966

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 158 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ collectReductionDecls(T loop,
396396

397397
/// Translates the blocks contained in the given region and appends them to at
398398
/// the current insertion point of `builder`. The operations of the entry block
399-
/// are appended to the current insertion block, which is not expected to have a
400-
/// terminator. If set, `continuationBlockArgs` is populated with translated
401-
/// values that correspond to the values omp.yield'ed from the region.
399+
/// are appended to the current insertion block. If set, `continuationBlockArgs`
400+
/// is populated with translated values that correspond to the values
401+
/// omp.yield'ed from the region.
402402
static LogicalResult inlineConvertOmpRegions(
403403
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
404404
LLVM::ModuleTranslation &moduleTranslation,
@@ -409,7 +409,14 @@ static LogicalResult inlineConvertOmpRegions(
409409
// Special case for single-block regions that don't create additional blocks:
410410
// insert operations without creating additional blocks.
411411
if (llvm::hasSingleElement(region)) {
412+
llvm::Instruction *potentialTerminator =
413+
builder.GetInsertBlock()->empty() ? nullptr
414+
: &builder.GetInsertBlock()->back();
415+
416+
if (potentialTerminator && potentialTerminator->isTerminator())
417+
potentialTerminator->removeFromParent();
412418
moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
419+
413420
if (failed(moduleTranslation.convertBlock(
414421
region.front(), /*ignoreArguments=*/true, builder)))
415422
return failure();
@@ -423,6 +430,10 @@ static LogicalResult inlineConvertOmpRegions(
423430
// Drop the mapping that is no longer necessary so that the same region can
424431
// be processed multiple times.
425432
moduleTranslation.forgetMapping(region);
433+
434+
if (potentialTerminator && potentialTerminator->isTerminator())
435+
potentialTerminator->insertAfter(&builder.GetInsertBlock()->back());
436+
426437
return success();
427438
}
428439

@@ -1000,11 +1011,41 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
10001011
return success();
10011012
}
10021013

1014+
/// Replace the region arguments of the parallel op (which correspond to private
1015+
/// variables) with the actual private variables they correspond to. This
1016+
/// prepares the parallel op so that it matches what is expected by the
1017+
/// OMPIRBuilder. Instead of editing the original op in-place, this function
1018+
/// does the required changes to a cloned version which should then be erased by
1019+
/// the caller.
1020+
static omp::ParallelOp
1021+
prepareOmpParallelForPrivatization(omp::ParallelOp opInst) {
1022+
Region &region = opInst.getRegion();
1023+
auto privateVars = opInst.getPrivateVars();
1024+
1025+
auto privateVarsIt = privateVars.begin();
1026+
// Reduction precede private arguments, so skip them first.
1027+
unsigned privateArgBeginIdx = opInst.getNumReductionVars();
1028+
unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size();
1029+
1030+
mlir::IRMapping mapping;
1031+
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1032+
++argIdx, ++privateVarsIt)
1033+
mapping.map(region.getArgument(argIdx), *privateVarsIt);
1034+
1035+
mlir::OpBuilder cloneBuilder(opInst);
1036+
omp::ParallelOp opInstClone =
1037+
llvm::cast<omp::ParallelOp>(cloneBuilder.clone(*opInst, mapping));
1038+
1039+
return opInstClone;
1040+
}
1041+
10031042
/// Converts the OpenMP parallel operation to LLVM IR.
10041043
static LogicalResult
10051044
convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10061045
LLVM::ModuleTranslation &moduleTranslation) {
10071046
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1047+
omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization(opInst);
1048+
10081049
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
10091050
// relying on captured variables.
10101051
LogicalResult bodyGenStatus = success();
@@ -1013,12 +1054,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10131054
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
10141055
// Collect reduction declarations
10151056
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1016-
collectReductionDecls(opInst, reductionDecls);
1057+
collectReductionDecls(opInstClone, reductionDecls);
10171058

10181059
// Allocate reduction vars
10191060
SmallVector<llvm::Value *> privateReductionVariables;
10201061
DenseMap<Value, llvm::Value *> reductionVariableMap;
1021-
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
1062+
allocReductionVars(opInstClone, builder, moduleTranslation, allocaIP,
10221063
reductionDecls, privateReductionVariables,
10231064
reductionVariableMap);
10241065

@@ -1030,7 +1071,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10301071

10311072
// Initialize reduction vars
10321073
builder.restoreIP(allocaIP);
1033-
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
1074+
for (unsigned i = 0; i < opInstClone.getNumReductionVars(); ++i) {
10341075
SmallVector<llvm::Value *> phis;
10351076
if (failed(inlineConvertOmpRegions(
10361077
reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
@@ -1051,18 +1092,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10511092
// ParallelOp has only one region associated with it.
10521093
builder.restoreIP(codeGenIP);
10531094
auto regionBlock =
1054-
convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
1095+
convertOmpOpRegions(opInstClone.getRegion(), "omp.par.region", builder,
10551096
moduleTranslation, bodyGenStatus);
10561097

10571098
// Process the reductions if required.
1058-
if (opInst.getNumReductionVars() > 0) {
1099+
if (opInstClone.getNumReductionVars() > 0) {
10591100
// Collect reduction info
10601101
SmallVector<OwningReductionGen> owningReductionGens;
10611102
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
10621103
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1063-
collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
1064-
owningReductionGens, owningAtomicReductionGens,
1065-
privateReductionVariables, reductionInfos);
1104+
collectReductionInfo(opInstClone, builder, moduleTranslation,
1105+
reductionDecls, owningReductionGens,
1106+
owningAtomicReductionGens, privateReductionVariables,
1107+
reductionInfos);
10661108

10671109
// Move to region cont block
10681110
builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1075,7 +1117,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10751117
ompBuilder->createReductions(builder.saveIP(), allocaIP,
10761118
reductionInfos, false);
10771119
if (!contInsertPoint.getBlock()) {
1078-
bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
1120+
bodyGenStatus = opInstClone->emitOpError()
1121+
<< "failed to convert reductions";
10791122
return;
10801123
}
10811124

@@ -1086,12 +1129,82 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10861129

10871130
// TODO: Perform appropriate actions according to the data-sharing
10881131
// attribute (shared, private, firstprivate, ...) of variables.
1089-
// Currently defaults to shared.
1132+
// Currently shared and private are supported.
10901133
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
10911134
llvm::Value &, llvm::Value &vPtr,
10921135
llvm::Value *&replacementValue) -> InsertPointTy {
10931136
replacementValue = &vPtr;
10941137

1138+
// If this is a private value, this lambda will return the corresponding
1139+
// mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1140+
// returned.
1141+
auto [privVar, privatizerClone] =
1142+
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1143+
if (!opInstClone.getPrivateVars().empty()) {
1144+
auto privVars = opInstClone.getPrivateVars();
1145+
auto privatizers = opInstClone.getPrivatizers();
1146+
1147+
for (auto [privVar, privatizerAttr] :
1148+
llvm::zip_equal(privVars, *privatizers)) {
1149+
// Find the MLIR private variable corresponding to the LLVM value
1150+
// being privatized.
1151+
llvm::Value *llvmPrivVar = moduleTranslation.lookupValue(privVar);
1152+
if (llvmPrivVar != &vPtr)
1153+
continue;
1154+
1155+
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1156+
omp::PrivateClauseOp privatizer =
1157+
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1158+
opInstClone, privSym);
1159+
1160+
// Clone the privatizer in case it is used by more than one parallel
1161+
// region. The privatizer is processed in-place (see below) before it
1162+
// gets inlined in the parallel region and therefore processing the
1163+
// original op is dangerous.
1164+
return {privVar, privatizer.clone()};
1165+
}
1166+
}
1167+
1168+
return {mlir::Value(), omp::PrivateClauseOp()};
1169+
}();
1170+
1171+
if (privVar) {
1172+
if (privatizerClone.getDataSharingType() ==
1173+
omp::DataSharingClauseType::FirstPrivate) {
1174+
privatizerClone.emitOpError(
1175+
"TODO: delayed privatization is not "
1176+
"supported for `firstprivate` clauses yet.");
1177+
bodyGenStatus = failure();
1178+
return codeGenIP;
1179+
}
1180+
1181+
Region &allocRegion = privatizerClone.getAllocRegion();
1182+
1183+
// Replace the privatizer block argument with mlir value being privatized.
1184+
// This way, the body of the privatizer will be changed from using the
1185+
// region/block argument to the value being privatized.
1186+
auto allocRegionArg = allocRegion.getArgument(0);
1187+
replaceAllUsesInRegionWith(allocRegionArg, privVar, allocRegion);
1188+
1189+
auto oldIP = builder.saveIP();
1190+
builder.restoreIP(allocaIP);
1191+
1192+
SmallVector<llvm::Value *, 1> yieldedValues;
1193+
if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
1194+
moduleTranslation, &yieldedValues))) {
1195+
opInstClone.emitError(
1196+
"failed to inline `alloc` region of an `omp.private` "
1197+
"op in the parallel region");
1198+
bodyGenStatus = failure();
1199+
} else {
1200+
assert(yieldedValues.size() == 1);
1201+
replacementValue = yieldedValues.front();
1202+
}
1203+
1204+
privatizerClone.erase();
1205+
builder.restoreIP(oldIP);
1206+
}
1207+
10951208
return codeGenIP;
10961209
};
10971210

@@ -1100,13 +1213,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11001213
auto finiCB = [&](InsertPointTy codeGenIP) {};
11011214

11021215
llvm::Value *ifCond = nullptr;
1103-
if (auto ifExprVar = opInst.getIfExprVar())
1216+
if (auto ifExprVar = opInstClone.getIfExprVar())
11041217
ifCond = moduleTranslation.lookupValue(ifExprVar);
11051218
llvm::Value *numThreads = nullptr;
1106-
if (auto numThreadsVar = opInst.getNumThreadsVar())
1219+
if (auto numThreadsVar = opInstClone.getNumThreadsVar())
11071220
numThreads = moduleTranslation.lookupValue(numThreadsVar);
11081221
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1109-
if (auto bind = opInst.getProcBindVal())
1222+
if (auto bind = opInstClone.getProcBindVal())
11101223
pbKind = getProcBindKind(*bind);
11111224
// TODO: Is the Parallel construct cancellable?
11121225
bool isCancellable = false;
@@ -1119,6 +1232,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11191232
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
11201233
ifCond, numThreads, pbKind, isCancellable));
11211234

1235+
opInstClone.erase();
11221236
return bodyGenStatus;
11231237
}
11241238

@@ -1635,7 +1749,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
16351749
// A small helper structure to contain data gathered
16361750
// for map lowering and coalese it into one area and
16371751
// avoiding extra computations such as searches in the
1638-
// llvm module for lowered mapped varibles or checking
1752+
// llvm module for lowered mapped variables or checking
16391753
// if something is declare target (and retrieving the
16401754
// value) more than neccessary.
16411755
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
@@ -2854,26 +2968,26 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
28542968
moduleTranslation);
28552969
return failure();
28562970
})
2857-
.Case(
2858-
"omp.requires",
2859-
[&](Attribute attr) {
2860-
if (auto requiresAttr = attr.dyn_cast<omp::ClauseRequiresAttr>()) {
2861-
using Requires = omp::ClauseRequires;
2862-
Requires flags = requiresAttr.getValue();
2863-
llvm::OpenMPIRBuilderConfig &config =
2864-
moduleTranslation.getOpenMPBuilder()->Config;
2865-
config.setHasRequiresReverseOffload(
2866-
bitEnumContainsAll(flags, Requires::reverse_offload));
2867-
config.setHasRequiresUnifiedAddress(
2868-
bitEnumContainsAll(flags, Requires::unified_address));
2869-
config.setHasRequiresUnifiedSharedMemory(
2870-
bitEnumContainsAll(flags, Requires::unified_shared_memory));
2871-
config.setHasRequiresDynamicAllocators(
2872-
bitEnumContainsAll(flags, Requires::dynamic_allocators));
2873-
return success();
2874-
}
2875-
return failure();
2876-
})
2971+
.Case("omp.requires",
2972+
[&](Attribute attr) {
2973+
if (auto requiresAttr =
2974+
attr.dyn_cast<omp::ClauseRequiresAttr>()) {
2975+
using Requires = omp::ClauseRequires;
2976+
Requires flags = requiresAttr.getValue();
2977+
llvm::OpenMPIRBuilderConfig &config =
2978+
moduleTranslation.getOpenMPBuilder()->Config;
2979+
config.setHasRequiresReverseOffload(
2980+
bitEnumContainsAll(flags, Requires::reverse_offload));
2981+
config.setHasRequiresUnifiedAddress(
2982+
bitEnumContainsAll(flags, Requires::unified_address));
2983+
config.setHasRequiresUnifiedSharedMemory(
2984+
bitEnumContainsAll(flags, Requires::unified_shared_memory));
2985+
config.setHasRequiresDynamicAllocators(
2986+
bitEnumContainsAll(flags, Requires::dynamic_allocators));
2987+
return success();
2988+
}
2989+
return failure();
2990+
})
28772991
.Default([](Attribute) {
28782992
// Fall through for omp attributes that do not require lowering.
28792993
return success();
@@ -2988,12 +3102,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
29883102
.Case([&](omp::TargetOp) {
29893103
return convertOmpTarget(*op, builder, moduleTranslation);
29903104
})
2991-
.Case<omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
2992-
// No-op, should be handled by relevant owning operations e.g.
2993-
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
2994-
// discarded
2995-
return success();
2996-
})
3105+
.Case<omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
3106+
[&](auto op) {
3107+
// No-op, should be handled by relevant owning operations e.g.
3108+
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3109+
// discarded
3110+
return success();
3111+
})
29973112
.Default([&](Operation *inst) {
29983113
return inst->emitError("unsupported OpenMP operation: ")
29993114
<< inst->getName();

0 commit comments

Comments
 (0)