Skip to content

[MLIR][OpenMP] Support basic materialization for omp.private ops #81715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1957,7 +1957,10 @@ LogicalResult PrivateClauseOp::verify() {
Type symType = getType();

auto verifyTerminator = [&](Operation *terminator) -> LogicalResult {
if (!terminator->hasSuccessors() && !llvm::isa<YieldOp>(terminator))
if (!terminator->getBlock()->getSuccessors().empty())
return success();

if (!llvm::isa<YieldOp>(terminator))
return mlir::emitError(terminator->getLoc())
<< "expected exit block terminator to be an `omp.yield` op.";

Expand Down
182 changes: 151 additions & 31 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,9 @@ collectReductionDecls(T loop,

/// Translates the blocks contained in the given region and appends them to at
/// the current insertion point of `builder`. The operations of the entry block
/// are appended to the current insertion block, which is not expected to have a
/// terminator. If set, `continuationBlockArgs` is populated with translated
/// values that correspond to the values omp.yield'ed from the region.
/// are appended to the current insertion block. If set, `continuationBlockArgs`
/// is populated with translated values that correspond to the values
/// omp.yield'ed from the region.
static LogicalResult inlineConvertOmpRegions(
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
Expand All @@ -409,7 +409,14 @@ static LogicalResult inlineConvertOmpRegions(
// Special case for single-block regions that don't create additional blocks:
// insert operations without creating additional blocks.
if (llvm::hasSingleElement(region)) {
llvm::Instruction *potentialTerminator =
builder.GetInsertBlock()->empty() ? nullptr
: &builder.GetInsertBlock()->back();

if (potentialTerminator && potentialTerminator->isTerminator())
potentialTerminator->removeFromParent();
moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());

if (failed(moduleTranslation.convertBlock(
region.front(), /*ignoreArguments=*/true, builder)))
return failure();
Expand All @@ -423,6 +430,10 @@ static LogicalResult inlineConvertOmpRegions(
// Drop the mapping that is no longer necessary so that the same region can
// be processed multiple times.
moduleTranslation.forgetMapping(region);

if (potentialTerminator && potentialTerminator->isTerminator())
potentialTerminator->insertAfter(&builder.GetInsertBlock()->back());

return success();
}

Expand Down Expand Up @@ -1000,11 +1011,50 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}

/// A RAII class that on construction replaces the region arguments of the
/// parallel op (which correspond to private variables) with the actual private
/// variables they correspond to. This prepares the parallel op so that it
/// matches what is expected by the OMPIRBuilder.
///
/// On destruction, it restores the original state of the operation so that on
/// the MLIR side, the op is not affected by conversion to LLVM IR.
class OmpParallelOpConversionManager {
public:
OmpParallelOpConversionManager(omp::ParallelOp opInst)
: region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
privateArgBeginIdx(opInst.getNumReductionVars()),
privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
auto privateVarsIt = privateVars.begin();

for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
++argIdx, ++privateVarsIt)
mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
*privateVarsIt, region);
}

~OmpParallelOpConversionManager() {
auto privateVarsIt = privateVars.begin();

for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
++argIdx, ++privateVarsIt)
mlir::replaceAllUsesInRegionWith(*privateVarsIt,
region.getArgument(argIdx), region);
}

private:
Region &region;
OperandRange privateVars;
unsigned privateArgBeginIdx;
unsigned privateArgEndIdx;
};

/// Converts the OpenMP parallel operation to LLVM IR.
static LogicalResult
convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
OmpParallelOpConversionManager raii(opInst);

// TODO: support error propagation in OpenMPIRBuilder and use it instead of
// relying on captured variables.
LogicalResult bodyGenStatus = success();
Expand Down Expand Up @@ -1086,12 +1136,81 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,

// TODO: Perform appropriate actions according to the data-sharing
// attribute (shared, private, firstprivate, ...) of variables.
// Currently defaults to shared.
// Currently shared and private are supported.
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
llvm::Value &, llvm::Value &vPtr,
llvm::Value *&replacementValue) -> InsertPointTy {
replacementValue = &vPtr;

// If this is a private value, this lambda will return the corresponding
// mlir value and its `PrivateClauseOp`. Otherwise, empty values are
// returned.
auto [privVar, privatizerClone] =
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
if (!opInst.getPrivateVars().empty()) {
auto privVars = opInst.getPrivateVars();
auto privatizers = opInst.getPrivatizers();

for (auto [privVar, privatizerAttr] :
llvm::zip_equal(privVars, *privatizers)) {
// Find the MLIR private variable corresponding to the LLVM value
// being privatized.
llvm::Value *llvmPrivVar = moduleTranslation.lookupValue(privVar);
if (llvmPrivVar != &vPtr)
continue;

SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
omp::PrivateClauseOp privatizer =
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
opInst, privSym);

// Clone the privatizer in case it is used by more than one parallel
// region. The privatizer is processed in-place (see below) before it
// gets inlined in the parallel region and therefore processing the
// original op is dangerous.
return {privVar, privatizer.clone()};
}
}

return {mlir::Value(), omp::PrivateClauseOp()};
}();

if (privVar) {
if (privatizerClone.getDataSharingType() ==
omp::DataSharingClauseType::FirstPrivate) {
privatizerClone.emitOpError(
"TODO: delayed privatization is not "
"supported for `firstprivate` clauses yet.");
bodyGenStatus = failure();
return codeGenIP;
}

Region &allocRegion = privatizerClone.getAllocRegion();

// Replace the privatizer block argument with mlir value being privatized.
// This way, the body of the privatizer will be changed from using the
// region/block argument to the value being privatized.
auto allocRegionArg = allocRegion.getArgument(0);
replaceAllUsesInRegionWith(allocRegionArg, privVar, allocRegion);

auto oldIP = builder.saveIP();
builder.restoreIP(allocaIP);

SmallVector<llvm::Value *, 1> yieldedValues;
if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
moduleTranslation, &yieldedValues))) {
opInst.emitError("failed to inline `alloc` region of an `omp.private` "
"op in the parallel region");
bodyGenStatus = failure();
} else {
assert(yieldedValues.size() == 1);
replacementValue = yieldedValues.front();
}

privatizerClone.erase();
builder.restoreIP(oldIP);
}

return codeGenIP;
};

Expand Down Expand Up @@ -1635,7 +1754,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
// A small helper structure to contain data gathered
// for map lowering and coalese it into one area and
// avoiding extra computations such as searches in the
// llvm module for lowered mapped varibles or checking
// llvm module for lowered mapped variables or checking
// if something is declare target (and retrieving the
// value) more than neccessary.
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
Expand Down Expand Up @@ -2854,26 +2973,26 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
moduleTranslation);
return failure();
})
.Case(
"omp.requires",
[&](Attribute attr) {
if (auto requiresAttr = attr.dyn_cast<omp::ClauseRequiresAttr>()) {
using Requires = omp::ClauseRequires;
Requires flags = requiresAttr.getValue();
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.setHasRequiresReverseOffload(
bitEnumContainsAll(flags, Requires::reverse_offload));
config.setHasRequiresUnifiedAddress(
bitEnumContainsAll(flags, Requires::unified_address));
config.setHasRequiresUnifiedSharedMemory(
bitEnumContainsAll(flags, Requires::unified_shared_memory));
config.setHasRequiresDynamicAllocators(
bitEnumContainsAll(flags, Requires::dynamic_allocators));
return success();
}
return failure();
})
.Case("omp.requires",
[&](Attribute attr) {
if (auto requiresAttr =
attr.dyn_cast<omp::ClauseRequiresAttr>()) {
using Requires = omp::ClauseRequires;
Requires flags = requiresAttr.getValue();
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.setHasRequiresReverseOffload(
bitEnumContainsAll(flags, Requires::reverse_offload));
config.setHasRequiresUnifiedAddress(
bitEnumContainsAll(flags, Requires::unified_address));
config.setHasRequiresUnifiedSharedMemory(
bitEnumContainsAll(flags, Requires::unified_shared_memory));
config.setHasRequiresDynamicAllocators(
bitEnumContainsAll(flags, Requires::dynamic_allocators));
return success();
}
return failure();
})
.Default([](Attribute) {
// Fall through for omp attributes that do not require lowering.
return success();
Expand Down Expand Up @@ -2988,12 +3107,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
.Case([&](omp::TargetOp) {
return convertOmpTarget(*op, builder, moduleTranslation);
})
.Case<omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
// No-op, should be handled by relevant owning operations e.g.
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
// discarded
return success();
})
.Case<omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
[&](auto op) {
// No-op, should be handled by relevant owning operations e.g.
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
// discarded
return success();
})
.Default([&](Operation *inst) {
return inst->emitError("unsupported OpenMP operation: ")
<< inst->getName();
Expand Down
Loading