Skip to content

Commit 9d56be0

Browse files
authored
[MLIR][OpenMP] Support basic materialization for omp.private ops (#81715)
Adds basic support for materializing delayed privatization. So far, the restrictions on the implementation are: - Only `private` clauses are supported (`firstprivate` support will be added in a later PR).
1 parent 7c206c7 commit 9d56be0

File tree

3 files changed

+297
-32
lines changed

3 files changed

+297
-32
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1957,7 +1957,10 @@ LogicalResult PrivateClauseOp::verify() {
19571957
Type symType = getType();
19581958

19591959
auto verifyTerminator = [&](Operation *terminator) -> LogicalResult {
1960-
if (!terminator->hasSuccessors() && !llvm::isa<YieldOp>(terminator))
1960+
if (!terminator->getBlock()->getSuccessors().empty())
1961+
return success();
1962+
1963+
if (!llvm::isa<YieldOp>(terminator))
19611964
return mlir::emitError(terminator->getLoc())
19621965
<< "expected exit block terminator to be an `omp.yield` op.";
19631966

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 151 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ collectReductionDecls(T loop,
396396

397397
/// Translates the blocks contained in the given region and appends them to at
398398
/// the current insertion point of `builder`. The operations of the entry block
399-
/// are appended to the current insertion block, which is not expected to have a
400-
/// terminator. If set, `continuationBlockArgs` is populated with translated
401-
/// values that correspond to the values omp.yield'ed from the region.
399+
/// are appended to the current insertion block. If set, `continuationBlockArgs`
400+
/// is populated with translated values that correspond to the values
401+
/// omp.yield'ed from the region.
402402
static LogicalResult inlineConvertOmpRegions(
403403
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
404404
LLVM::ModuleTranslation &moduleTranslation,
@@ -409,7 +409,14 @@ static LogicalResult inlineConvertOmpRegions(
409409
// Special case for single-block regions that don't create additional blocks:
410410
// insert operations without creating additional blocks.
411411
if (llvm::hasSingleElement(region)) {
412+
llvm::Instruction *potentialTerminator =
413+
builder.GetInsertBlock()->empty() ? nullptr
414+
: &builder.GetInsertBlock()->back();
415+
416+
if (potentialTerminator && potentialTerminator->isTerminator())
417+
potentialTerminator->removeFromParent();
412418
moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
419+
413420
if (failed(moduleTranslation.convertBlock(
414421
region.front(), /*ignoreArguments=*/true, builder)))
415422
return failure();
@@ -423,6 +430,10 @@ static LogicalResult inlineConvertOmpRegions(
423430
// Drop the mapping that is no longer necessary so that the same region can
424431
// be processed multiple times.
425432
moduleTranslation.forgetMapping(region);
433+
434+
if (potentialTerminator && potentialTerminator->isTerminator())
435+
potentialTerminator->insertAfter(&builder.GetInsertBlock()->back());
436+
426437
return success();
427438
}
428439

@@ -1000,11 +1011,50 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
10001011
return success();
10011012
}
10021013

1014+
/// A RAII class that on construction replaces the region arguments of the
1015+
/// parallel op (which correspond to private variables) with the actual private
1016+
/// variables they correspond to. This prepares the parallel op so that it
1017+
/// matches what is expected by the OMPIRBuilder.
1018+
///
1019+
/// On destruction, it restores the original state of the operation so that on
1020+
/// the MLIR side, the op is not affected by conversion to LLVM IR.
1021+
class OmpParallelOpConversionManager {
1022+
public:
1023+
OmpParallelOpConversionManager(omp::ParallelOp opInst)
1024+
: region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
1025+
privateArgBeginIdx(opInst.getNumReductionVars()),
1026+
privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
1027+
auto privateVarsIt = privateVars.begin();
1028+
1029+
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1030+
++argIdx, ++privateVarsIt)
1031+
mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
1032+
*privateVarsIt, region);
1033+
}
1034+
1035+
~OmpParallelOpConversionManager() {
1036+
auto privateVarsIt = privateVars.begin();
1037+
1038+
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1039+
++argIdx, ++privateVarsIt)
1040+
mlir::replaceAllUsesInRegionWith(*privateVarsIt,
1041+
region.getArgument(argIdx), region);
1042+
}
1043+
1044+
private:
1045+
Region &region;
1046+
OperandRange privateVars;
1047+
unsigned privateArgBeginIdx;
1048+
unsigned privateArgEndIdx;
1049+
};
1050+
10031051
/// Converts the OpenMP parallel operation to LLVM IR.
10041052
static LogicalResult
10051053
convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10061054
LLVM::ModuleTranslation &moduleTranslation) {
10071055
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1056+
OmpParallelOpConversionManager raii(opInst);
1057+
10081058
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
10091059
// relying on captured variables.
10101060
LogicalResult bodyGenStatus = success();
@@ -1086,12 +1136,81 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10861136

10871137
// TODO: Perform appropriate actions according to the data-sharing
10881138
// attribute (shared, private, firstprivate, ...) of variables.
1089-
// Currently defaults to shared.
1139+
// Currently shared and private are supported.
10901140
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
10911141
llvm::Value &, llvm::Value &vPtr,
10921142
llvm::Value *&replacementValue) -> InsertPointTy {
10931143
replacementValue = &vPtr;
10941144

1145+
// If this is a private value, this lambda will return the corresponding
1146+
// mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1147+
// returned.
1148+
auto [privVar, privatizerClone] =
1149+
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1150+
if (!opInst.getPrivateVars().empty()) {
1151+
auto privVars = opInst.getPrivateVars();
1152+
auto privatizers = opInst.getPrivatizers();
1153+
1154+
for (auto [privVar, privatizerAttr] :
1155+
llvm::zip_equal(privVars, *privatizers)) {
1156+
// Find the MLIR private variable corresponding to the LLVM value
1157+
// being privatized.
1158+
llvm::Value *llvmPrivVar = moduleTranslation.lookupValue(privVar);
1159+
if (llvmPrivVar != &vPtr)
1160+
continue;
1161+
1162+
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1163+
omp::PrivateClauseOp privatizer =
1164+
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1165+
opInst, privSym);
1166+
1167+
// Clone the privatizer in case it is used by more than one parallel
1168+
// region. The privatizer is processed in-place (see below) before it
1169+
// gets inlined in the parallel region and therefore processing the
1170+
// original op is dangerous.
1171+
return {privVar, privatizer.clone()};
1172+
}
1173+
}
1174+
1175+
return {mlir::Value(), omp::PrivateClauseOp()};
1176+
}();
1177+
1178+
if (privVar) {
1179+
if (privatizerClone.getDataSharingType() ==
1180+
omp::DataSharingClauseType::FirstPrivate) {
1181+
privatizerClone.emitOpError(
1182+
"TODO: delayed privatization is not "
1183+
"supported for `firstprivate` clauses yet.");
1184+
bodyGenStatus = failure();
1185+
return codeGenIP;
1186+
}
1187+
1188+
Region &allocRegion = privatizerClone.getAllocRegion();
1189+
1190+
// Replace the privatizer block argument with mlir value being privatized.
1191+
// This way, the body of the privatizer will be changed from using the
1192+
// region/block argument to the value being privatized.
1193+
auto allocRegionArg = allocRegion.getArgument(0);
1194+
replaceAllUsesInRegionWith(allocRegionArg, privVar, allocRegion);
1195+
1196+
auto oldIP = builder.saveIP();
1197+
builder.restoreIP(allocaIP);
1198+
1199+
SmallVector<llvm::Value *, 1> yieldedValues;
1200+
if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
1201+
moduleTranslation, &yieldedValues))) {
1202+
opInst.emitError("failed to inline `alloc` region of an `omp.private` "
1203+
"op in the parallel region");
1204+
bodyGenStatus = failure();
1205+
} else {
1206+
assert(yieldedValues.size() == 1);
1207+
replacementValue = yieldedValues.front();
1208+
}
1209+
1210+
privatizerClone.erase();
1211+
builder.restoreIP(oldIP);
1212+
}
1213+
10951214
return codeGenIP;
10961215
};
10971216

@@ -1635,7 +1754,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
16351754
// A small helper structure to contain data gathered
16361755
// for map lowering and coalese it into one area and
16371756
// avoiding extra computations such as searches in the
1638-
// llvm module for lowered mapped varibles or checking
1757+
// llvm module for lowered mapped variables or checking
16391758
// if something is declare target (and retrieving the
16401759
// value) more than neccessary.
16411760
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
@@ -2854,26 +2973,26 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
28542973
moduleTranslation);
28552974
return failure();
28562975
})
2857-
.Case(
2858-
"omp.requires",
2859-
[&](Attribute attr) {
2860-
if (auto requiresAttr = attr.dyn_cast<omp::ClauseRequiresAttr>()) {
2861-
using Requires = omp::ClauseRequires;
2862-
Requires flags = requiresAttr.getValue();
2863-
llvm::OpenMPIRBuilderConfig &config =
2864-
moduleTranslation.getOpenMPBuilder()->Config;
2865-
config.setHasRequiresReverseOffload(
2866-
bitEnumContainsAll(flags, Requires::reverse_offload));
2867-
config.setHasRequiresUnifiedAddress(
2868-
bitEnumContainsAll(flags, Requires::unified_address));
2869-
config.setHasRequiresUnifiedSharedMemory(
2870-
bitEnumContainsAll(flags, Requires::unified_shared_memory));
2871-
config.setHasRequiresDynamicAllocators(
2872-
bitEnumContainsAll(flags, Requires::dynamic_allocators));
2873-
return success();
2874-
}
2875-
return failure();
2876-
})
2976+
.Case("omp.requires",
2977+
[&](Attribute attr) {
2978+
if (auto requiresAttr =
2979+
attr.dyn_cast<omp::ClauseRequiresAttr>()) {
2980+
using Requires = omp::ClauseRequires;
2981+
Requires flags = requiresAttr.getValue();
2982+
llvm::OpenMPIRBuilderConfig &config =
2983+
moduleTranslation.getOpenMPBuilder()->Config;
2984+
config.setHasRequiresReverseOffload(
2985+
bitEnumContainsAll(flags, Requires::reverse_offload));
2986+
config.setHasRequiresUnifiedAddress(
2987+
bitEnumContainsAll(flags, Requires::unified_address));
2988+
config.setHasRequiresUnifiedSharedMemory(
2989+
bitEnumContainsAll(flags, Requires::unified_shared_memory));
2990+
config.setHasRequiresDynamicAllocators(
2991+
bitEnumContainsAll(flags, Requires::dynamic_allocators));
2992+
return success();
2993+
}
2994+
return failure();
2995+
})
28772996
.Default([](Attribute) {
28782997
// Fall through for omp attributes that do not require lowering.
28792998
return success();
@@ -2988,12 +3107,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
29883107
.Case([&](omp::TargetOp) {
29893108
return convertOmpTarget(*op, builder, moduleTranslation);
29903109
})
2991-
.Case<omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
2992-
// No-op, should be handled by relevant owning operations e.g.
2993-
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
2994-
// discarded
2995-
return success();
2996-
})
3110+
.Case<omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
3111+
[&](auto op) {
3112+
// No-op, should be handled by relevant owning operations e.g.
3113+
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3114+
// discarded
3115+
return success();
3116+
})
29973117
.Default([&](Operation *inst) {
29983118
return inst->emitError("unsupported OpenMP operation: ")
29993119
<< inst->getName();

0 commit comments

Comments
 (0)