@@ -396,9 +396,9 @@ collectReductionDecls(T loop,
396
396
397
397
// / Translates the blocks contained in the given region and appends them to at
398
398
// / the current insertion point of `builder`. The operations of the entry block
399
- // / are appended to the current insertion block, which is not expected to have a
400
- // / terminator. If set, `continuationBlockArgs` is populated with translated
401
- // / values that correspond to the values omp.yield'ed from the region.
399
+ // / are appended to the current insertion block. If set, `continuationBlockArgs`
400
+ // / is populated with translated values that correspond to the values
401
+ // / omp.yield'ed from the region.
402
402
static LogicalResult inlineConvertOmpRegions (
403
403
Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
404
404
LLVM::ModuleTranslation &moduleTranslation,
@@ -409,7 +409,14 @@ static LogicalResult inlineConvertOmpRegions(
409
409
// Special case for single-block regions that don't create additional blocks:
410
410
// insert operations without creating additional blocks.
411
411
if (llvm::hasSingleElement (region)) {
412
+ llvm::Instruction *potentialTerminator =
413
+ builder.GetInsertBlock ()->empty () ? nullptr
414
+ : &builder.GetInsertBlock ()->back ();
415
+
416
+ if (potentialTerminator && potentialTerminator->isTerminator ())
417
+ potentialTerminator->removeFromParent ();
412
418
moduleTranslation.mapBlock (®ion.front (), builder.GetInsertBlock ());
419
+
413
420
if (failed (moduleTranslation.convertBlock (
414
421
region.front (), /* ignoreArguments=*/ true , builder)))
415
422
return failure ();
@@ -423,6 +430,10 @@ static LogicalResult inlineConvertOmpRegions(
423
430
// Drop the mapping that is no longer necessary so that the same region can
424
431
// be processed multiple times.
425
432
moduleTranslation.forgetMapping (region);
433
+
434
+ if (potentialTerminator && potentialTerminator->isTerminator ())
435
+ potentialTerminator->insertAfter (&builder.GetInsertBlock ()->back ());
436
+
426
437
return success ();
427
438
}
428
439
@@ -1000,11 +1011,41 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
1000
1011
return success ();
1001
1012
}
1002
1013
1014
+ // / Replace the region arguments of the parallel op (which correspond to private
1015
+ // / variables) with the actual private variables they correspond to. This
1016
+ // / prepares the parallel op so that it matches what is expected by the
1017
+ // / OMPIRBuilder. Instead of editing the original op in-place, this function
1018
+ // / does the required changes to a cloned version which should then be erased by
1019
+ // / the caller.
1020
+ static omp::ParallelOp
1021
+ prepareOmpParallelForPrivatization (omp::ParallelOp opInst) {
1022
+ Region ®ion = opInst.getRegion ();
1023
+ auto privateVars = opInst.getPrivateVars ();
1024
+
1025
+ auto privateVarsIt = privateVars.begin ();
1026
+ // Reduction precede private arguments, so skip them first.
1027
+ unsigned privateArgBeginIdx = opInst.getNumReductionVars ();
1028
+ unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size ();
1029
+
1030
+ mlir::IRMapping mapping;
1031
+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1032
+ ++argIdx, ++privateVarsIt)
1033
+ mapping.map (region.getArgument (argIdx), *privateVarsIt);
1034
+
1035
+ mlir::OpBuilder cloneBuilder (opInst);
1036
+ omp::ParallelOp opInstClone =
1037
+ llvm::cast<omp::ParallelOp>(cloneBuilder.clone (*opInst, mapping));
1038
+
1039
+ return opInstClone;
1040
+ }
1041
+
1003
1042
// / Converts the OpenMP parallel operation to LLVM IR.
1004
1043
static LogicalResult
1005
1044
convertOmpParallel (omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1006
1045
LLVM::ModuleTranslation &moduleTranslation) {
1007
1046
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1047
+ omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization (opInst);
1048
+
1008
1049
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
1009
1050
// relying on captured variables.
1010
1051
LogicalResult bodyGenStatus = success ();
@@ -1013,12 +1054,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1013
1054
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1014
1055
// Collect reduction declarations
1015
1056
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1016
- collectReductionDecls (opInst , reductionDecls);
1057
+ collectReductionDecls (opInstClone , reductionDecls);
1017
1058
1018
1059
// Allocate reduction vars
1019
1060
SmallVector<llvm::Value *> privateReductionVariables;
1020
1061
DenseMap<Value, llvm::Value *> reductionVariableMap;
1021
- allocReductionVars (opInst , builder, moduleTranslation, allocaIP,
1062
+ allocReductionVars (opInstClone , builder, moduleTranslation, allocaIP,
1022
1063
reductionDecls, privateReductionVariables,
1023
1064
reductionVariableMap);
1024
1065
@@ -1030,7 +1071,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1030
1071
1031
1072
// Initialize reduction vars
1032
1073
builder.restoreIP (allocaIP);
1033
- for (unsigned i = 0 ; i < opInst .getNumReductionVars (); ++i) {
1074
+ for (unsigned i = 0 ; i < opInstClone .getNumReductionVars (); ++i) {
1034
1075
SmallVector<llvm::Value *> phis;
1035
1076
if (failed (inlineConvertOmpRegions (
1036
1077
reductionDecls[i].getInitializerRegion (), " omp.reduction.neutral" ,
@@ -1051,18 +1092,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1051
1092
// ParallelOp has only one region associated with it.
1052
1093
builder.restoreIP (codeGenIP);
1053
1094
auto regionBlock =
1054
- convertOmpOpRegions (opInst .getRegion (), " omp.par.region" , builder,
1095
+ convertOmpOpRegions (opInstClone .getRegion (), " omp.par.region" , builder,
1055
1096
moduleTranslation, bodyGenStatus);
1056
1097
1057
1098
// Process the reductions if required.
1058
- if (opInst .getNumReductionVars () > 0 ) {
1099
+ if (opInstClone .getNumReductionVars () > 0 ) {
1059
1100
// Collect reduction info
1060
1101
SmallVector<OwningReductionGen> owningReductionGens;
1061
1102
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1062
1103
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1063
- collectReductionInfo (opInst, builder, moduleTranslation, reductionDecls,
1064
- owningReductionGens, owningAtomicReductionGens,
1065
- privateReductionVariables, reductionInfos);
1104
+ collectReductionInfo (opInstClone, builder, moduleTranslation,
1105
+ reductionDecls, owningReductionGens,
1106
+ owningAtomicReductionGens, privateReductionVariables,
1107
+ reductionInfos);
1066
1108
1067
1109
// Move to region cont block
1068
1110
builder.SetInsertPoint (regionBlock->getTerminator ());
@@ -1075,7 +1117,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1075
1117
ompBuilder->createReductions (builder.saveIP (), allocaIP,
1076
1118
reductionInfos, false );
1077
1119
if (!contInsertPoint.getBlock ()) {
1078
- bodyGenStatus = opInst->emitOpError () << " failed to convert reductions" ;
1120
+ bodyGenStatus = opInstClone->emitOpError ()
1121
+ << " failed to convert reductions" ;
1079
1122
return ;
1080
1123
}
1081
1124
@@ -1086,12 +1129,82 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1086
1129
1087
1130
// TODO: Perform appropriate actions according to the data-sharing
1088
1131
// attribute (shared, private, firstprivate, ...) of variables.
1089
- // Currently defaults to shared .
1132
+ // Currently shared and private are supported .
1090
1133
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1091
1134
llvm::Value &, llvm::Value &vPtr,
1092
1135
llvm::Value *&replacementValue) -> InsertPointTy {
1093
1136
replacementValue = &vPtr;
1094
1137
1138
+ // If this is a private value, this lambda will return the corresponding
1139
+ // mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1140
+ // returned.
1141
+ auto [privVar, privatizerClone] =
1142
+ [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1143
+ if (!opInstClone.getPrivateVars ().empty ()) {
1144
+ auto privVars = opInstClone.getPrivateVars ();
1145
+ auto privatizers = opInstClone.getPrivatizers ();
1146
+
1147
+ for (auto [privVar, privatizerAttr] :
1148
+ llvm::zip_equal (privVars, *privatizers)) {
1149
+ // Find the MLIR private variable corresponding to the LLVM value
1150
+ // being privatized.
1151
+ llvm::Value *llvmPrivVar = moduleTranslation.lookupValue (privVar);
1152
+ if (llvmPrivVar != &vPtr)
1153
+ continue ;
1154
+
1155
+ SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1156
+ omp::PrivateClauseOp privatizer =
1157
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1158
+ opInstClone, privSym);
1159
+
1160
+ // Clone the privatizer in case it is used by more than one parallel
1161
+ // region. The privatizer is processed in-place (see below) before it
1162
+ // gets inlined in the parallel region and therefore processing the
1163
+ // original op is dangerous.
1164
+ return {privVar, privatizer.clone ()};
1165
+ }
1166
+ }
1167
+
1168
+ return {mlir::Value (), omp::PrivateClauseOp ()};
1169
+ }();
1170
+
1171
+ if (privVar) {
1172
+ if (privatizerClone.getDataSharingType () ==
1173
+ omp::DataSharingClauseType::FirstPrivate) {
1174
+ privatizerClone.emitOpError (
1175
+ " TODO: delayed privatization is not "
1176
+ " supported for `firstprivate` clauses yet." );
1177
+ bodyGenStatus = failure ();
1178
+ return codeGenIP;
1179
+ }
1180
+
1181
+ Region &allocRegion = privatizerClone.getAllocRegion ();
1182
+
1183
+ // Replace the privatizer block argument with mlir value being privatized.
1184
+ // This way, the body of the privatizer will be changed from using the
1185
+ // region/block argument to the value being privatized.
1186
+ auto allocRegionArg = allocRegion.getArgument (0 );
1187
+ replaceAllUsesInRegionWith (allocRegionArg, privVar, allocRegion);
1188
+
1189
+ auto oldIP = builder.saveIP ();
1190
+ builder.restoreIP (allocaIP);
1191
+
1192
+ SmallVector<llvm::Value *, 1 > yieldedValues;
1193
+ if (failed (inlineConvertOmpRegions (allocRegion, " omp.privatizer" , builder,
1194
+ moduleTranslation, &yieldedValues))) {
1195
+ opInstClone.emitError (
1196
+ " failed to inline `alloc` region of an `omp.private` "
1197
+ " op in the parallel region" );
1198
+ bodyGenStatus = failure ();
1199
+ } else {
1200
+ assert (yieldedValues.size () == 1 );
1201
+ replacementValue = yieldedValues.front ();
1202
+ }
1203
+
1204
+ privatizerClone.erase ();
1205
+ builder.restoreIP (oldIP);
1206
+ }
1207
+
1095
1208
return codeGenIP;
1096
1209
};
1097
1210
@@ -1100,13 +1213,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1100
1213
auto finiCB = [&](InsertPointTy codeGenIP) {};
1101
1214
1102
1215
llvm::Value *ifCond = nullptr ;
1103
- if (auto ifExprVar = opInst .getIfExprVar ())
1216
+ if (auto ifExprVar = opInstClone .getIfExprVar ())
1104
1217
ifCond = moduleTranslation.lookupValue (ifExprVar);
1105
1218
llvm::Value *numThreads = nullptr ;
1106
- if (auto numThreadsVar = opInst .getNumThreadsVar ())
1219
+ if (auto numThreadsVar = opInstClone .getNumThreadsVar ())
1107
1220
numThreads = moduleTranslation.lookupValue (numThreadsVar);
1108
1221
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1109
- if (auto bind = opInst .getProcBindVal ())
1222
+ if (auto bind = opInstClone .getProcBindVal ())
1110
1223
pbKind = getProcBindKind (*bind);
1111
1224
// TODO: Is the Parallel construct cancellable?
1112
1225
bool isCancellable = false ;
@@ -1119,6 +1232,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1119
1232
ompBuilder->createParallel (ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
1120
1233
ifCond, numThreads, pbKind, isCancellable));
1121
1234
1235
+ opInstClone.erase ();
1122
1236
return bodyGenStatus;
1123
1237
}
1124
1238
@@ -1635,7 +1749,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
1635
1749
// A small helper structure to contain data gathered
1636
1750
// for map lowering and coalese it into one area and
1637
1751
// avoiding extra computations such as searches in the
1638
- // llvm module for lowered mapped varibles or checking
1752
+ // llvm module for lowered mapped variables or checking
1639
1753
// if something is declare target (and retrieving the
1640
1754
// value) more than neccessary.
1641
1755
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
@@ -2854,26 +2968,26 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
2854
2968
moduleTranslation);
2855
2969
return failure ();
2856
2970
})
2857
- .Case (
2858
- " omp.requires " ,
2859
- [&](Attribute attr) {
2860
- if ( auto requiresAttr = attr.dyn_cast <omp::ClauseRequiresAttr>()) {
2861
- using Requires = omp::ClauseRequires;
2862
- Requires flags = requiresAttr.getValue ();
2863
- llvm::OpenMPIRBuilderConfig &config =
2864
- moduleTranslation.getOpenMPBuilder ()->Config ;
2865
- config.setHasRequiresReverseOffload (
2866
- bitEnumContainsAll (flags, Requires::reverse_offload));
2867
- config.setHasRequiresUnifiedAddress (
2868
- bitEnumContainsAll (flags, Requires::unified_address));
2869
- config.setHasRequiresUnifiedSharedMemory (
2870
- bitEnumContainsAll (flags, Requires::unified_shared_memory));
2871
- config.setHasRequiresDynamicAllocators (
2872
- bitEnumContainsAll (flags, Requires::dynamic_allocators));
2873
- return success ();
2874
- }
2875
- return failure ();
2876
- })
2971
+ .Case (" omp.requires " ,
2972
+ [&](Attribute attr) {
2973
+ if ( auto requiresAttr =
2974
+ attr.dyn_cast <omp::ClauseRequiresAttr>()) {
2975
+ using Requires = omp::ClauseRequires;
2976
+ Requires flags = requiresAttr.getValue ();
2977
+ llvm::OpenMPIRBuilderConfig &config =
2978
+ moduleTranslation.getOpenMPBuilder ()->Config ;
2979
+ config.setHasRequiresReverseOffload (
2980
+ bitEnumContainsAll (flags, Requires::reverse_offload));
2981
+ config.setHasRequiresUnifiedAddress (
2982
+ bitEnumContainsAll (flags, Requires::unified_address));
2983
+ config.setHasRequiresUnifiedSharedMemory (
2984
+ bitEnumContainsAll (flags, Requires::unified_shared_memory));
2985
+ config.setHasRequiresDynamicAllocators (
2986
+ bitEnumContainsAll (flags, Requires::dynamic_allocators));
2987
+ return success ();
2988
+ }
2989
+ return failure ();
2990
+ })
2877
2991
.Default ([](Attribute) {
2878
2992
// Fall through for omp attributes that do not require lowering.
2879
2993
return success ();
@@ -2988,12 +3102,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
2988
3102
.Case ([&](omp::TargetOp) {
2989
3103
return convertOmpTarget (*op, builder, moduleTranslation);
2990
3104
})
2991
- .Case <omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
2992
- // No-op, should be handled by relevant owning operations e.g.
2993
- // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
2994
- // discarded
2995
- return success ();
2996
- })
3105
+ .Case <omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
3106
+ [&](auto op) {
3107
+ // No-op, should be handled by relevant owning operations e.g.
3108
+ // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3109
+ // discarded
3110
+ return success ();
3111
+ })
2997
3112
.Default ([&](Operation *inst) {
2998
3113
return inst->emitError (" unsupported OpenMP operation: " )
2999
3114
<< inst->getName ();
0 commit comments