@@ -1011,40 +1011,49 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
1011
1011
return success ();
1012
1012
}
1013
1013
1014
- // / Replace the region arguments of the parallel op (which correspond to private
1015
- // / variables) with the actual private variables they correspond to. This
1016
- // / prepares the parallel op so that it matches what is expected by the
1017
- // / OMPIRBuilder. Instead of editing the original op in-place, this function
1018
- // / does the required changes to a cloned version which should then be erased by
1019
- // / the caller.
1020
- static omp::ParallelOp
1021
- prepareOmpParallelForPrivatization (omp::ParallelOp opInst) {
1022
- Region ®ion = opInst.getRegion ();
1023
- auto privateVars = opInst.getPrivateVars ();
1024
-
1025
- auto privateVarsIt = privateVars.begin ();
1026
- // Reduction precede private arguments, so skip them first.
1027
- unsigned privateArgBeginIdx = opInst.getNumReductionVars ();
1028
- unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size ();
1029
-
1030
- mlir::IRMapping mapping;
1031
- for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1032
- ++argIdx, ++privateVarsIt)
1033
- mapping.map (region.getArgument (argIdx), *privateVarsIt);
1034
-
1035
- mlir::OpBuilder cloneBuilder (opInst);
1036
- omp::ParallelOp opInstClone =
1037
- llvm::cast<omp::ParallelOp>(cloneBuilder.clone (*opInst, mapping));
1038
-
1039
- return opInstClone;
1040
- }
1014
+ // / A RAII class that on construction replaces the region arguments of the
1015
+ // / parallel op (which correspond to private variables) with the actual private
1016
+ // / variables they correspond to. This prepares the parallel op so that it
1017
+ // / matches what is expected by the OMPIRBuilder.
1018
+ // /
1019
+ // / On destruction, it restores the original state of the operation so that on
1020
+ // / the MLIR side, the op is not affected by conversion to LLVM IR.
1021
+ class OmpParallelOpConversionManager {
1022
+ public:
1023
+ OmpParallelOpConversionManager (omp::ParallelOp opInst)
1024
+ : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
1025
+ privateArgBeginIdx (opInst.getNumReductionVars()),
1026
+ privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
1027
+ auto privateVarsIt = privateVars.begin ();
1028
+
1029
+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1030
+ ++argIdx, ++privateVarsIt)
1031
+ mlir::replaceAllUsesInRegionWith (region.getArgument (argIdx),
1032
+ *privateVarsIt, region);
1033
+ }
1034
+
1035
+ ~OmpParallelOpConversionManager () {
1036
+ auto privateVarsIt = privateVars.begin ();
1037
+
1038
+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1039
+ ++argIdx, ++privateVarsIt)
1040
+ mlir::replaceAllUsesInRegionWith (*privateVarsIt,
1041
+ region.getArgument (argIdx), region);
1042
+ }
1043
+
1044
+ private:
1045
+ Region ®ion;
1046
+ OperandRange privateVars;
1047
+ unsigned privateArgBeginIdx;
1048
+ unsigned privateArgEndIdx;
1049
+ };
1041
1050
1042
1051
// / Converts the OpenMP parallel operation to LLVM IR.
1043
1052
static LogicalResult
1044
1053
convertOmpParallel (omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1045
1054
LLVM::ModuleTranslation &moduleTranslation) {
1046
1055
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1047
- omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization (opInst);
1056
+ OmpParallelOpConversionManager raii (opInst);
1048
1057
1049
1058
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
1050
1059
// relying on captured variables.
@@ -1054,12 +1063,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1054
1063
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1055
1064
// Collect reduction declarations
1056
1065
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1057
- collectReductionDecls (opInstClone , reductionDecls);
1066
+ collectReductionDecls (opInst , reductionDecls);
1058
1067
1059
1068
// Allocate reduction vars
1060
1069
SmallVector<llvm::Value *> privateReductionVariables;
1061
1070
DenseMap<Value, llvm::Value *> reductionVariableMap;
1062
- allocReductionVars (opInstClone , builder, moduleTranslation, allocaIP,
1071
+ allocReductionVars (opInst , builder, moduleTranslation, allocaIP,
1063
1072
reductionDecls, privateReductionVariables,
1064
1073
reductionVariableMap);
1065
1074
@@ -1071,7 +1080,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1071
1080
1072
1081
// Initialize reduction vars
1073
1082
builder.restoreIP (allocaIP);
1074
- for (unsigned i = 0 ; i < opInstClone .getNumReductionVars (); ++i) {
1083
+ for (unsigned i = 0 ; i < opInst .getNumReductionVars (); ++i) {
1075
1084
SmallVector<llvm::Value *> phis;
1076
1085
if (failed (inlineConvertOmpRegions (
1077
1086
reductionDecls[i].getInitializerRegion (), " omp.reduction.neutral" ,
@@ -1092,19 +1101,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1092
1101
// ParallelOp has only one region associated with it.
1093
1102
builder.restoreIP (codeGenIP);
1094
1103
auto regionBlock =
1095
- convertOmpOpRegions (opInstClone .getRegion (), " omp.par.region" , builder,
1104
+ convertOmpOpRegions (opInst .getRegion (), " omp.par.region" , builder,
1096
1105
moduleTranslation, bodyGenStatus);
1097
1106
1098
1107
// Process the reductions if required.
1099
- if (opInstClone .getNumReductionVars () > 0 ) {
1108
+ if (opInst .getNumReductionVars () > 0 ) {
1100
1109
// Collect reduction info
1101
1110
SmallVector<OwningReductionGen> owningReductionGens;
1102
1111
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1103
1112
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1104
- collectReductionInfo (opInstClone, builder, moduleTranslation,
1105
- reductionDecls, owningReductionGens,
1106
- owningAtomicReductionGens, privateReductionVariables,
1107
- reductionInfos);
1113
+ collectReductionInfo (opInst, builder, moduleTranslation, reductionDecls,
1114
+ owningReductionGens, owningAtomicReductionGens,
1115
+ privateReductionVariables, reductionInfos);
1108
1116
1109
1117
// Move to region cont block
1110
1118
builder.SetInsertPoint (regionBlock->getTerminator ());
@@ -1117,8 +1125,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1117
1125
ompBuilder->createReductions (builder.saveIP (), allocaIP,
1118
1126
reductionInfos, false );
1119
1127
if (!contInsertPoint.getBlock ()) {
1120
- bodyGenStatus = opInstClone->emitOpError ()
1121
- << " failed to convert reductions" ;
1128
+ bodyGenStatus = opInst->emitOpError () << " failed to convert reductions" ;
1122
1129
return ;
1123
1130
}
1124
1131
@@ -1140,9 +1147,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1140
1147
// returned.
1141
1148
auto [privVar, privatizerClone] =
1142
1149
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1143
- if (!opInstClone .getPrivateVars ().empty ()) {
1144
- auto privVars = opInstClone .getPrivateVars ();
1145
- auto privatizers = opInstClone .getPrivatizers ();
1150
+ if (!opInst .getPrivateVars ().empty ()) {
1151
+ auto privVars = opInst .getPrivateVars ();
1152
+ auto privatizers = opInst .getPrivatizers ();
1146
1153
1147
1154
for (auto [privVar, privatizerAttr] :
1148
1155
llvm::zip_equal (privVars, *privatizers)) {
@@ -1155,7 +1162,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1155
1162
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1156
1163
omp::PrivateClauseOp privatizer =
1157
1164
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1158
- opInstClone , privSym);
1165
+ opInst , privSym);
1159
1166
1160
1167
// Clone the privatizer in case it is used by more than one parallel
1161
1168
// region. The privatizer is processed in-place (see below) before it
@@ -1192,9 +1199,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1192
1199
SmallVector<llvm::Value *, 1 > yieldedValues;
1193
1200
if (failed (inlineConvertOmpRegions (allocRegion, " omp.privatizer" , builder,
1194
1201
moduleTranslation, &yieldedValues))) {
1195
- opInstClone.emitError (
1196
- " failed to inline `alloc` region of an `omp.private` "
1197
- " op in the parallel region" );
1202
+ opInst.emitError (" failed to inline `alloc` region of an `omp.private` "
1203
+ " op in the parallel region" );
1198
1204
bodyGenStatus = failure ();
1199
1205
} else {
1200
1206
assert (yieldedValues.size () == 1 );
@@ -1213,13 +1219,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1213
1219
auto finiCB = [&](InsertPointTy codeGenIP) {};
1214
1220
1215
1221
llvm::Value *ifCond = nullptr ;
1216
- if (auto ifExprVar = opInstClone .getIfExprVar ())
1222
+ if (auto ifExprVar = opInst .getIfExprVar ())
1217
1223
ifCond = moduleTranslation.lookupValue (ifExprVar);
1218
1224
llvm::Value *numThreads = nullptr ;
1219
- if (auto numThreadsVar = opInstClone .getNumThreadsVar ())
1225
+ if (auto numThreadsVar = opInst .getNumThreadsVar ())
1220
1226
numThreads = moduleTranslation.lookupValue (numThreadsVar);
1221
1227
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1222
- if (auto bind = opInstClone .getProcBindVal ())
1228
+ if (auto bind = opInst .getProcBindVal ())
1223
1229
pbKind = getProcBindKind (*bind);
1224
1230
// TODO: Is the Parallel construct cancellable?
1225
1231
bool isCancellable = false ;
@@ -1232,7 +1238,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1232
1238
ompBuilder->createParallel (ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
1233
1239
ifCond, numThreads, pbKind, isCancellable));
1234
1240
1235
- opInstClone.erase ();
1236
1241
return bodyGenStatus;
1237
1242
}
1238
1243
0 commit comments