Skip to content

Commit df16ecc

Browse files
committed
Add RAII object to manage mapping of the op's arguments.
1 parent e0c5f0b commit df16ecc

File tree

1 file changed

+55
-50
lines changed

1 file changed

+55
-50
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,40 +1011,49 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
10111011
return success();
10121012
}
10131013

1014-
/// Replace the region arguments of the parallel op (which correspond to private
1015-
/// variables) with the actual private variables they correspond to. This
1016-
/// prepares the parallel op so that it matches what is expected by the
1017-
/// OMPIRBuilder. Instead of editing the original op in-place, this function
1018-
/// does the required changes to a cloned version which should then be erased by
1019-
/// the caller.
1020-
static omp::ParallelOp
1021-
prepareOmpParallelForPrivatization(omp::ParallelOp opInst) {
1022-
Region &region = opInst.getRegion();
1023-
auto privateVars = opInst.getPrivateVars();
1024-
1025-
auto privateVarsIt = privateVars.begin();
1026-
// Reduction precede private arguments, so skip them first.
1027-
unsigned privateArgBeginIdx = opInst.getNumReductionVars();
1028-
unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size();
1029-
1030-
mlir::IRMapping mapping;
1031-
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1032-
++argIdx, ++privateVarsIt)
1033-
mapping.map(region.getArgument(argIdx), *privateVarsIt);
1034-
1035-
mlir::OpBuilder cloneBuilder(opInst);
1036-
omp::ParallelOp opInstClone =
1037-
llvm::cast<omp::ParallelOp>(cloneBuilder.clone(*opInst, mapping));
1038-
1039-
return opInstClone;
1040-
}
1014+
/// A RAII class that on construction replaces the region arguments of the
1015+
/// parallel op (which correspond to private variables) with the actual private
1016+
/// variables they correspond to. This prepares the parallel op so that it
1017+
/// matches what is expected by the OMPIRBuilder.
1018+
///
1019+
/// On destruction, it restores the original state of the operation so that on
1020+
/// the MLIR side, the op is not affected by conversion to LLVM IR.
1021+
class OmpParallelOpConversionManager {
1022+
public:
1023+
OmpParallelOpConversionManager(omp::ParallelOp opInst)
1024+
: region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
1025+
privateArgBeginIdx(opInst.getNumReductionVars()),
1026+
privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
1027+
auto privateVarsIt = privateVars.begin();
1028+
1029+
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1030+
++argIdx, ++privateVarsIt)
1031+
mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
1032+
*privateVarsIt, region);
1033+
}
1034+
1035+
~OmpParallelOpConversionManager() {
1036+
auto privateVarsIt = privateVars.begin();
1037+
1038+
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1039+
++argIdx, ++privateVarsIt)
1040+
mlir::replaceAllUsesInRegionWith(*privateVarsIt,
1041+
region.getArgument(argIdx), region);
1042+
}
1043+
1044+
private:
1045+
Region &region;
1046+
OperandRange privateVars;
1047+
unsigned privateArgBeginIdx;
1048+
unsigned privateArgEndIdx;
1049+
};
10411050

10421051
/// Converts the OpenMP parallel operation to LLVM IR.
10431052
static LogicalResult
10441053
convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10451054
LLVM::ModuleTranslation &moduleTranslation) {
10461055
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1047-
omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization(opInst);
1056+
OmpParallelOpConversionManager raii(opInst);
10481057

10491058
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
10501059
// relying on captured variables.
@@ -1054,12 +1063,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10541063
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
10551064
// Collect reduction declarations
10561065
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1057-
collectReductionDecls(opInstClone, reductionDecls);
1066+
collectReductionDecls(opInst, reductionDecls);
10581067

10591068
// Allocate reduction vars
10601069
SmallVector<llvm::Value *> privateReductionVariables;
10611070
DenseMap<Value, llvm::Value *> reductionVariableMap;
1062-
allocReductionVars(opInstClone, builder, moduleTranslation, allocaIP,
1071+
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
10631072
reductionDecls, privateReductionVariables,
10641073
reductionVariableMap);
10651074

@@ -1071,7 +1080,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10711080

10721081
// Initialize reduction vars
10731082
builder.restoreIP(allocaIP);
1074-
for (unsigned i = 0; i < opInstClone.getNumReductionVars(); ++i) {
1083+
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
10751084
SmallVector<llvm::Value *> phis;
10761085
if (failed(inlineConvertOmpRegions(
10771086
reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
@@ -1092,19 +1101,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10921101
// ParallelOp has only one region associated with it.
10931102
builder.restoreIP(codeGenIP);
10941103
auto regionBlock =
1095-
convertOmpOpRegions(opInstClone.getRegion(), "omp.par.region", builder,
1104+
convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
10961105
moduleTranslation, bodyGenStatus);
10971106

10981107
// Process the reductions if required.
1099-
if (opInstClone.getNumReductionVars() > 0) {
1108+
if (opInst.getNumReductionVars() > 0) {
11001109
// Collect reduction info
11011110
SmallVector<OwningReductionGen> owningReductionGens;
11021111
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
11031112
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1104-
collectReductionInfo(opInstClone, builder, moduleTranslation,
1105-
reductionDecls, owningReductionGens,
1106-
owningAtomicReductionGens, privateReductionVariables,
1107-
reductionInfos);
1113+
collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
1114+
owningReductionGens, owningAtomicReductionGens,
1115+
privateReductionVariables, reductionInfos);
11081116

11091117
// Move to region cont block
11101118
builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1117,8 +1125,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11171125
ompBuilder->createReductions(builder.saveIP(), allocaIP,
11181126
reductionInfos, false);
11191127
if (!contInsertPoint.getBlock()) {
1120-
bodyGenStatus = opInstClone->emitOpError()
1121-
<< "failed to convert reductions";
1128+
bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
11221129
return;
11231130
}
11241131

@@ -1140,9 +1147,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11401147
// returned.
11411148
auto [privVar, privatizerClone] =
11421149
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1143-
if (!opInstClone.getPrivateVars().empty()) {
1144-
auto privVars = opInstClone.getPrivateVars();
1145-
auto privatizers = opInstClone.getPrivatizers();
1150+
if (!opInst.getPrivateVars().empty()) {
1151+
auto privVars = opInst.getPrivateVars();
1152+
auto privatizers = opInst.getPrivatizers();
11461153

11471154
for (auto [privVar, privatizerAttr] :
11481155
llvm::zip_equal(privVars, *privatizers)) {
@@ -1155,7 +1162,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11551162
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
11561163
omp::PrivateClauseOp privatizer =
11571164
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1158-
opInstClone, privSym);
1165+
opInst, privSym);
11591166

11601167
// Clone the privatizer in case it is used by more than one parallel
11611168
// region. The privatizer is processed in-place (see below) before it
@@ -1192,9 +1199,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11921199
SmallVector<llvm::Value *, 1> yieldedValues;
11931200
if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
11941201
moduleTranslation, &yieldedValues))) {
1195-
opInstClone.emitError(
1196-
"failed to inline `alloc` region of an `omp.private` "
1197-
"op in the parallel region");
1202+
opInst.emitError("failed to inline `alloc` region of an `omp.private` "
1203+
"op in the parallel region");
11981204
bodyGenStatus = failure();
11991205
} else {
12001206
assert(yieldedValues.size() == 1);
@@ -1213,13 +1219,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
12131219
auto finiCB = [&](InsertPointTy codeGenIP) {};
12141220

12151221
llvm::Value *ifCond = nullptr;
1216-
if (auto ifExprVar = opInstClone.getIfExprVar())
1222+
if (auto ifExprVar = opInst.getIfExprVar())
12171223
ifCond = moduleTranslation.lookupValue(ifExprVar);
12181224
llvm::Value *numThreads = nullptr;
1219-
if (auto numThreadsVar = opInstClone.getNumThreadsVar())
1225+
if (auto numThreadsVar = opInst.getNumThreadsVar())
12201226
numThreads = moduleTranslation.lookupValue(numThreadsVar);
12211227
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1222-
if (auto bind = opInstClone.getProcBindVal())
1228+
if (auto bind = opInst.getProcBindVal())
12231229
pbKind = getProcBindKind(*bind);
12241230
// TODO: Is the Parallel construct cancellable?
12251231
bool isCancellable = false;
@@ -1232,7 +1238,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
12321238
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
12331239
ifCond, numThreads, pbKind, isCancellable));
12341240

1235-
opInstClone.erase();
12361241
return bodyGenStatus;
12371242
}
12381243

0 commit comments

Comments
 (0)