Skip to content

Commit 9313342

Browse files
committed
[OpenMP][flang] Support GPU team reductions on allocatables
Extends the work started in #165714 by supporting team reductions. Similar to what was done in #165714, this PR introduces proper allocations, loads, and stores for by-ref reductions in teams-related callbacks: * `_omp_reduction_list_to_global_copy_func`, * `_omp_reduction_list_to_global_reduce_func`, * `_omp_reduction_global_to_list_copy_func`, and * `_omp_reduction_global_to_list_reduce_func`.
1 parent f481f5b commit 9313342

File tree

4 files changed

+233
-39
lines changed

4 files changed

+233
-39
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,7 +1733,8 @@ class OpenMPIRBuilder {
17331733
/// \return The ListToGlobalCopy function.
17341734
Function *emitListToGlobalCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
17351735
Type *ReductionsBufferTy,
1736-
AttributeList FuncAttrs);
1736+
AttributeList FuncAttrs,
1737+
ArrayRef<bool> IsByRef);
17371738

17381739
/// This function emits a helper that copies all the reduction variables from
17391740
/// the team into the provided global buffer for the reduction variables.
@@ -1750,7 +1751,8 @@ class OpenMPIRBuilder {
17501751
/// \return The GlobalToList function.
17511752
Function *emitGlobalToListCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
17521753
Type *ReductionsBufferTy,
1753-
AttributeList FuncAttrs);
1754+
AttributeList FuncAttrs,
1755+
ArrayRef<bool> IsByRef);
17541756

17551757
/// This function emits a helper that reduces all the reduction variables from
17561758
/// the team into the provided global buffer for the reduction variables.
@@ -1772,7 +1774,8 @@ class OpenMPIRBuilder {
17721774
Function *
17731775
emitListToGlobalReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,
17741776
Function *ReduceFn, Type *ReductionsBufferTy,
1775-
AttributeList FuncAttrs);
1777+
AttributeList FuncAttrs,
1778+
ArrayRef<bool> IsByRef);
17761779

17771780
/// This function emits a helper that reduces all the reduction variables from
17781781
/// the team into the provided global buffer for the reduction variables.
@@ -1794,7 +1797,8 @@ class OpenMPIRBuilder {
17941797
Function *
17951798
emitGlobalToListReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,
17961799
Function *ReduceFn, Type *ReductionsBufferTy,
1797-
AttributeList FuncAttrs);
1800+
AttributeList FuncAttrs,
1801+
ArrayRef<bool> IsByRef);
17981802

17991803
/// Get the function name of a reduction function.
18001804
std::string getReductionFuncName(StringRef Name) const;

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3153,7 +3153,7 @@ Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
31533153

31543154
Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
31553155
ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3156-
AttributeList FuncAttrs) {
3156+
AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
31573157
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
31583158
LLVMContext &Ctx = M.getContext();
31593159
FunctionType *FuncTy = FunctionType::get(
@@ -3223,7 +3223,15 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
32233223

32243224
switch (RI.EvaluationKind) {
32253225
case EvalKind::Scalar: {
3226-
Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
3226+
Value *TargetElement = [&]() {
3227+
if (IsByRef.empty() || !IsByRef[En.index()])
3228+
return Builder.CreateLoad(RI.ElementType, ElemPtr);
3229+
3230+
cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr));
3231+
ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtr);
3232+
return Builder.CreateLoad(RI.ByRefElementType, ElemPtr);
3233+
}();
3234+
32273235
Builder.CreateStore(TargetElement, GlobVal);
32283236
break;
32293237
}
@@ -3263,7 +3271,7 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
32633271

32643272
Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
32653273
ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3266-
Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3274+
Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
32673275
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
32683276
LLVMContext &Ctx = M.getContext();
32693277
FunctionType *FuncTy = FunctionType::get(
@@ -3302,6 +3310,8 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
33023310
Value *LocalReduceList =
33033311
Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
33043312

3313+
InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
3314+
33053315
Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
33063316
BufferArgAlloca, Builder.getPtrTy(),
33073317
BufferArgAlloca->getName() + ".ascast");
@@ -3323,6 +3333,20 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
33233333
Type *IndexTy = Builder.getIndexTy(
33243334
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
33253335
for (auto En : enumerate(ReductionInfos)) {
3336+
const ReductionInfo &RI = En.value();
3337+
Value *ByRefAlloc;
3338+
3339+
if (!IsByRef.empty() && IsByRef[En.index()]) {
3340+
InsertPointTy OldIP = Builder.saveIP();
3341+
Builder.restoreIP(AllocaIP);
3342+
3343+
ByRefAlloc = Builder.CreateAlloca(RI.ByRefAllocatedType);
3344+
ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
3345+
ByRefAlloc, Builder.getPtrTy(), ByRefAlloc->getName() + ".ascast");
3346+
3347+
Builder.restoreIP(OldIP);
3348+
}
3349+
33263350
Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
33273351
RedListArrayTy, LocalReduceListAddrCast,
33283352
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
@@ -3331,7 +3355,15 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
33313355
// Global = Buffer.VD[Idx];
33323356
Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
33333357
ReductionsBufferTy, BufferVD, 0, En.index());
3334-
Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3358+
3359+
if (!IsByRef.empty() && IsByRef[En.index()]) {
3360+
Value *ByRefDataPtr;
3361+
cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr));
3362+
Builder.CreateStore(GlobValPtr, ByRefDataPtr);
3363+
Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
3364+
} else {
3365+
Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3366+
}
33353367
}
33363368

33373369
// Call reduce_function(GlobalReduceList, ReduceList)
@@ -3346,30 +3378,30 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
33463378

33473379
Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
33483380
ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3349-
AttributeList FuncAttrs) {
3381+
AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
33503382
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
33513383
LLVMContext &Ctx = M.getContext();
33523384
FunctionType *FuncTy = FunctionType::get(
33533385
Builder.getVoidTy(),
33543386
{Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
33553387
/* IsVarArg */ false);
3356-
Function *LtGCFunc =
3388+
Function *GtLCFunc =
33573389
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
33583390
"_omp_reduction_global_to_list_copy_func", &M);
3359-
LtGCFunc->setAttributes(FuncAttrs);
3360-
LtGCFunc->addParamAttr(0, Attribute::NoUndef);
3361-
LtGCFunc->addParamAttr(1, Attribute::NoUndef);
3362-
LtGCFunc->addParamAttr(2, Attribute::NoUndef);
3391+
GtLCFunc->setAttributes(FuncAttrs);
3392+
GtLCFunc->addParamAttr(0, Attribute::NoUndef);
3393+
GtLCFunc->addParamAttr(1, Attribute::NoUndef);
3394+
GtLCFunc->addParamAttr(2, Attribute::NoUndef);
33633395

3364-
BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
3396+
BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", GtLCFunc);
33653397
Builder.SetInsertPoint(EntryBlock);
33663398

33673399
// Buffer: global reduction buffer.
3368-
Argument *BufferArg = LtGCFunc->getArg(0);
3400+
Argument *BufferArg = GtLCFunc->getArg(0);
33693401
// Idx: index of the buffer.
3370-
Argument *IdxArg = LtGCFunc->getArg(1);
3402+
Argument *IdxArg = GtLCFunc->getArg(1);
33713403
// ReduceList: thread local Reduce list.
3372-
Argument *ReduceListArg = LtGCFunc->getArg(2);
3404+
Argument *ReduceListArg = GtLCFunc->getArg(2);
33733405

33743406
Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
33753407
BufferArg->getName() + ".addr");
@@ -3413,7 +3445,15 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
34133445

34143446
switch (RI.EvaluationKind) {
34153447
case EvalKind::Scalar: {
3416-
Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
3448+
Type *ElemType = RI.ElementType;
3449+
3450+
if (!IsByRef.empty() && IsByRef[En.index()]) {
3451+
ElemType = RI.ByRefElementType;
3452+
cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr));
3453+
ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtr);
3454+
}
3455+
3456+
Value *TargetElement = Builder.CreateLoad(ElemType, GlobValPtr);
34173457
Builder.CreateStore(TargetElement, ElemPtr);
34183458
break;
34193459
}
@@ -3449,35 +3489,35 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
34493489

34503490
Builder.CreateRetVoid();
34513491
Builder.restoreIP(OldIP);
3452-
return LtGCFunc;
3492+
return GtLCFunc;
34533493
}
34543494

34553495
Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
34563496
ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3457-
Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3497+
Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
34583498
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
34593499
LLVMContext &Ctx = M.getContext();
34603500
auto *FuncTy = FunctionType::get(
34613501
Builder.getVoidTy(),
34623502
{Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
34633503
/* IsVarArg */ false);
3464-
Function *LtGRFunc =
3504+
Function *GtLRFunc =
34653505
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
34663506
"_omp_reduction_global_to_list_reduce_func", &M);
3467-
LtGRFunc->setAttributes(FuncAttrs);
3468-
LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3469-
LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3470-
LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3507+
GtLRFunc->setAttributes(FuncAttrs);
3508+
GtLRFunc->addParamAttr(0, Attribute::NoUndef);
3509+
GtLRFunc->addParamAttr(1, Attribute::NoUndef);
3510+
GtLRFunc->addParamAttr(2, Attribute::NoUndef);
34713511

3472-
BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3512+
BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", GtLRFunc);
34733513
Builder.SetInsertPoint(EntryBlock);
34743514

34753515
// Buffer: global reduction buffer.
3476-
Argument *BufferArg = LtGRFunc->getArg(0);
3516+
Argument *BufferArg = GtLRFunc->getArg(0);
34773517
// Idx: index of the buffer.
3478-
Argument *IdxArg = LtGRFunc->getArg(1);
3518+
Argument *IdxArg = GtLRFunc->getArg(1);
34793519
// ReduceList: thread local Reduce list.
3480-
Argument *ReduceListArg = LtGRFunc->getArg(2);
3520+
Argument *ReduceListArg = GtLRFunc->getArg(2);
34813521

34823522
Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
34833523
BufferArg->getName() + ".addr");
@@ -3493,6 +3533,8 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
34933533
Value *LocalReduceList =
34943534
Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
34953535

3536+
InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
3537+
34963538
Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
34973539
BufferArgAlloca, Builder.getPtrTy(),
34983540
BufferArgAlloca->getName() + ".ascast");
@@ -3514,6 +3556,20 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
35143556
Type *IndexTy = Builder.getIndexTy(
35153557
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
35163558
for (auto En : enumerate(ReductionInfos)) {
3559+
const ReductionInfo &RI = En.value();
3560+
Value *ByRefAlloc;
3561+
3562+
if (IsByRef[En.index()]) {
3563+
InsertPointTy OldIP = Builder.saveIP();
3564+
Builder.restoreIP(AllocaIP);
3565+
3566+
ByRefAlloc = Builder.CreateAlloca(RI.ByRefAllocatedType);
3567+
ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
3568+
ByRefAlloc, Builder.getPtrTy(), ByRefAlloc->getName() + ".ascast");
3569+
3570+
Builder.restoreIP(OldIP);
3571+
}
3572+
35173573
Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
35183574
RedListArrayTy, ReductionList,
35193575
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
@@ -3522,7 +3578,15 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
35223578
Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
35233579
Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
35243580
ReductionsBufferTy, BufferVD, 0, En.index());
3525-
Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3581+
3582+
if (IsByRef[En.index()]) {
3583+
Value *ByRefDataPtr;
3584+
cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr));
3585+
Builder.CreateStore(GlobValPtr, ByRefDataPtr);
3586+
Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
3587+
} else {
3588+
Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3589+
}
35263590
}
35273591

35283592
// Call reduce_function(ReduceList, GlobalReduceList)
@@ -3532,7 +3596,7 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
35323596
->addFnAttr(Attribute::NoUnwind);
35333597
Builder.CreateRetVoid();
35343598
Builder.restoreIP(OldIP);
3535-
return LtGRFunc;
3599+
return GtLRFunc;
35363600
}
35373601

35383602
std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
@@ -3788,7 +3852,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
37883852
auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
37893853
if (Size > MaxDataSize)
37903854
MaxDataSize = Size;
3791-
ReductionTypeArgs.emplace_back(En.value().ElementType);
3855+
Type *RedTypeArg = (!IsByRef.empty() && IsByRef[En.index()])
3856+
? En.value().ByRefElementType
3857+
: En.value().ElementType;
3858+
ReductionTypeArgs.emplace_back(RedTypeArg);
37923859
}
37933860
Value *ReductionDataSize =
37943861
Builder.getInt64(MaxDataSize * ReductionInfos.size());
@@ -3806,20 +3873,20 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
38063873
CodeGenIP = Builder.saveIP();
38073874
StructType *ReductionsBufferTy = StructType::create(
38083875
Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
3809-
Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
3876+
Function *RedFixedBufferFn = getOrCreateRuntimeFunctionPtr(
38103877
RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
38113878
Function *LtGCFunc = emitListToGlobalCopyFunction(
3812-
ReductionInfos, ReductionsBufferTy, FuncAttrs);
3879+
ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
38133880
Function *LtGRFunc = emitListToGlobalReduceFunction(
3814-
ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3881+
ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
38153882
Function *GtLCFunc = emitGlobalToListCopyFunction(
3816-
ReductionInfos, ReductionsBufferTy, FuncAttrs);
3883+
ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
38173884
Function *GtLRFunc = emitGlobalToListReduceFunction(
3818-
ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3885+
ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
38193886
Builder.restoreIP(CodeGenIP);
38203887

38213888
Value *KernelTeamsReductionPtr = createRuntimeFunctionCall(
3822-
RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
3889+
RedFixedBufferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
38233890

38243891
Value *Args3[] = {SrcLocInfo,
38253892
KernelTeamsReductionPtr,

mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
// Tests single-team by-ref GPU reductions.
2+
13
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
24

35
module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {

0 commit comments

Comments
 (0)