@@ -396,9 +396,9 @@ collectReductionDecls(T loop,
396
396
397
397
// / Translates the blocks contained in the given region and appends them to at
398
398
// / the current insertion point of `builder`. The operations of the entry block
399
- // / are appended to the current insertion block, which is not expected to have a
400
- // / terminator. If set, `continuationBlockArgs` is populated with translated
401
- // / values that correspond to the values omp.yield'ed from the region.
399
+ // / are appended to the current insertion block. If set, `continuationBlockArgs`
400
+ // / is populated with translated values that correspond to the values
401
+ // / omp.yield'ed from the region.
402
402
static LogicalResult inlineConvertOmpRegions (
403
403
Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
404
404
LLVM::ModuleTranslation &moduleTranslation,
@@ -409,7 +409,14 @@ static LogicalResult inlineConvertOmpRegions(
409
409
// Special case for single-block regions that don't create additional blocks:
410
410
// insert operations without creating additional blocks.
411
411
if (llvm::hasSingleElement (region)) {
412
+ llvm::Instruction *potentialTerminator =
413
+ builder.GetInsertBlock ()->empty () ? nullptr
414
+ : &builder.GetInsertBlock ()->back ();
415
+
416
+ if (potentialTerminator && potentialTerminator->isTerminator ())
417
+ potentialTerminator->removeFromParent ();
412
418
moduleTranslation.mapBlock (®ion.front (), builder.GetInsertBlock ());
419
+
413
420
if (failed (moduleTranslation.convertBlock (
414
421
region.front (), /* ignoreArguments=*/ true , builder)))
415
422
return failure ();
@@ -423,6 +430,10 @@ static LogicalResult inlineConvertOmpRegions(
423
430
// Drop the mapping that is no longer necessary so that the same region can
424
431
// be processed multiple times.
425
432
moduleTranslation.forgetMapping (region);
433
+
434
+ if (potentialTerminator && potentialTerminator->isTerminator ())
435
+ potentialTerminator->insertAfter (&builder.GetInsertBlock ()->back ());
436
+
426
437
return success ();
427
438
}
428
439
@@ -1000,11 +1011,50 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
1000
1011
return success ();
1001
1012
}
1002
1013
1014
+ // / A RAII class that on construction replaces the region arguments of the
1015
+ // / parallel op (which correspond to private variables) with the actual private
1016
+ // / variables they correspond to. This prepares the parallel op so that it
1017
+ // / matches what is expected by the OMPIRBuilder.
1018
+ // /
1019
+ // / On destruction, it restores the original state of the operation so that on
1020
+ // / the MLIR side, the op is not affected by conversion to LLVM IR.
1021
+ class OmpParallelOpConversionManager {
1022
+ public:
1023
+ OmpParallelOpConversionManager (omp::ParallelOp opInst)
1024
+ : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
1025
+ privateArgBeginIdx (opInst.getNumReductionVars()),
1026
+ privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
1027
+ auto privateVarsIt = privateVars.begin ();
1028
+
1029
+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1030
+ ++argIdx, ++privateVarsIt)
1031
+ mlir::replaceAllUsesInRegionWith (region.getArgument (argIdx),
1032
+ *privateVarsIt, region);
1033
+ }
1034
+
1035
+ ~OmpParallelOpConversionManager () {
1036
+ auto privateVarsIt = privateVars.begin ();
1037
+
1038
+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1039
+ ++argIdx, ++privateVarsIt)
1040
+ mlir::replaceAllUsesInRegionWith (*privateVarsIt,
1041
+ region.getArgument (argIdx), region);
1042
+ }
1043
+
1044
+ private:
1045
+ Region ®ion;
1046
+ OperandRange privateVars;
1047
+ unsigned privateArgBeginIdx;
1048
+ unsigned privateArgEndIdx;
1049
+ };
1050
+
1003
1051
// / Converts the OpenMP parallel operation to LLVM IR.
1004
1052
static LogicalResult
1005
1053
convertOmpParallel (omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1006
1054
LLVM::ModuleTranslation &moduleTranslation) {
1007
1055
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1056
+ OmpParallelOpConversionManager raii (opInst);
1057
+
1008
1058
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
1009
1059
// relying on captured variables.
1010
1060
LogicalResult bodyGenStatus = success ();
@@ -1086,12 +1136,81 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1086
1136
1087
1137
// TODO: Perform appropriate actions according to the data-sharing
1088
1138
// attribute (shared, private, firstprivate, ...) of variables.
1089
- // Currently defaults to shared .
1139
+ // Currently shared and private are supported .
1090
1140
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1091
1141
llvm::Value &, llvm::Value &vPtr,
1092
1142
llvm::Value *&replacementValue) -> InsertPointTy {
1093
1143
replacementValue = &vPtr;
1094
1144
1145
+ // If this is a private value, this lambda will return the corresponding
1146
+ // mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1147
+ // returned.
1148
+ auto [privVar, privatizerClone] =
1149
+ [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1150
+ if (!opInst.getPrivateVars ().empty ()) {
1151
+ auto privVars = opInst.getPrivateVars ();
1152
+ auto privatizers = opInst.getPrivatizers ();
1153
+
1154
+ for (auto [privVar, privatizerAttr] :
1155
+ llvm::zip_equal (privVars, *privatizers)) {
1156
+ // Find the MLIR private variable corresponding to the LLVM value
1157
+ // being privatized.
1158
+ llvm::Value *llvmPrivVar = moduleTranslation.lookupValue (privVar);
1159
+ if (llvmPrivVar != &vPtr)
1160
+ continue ;
1161
+
1162
+ SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1163
+ omp::PrivateClauseOp privatizer =
1164
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1165
+ opInst, privSym);
1166
+
1167
+ // Clone the privatizer in case it is used by more than one parallel
1168
+ // region. The privatizer is processed in-place (see below) before it
1169
+ // gets inlined in the parallel region and therefore processing the
1170
+ // original op is dangerous.
1171
+ return {privVar, privatizer.clone ()};
1172
+ }
1173
+ }
1174
+
1175
+ return {mlir::Value (), omp::PrivateClauseOp ()};
1176
+ }();
1177
+
1178
+ if (privVar) {
1179
+ if (privatizerClone.getDataSharingType () ==
1180
+ omp::DataSharingClauseType::FirstPrivate) {
1181
+ privatizerClone.emitOpError (
1182
+ " TODO: delayed privatization is not "
1183
+ " supported for `firstprivate` clauses yet." );
1184
+ bodyGenStatus = failure ();
1185
+ return codeGenIP;
1186
+ }
1187
+
1188
+ Region &allocRegion = privatizerClone.getAllocRegion ();
1189
+
1190
+ // Replace the privatizer block argument with mlir value being privatized.
1191
+ // This way, the body of the privatizer will be changed from using the
1192
+ // region/block argument to the value being privatized.
1193
+ auto allocRegionArg = allocRegion.getArgument (0 );
1194
+ replaceAllUsesInRegionWith (allocRegionArg, privVar, allocRegion);
1195
+
1196
+ auto oldIP = builder.saveIP ();
1197
+ builder.restoreIP (allocaIP);
1198
+
1199
+ SmallVector<llvm::Value *, 1 > yieldedValues;
1200
+ if (failed (inlineConvertOmpRegions (allocRegion, " omp.privatizer" , builder,
1201
+ moduleTranslation, &yieldedValues))) {
1202
+ opInst.emitError (" failed to inline `alloc` region of an `omp.private` "
1203
+ " op in the parallel region" );
1204
+ bodyGenStatus = failure ();
1205
+ } else {
1206
+ assert (yieldedValues.size () == 1 );
1207
+ replacementValue = yieldedValues.front ();
1208
+ }
1209
+
1210
+ privatizerClone.erase ();
1211
+ builder.restoreIP (oldIP);
1212
+ }
1213
+
1095
1214
return codeGenIP;
1096
1215
};
1097
1216
@@ -1635,7 +1754,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
1635
1754
// A small helper structure to contain data gathered
1636
1755
// for map lowering and coalese it into one area and
1637
1756
// avoiding extra computations such as searches in the
1638
- // llvm module for lowered mapped varibles or checking
1757
+ // llvm module for lowered mapped variables or checking
1639
1758
// if something is declare target (and retrieving the
1640
1759
// value) more than neccessary.
1641
1760
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
@@ -2854,26 +2973,26 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
2854
2973
moduleTranslation);
2855
2974
return failure ();
2856
2975
})
2857
- .Case (
2858
- " omp.requires " ,
2859
- [&](Attribute attr) {
2860
- if ( auto requiresAttr = attr.dyn_cast <omp::ClauseRequiresAttr>()) {
2861
- using Requires = omp::ClauseRequires;
2862
- Requires flags = requiresAttr.getValue ();
2863
- llvm::OpenMPIRBuilderConfig &config =
2864
- moduleTranslation.getOpenMPBuilder ()->Config ;
2865
- config.setHasRequiresReverseOffload (
2866
- bitEnumContainsAll (flags, Requires::reverse_offload));
2867
- config.setHasRequiresUnifiedAddress (
2868
- bitEnumContainsAll (flags, Requires::unified_address));
2869
- config.setHasRequiresUnifiedSharedMemory (
2870
- bitEnumContainsAll (flags, Requires::unified_shared_memory));
2871
- config.setHasRequiresDynamicAllocators (
2872
- bitEnumContainsAll (flags, Requires::dynamic_allocators));
2873
- return success ();
2874
- }
2875
- return failure ();
2876
- })
2976
+ .Case (" omp.requires " ,
2977
+ [&](Attribute attr) {
2978
+ if ( auto requiresAttr =
2979
+ attr.dyn_cast <omp::ClauseRequiresAttr>()) {
2980
+ using Requires = omp::ClauseRequires;
2981
+ Requires flags = requiresAttr.getValue ();
2982
+ llvm::OpenMPIRBuilderConfig &config =
2983
+ moduleTranslation.getOpenMPBuilder ()->Config ;
2984
+ config.setHasRequiresReverseOffload (
2985
+ bitEnumContainsAll (flags, Requires::reverse_offload));
2986
+ config.setHasRequiresUnifiedAddress (
2987
+ bitEnumContainsAll (flags, Requires::unified_address));
2988
+ config.setHasRequiresUnifiedSharedMemory (
2989
+ bitEnumContainsAll (flags, Requires::unified_shared_memory));
2990
+ config.setHasRequiresDynamicAllocators (
2991
+ bitEnumContainsAll (flags, Requires::dynamic_allocators));
2992
+ return success ();
2993
+ }
2994
+ return failure ();
2995
+ })
2877
2996
.Default ([](Attribute) {
2878
2997
// Fall through for omp attributes that do not require lowering.
2879
2998
return success ();
@@ -2988,12 +3107,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
2988
3107
.Case ([&](omp::TargetOp) {
2989
3108
return convertOmpTarget (*op, builder, moduleTranslation);
2990
3109
})
2991
- .Case <omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
2992
- // No-op, should be handled by relevant owning operations e.g.
2993
- // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
2994
- // discarded
2995
- return success ();
2996
- })
3110
+ .Case <omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
3111
+ [&](auto op) {
3112
+ // No-op, should be handled by relevant owning operations e.g.
3113
+ // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3114
+ // discarded
3115
+ return success ();
3116
+ })
2997
3117
.Default ([&](Operation *inst) {
2998
3118
return inst->emitError (" unsupported OpenMP operation: " )
2999
3119
<< inst->getName ();
0 commit comments