@@ -3153,7 +3153,7 @@ Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
31533153
31543154Function *OpenMPIRBuilder::emitListToGlobalCopyFunction (
31553155 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3156- AttributeList FuncAttrs) {
3156+ AttributeList FuncAttrs, ArrayRef< bool > IsByRef ) {
31573157 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP ();
31583158 LLVMContext &Ctx = M.getContext ();
31593159 FunctionType *FuncTy = FunctionType::get (
@@ -3223,7 +3223,15 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
32233223
32243224 switch (RI.EvaluationKind ) {
32253225 case EvalKind::Scalar: {
3226- Value *TargetElement = Builder.CreateLoad (RI.ElementType , ElemPtr);
3226+ Value *TargetElement = [&]() {
3227+ if (IsByRef.empty () || !IsByRef[En.index ()])
3228+ return Builder.CreateLoad (RI.ElementType , ElemPtr);
3229+
3230+ cantFail (RI.DataPtrPtrGen (Builder.saveIP (), ElemPtr, ElemPtr));
3231+ ElemPtr = Builder.CreateLoad (Builder.getPtrTy (), ElemPtr);
3232+ return Builder.CreateLoad (RI.ByRefElementType , ElemPtr);
3233+ }();
3234+
32273235 Builder.CreateStore (TargetElement, GlobVal);
32283236 break ;
32293237 }
@@ -3263,7 +3271,7 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
32633271
32643272Function *OpenMPIRBuilder::emitListToGlobalReduceFunction (
32653273 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3266- Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3274+ Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef< bool > IsByRef ) {
32673275 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP ();
32683276 LLVMContext &Ctx = M.getContext ();
32693277 FunctionType *FuncTy = FunctionType::get (
@@ -3302,6 +3310,8 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
33023310 Value *LocalReduceList =
33033311 Builder.CreateAlloca (RedListArrayTy, nullptr , " .omp.reduction.red_list" );
33043312
3313+ InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin ()};
3314+
33053315 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast (
33063316 BufferArgAlloca, Builder.getPtrTy (),
33073317 BufferArgAlloca->getName () + " .ascast" );
@@ -3323,6 +3333,20 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
33233333 Type *IndexTy = Builder.getIndexTy (
33243334 M.getDataLayout (), M.getDataLayout ().getDefaultGlobalsAddressSpace ());
33253335 for (auto En : enumerate(ReductionInfos)) {
3336+ const ReductionInfo &RI = En.value ();
3337+ Value *ByRefAlloc;
3338+
3339+ if (!IsByRef.empty () && IsByRef[En.index ()]) {
3340+ InsertPointTy OldIP = Builder.saveIP ();
3341+ Builder.restoreIP (AllocaIP);
3342+
3343+ ByRefAlloc = Builder.CreateAlloca (RI.ByRefAllocatedType );
3344+ ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast (
3345+ ByRefAlloc, Builder.getPtrTy (), ByRefAlloc->getName () + " .ascast" );
3346+
3347+ Builder.restoreIP (OldIP);
3348+ }
3349+
33263350 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP (
33273351 RedListArrayTy, LocalReduceListAddrCast,
33283352 {ConstantInt::get (IndexTy, 0 ), ConstantInt::get (IndexTy, En.index ())});
@@ -3331,7 +3355,15 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
33313355 // Global = Buffer.VD[Idx];
33323356 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32 (
33333357 ReductionsBufferTy, BufferVD, 0 , En.index ());
3334- Builder.CreateStore (GlobValPtr, TargetElementPtrPtr);
3358+
3359+ if (!IsByRef.empty () && IsByRef[En.index ()]) {
3360+ Value *ByRefDataPtr;
3361+ cantFail (RI.DataPtrPtrGen (Builder.saveIP (), ByRefAlloc, ByRefDataPtr));
3362+ Builder.CreateStore (GlobValPtr, ByRefDataPtr);
3363+ Builder.CreateStore (ByRefAlloc, TargetElementPtrPtr);
3364+ } else {
3365+ Builder.CreateStore (GlobValPtr, TargetElementPtrPtr);
3366+ }
33353367 }
33363368
33373369 // Call reduce_function(GlobalReduceList, ReduceList)
@@ -3346,30 +3378,30 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
33463378
33473379Function *OpenMPIRBuilder::emitGlobalToListCopyFunction (
33483380 ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3349- AttributeList FuncAttrs) {
3381+ AttributeList FuncAttrs, ArrayRef< bool > IsByRef ) {
33503382 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP ();
33513383 LLVMContext &Ctx = M.getContext ();
33523384 FunctionType *FuncTy = FunctionType::get (
33533385 Builder.getVoidTy (),
33543386 {Builder.getPtrTy (), Builder.getInt32Ty (), Builder.getPtrTy ()},
33553387 /* IsVarArg */ false );
3356- Function *LtGCFunc =
3388+ Function *GtLCFunc =
33573389 Function::Create (FuncTy, GlobalVariable::InternalLinkage,
33583390 " _omp_reduction_global_to_list_copy_func" , &M);
3359- LtGCFunc ->setAttributes (FuncAttrs);
3360- LtGCFunc ->addParamAttr (0 , Attribute::NoUndef);
3361- LtGCFunc ->addParamAttr (1 , Attribute::NoUndef);
3362- LtGCFunc ->addParamAttr (2 , Attribute::NoUndef);
3391+ GtLCFunc ->setAttributes (FuncAttrs);
3392+ GtLCFunc ->addParamAttr (0 , Attribute::NoUndef);
3393+ GtLCFunc ->addParamAttr (1 , Attribute::NoUndef);
3394+ GtLCFunc ->addParamAttr (2 , Attribute::NoUndef);
33633395
3364- BasicBlock *EntryBlock = BasicBlock::Create (Ctx, " entry" , LtGCFunc );
3396+ BasicBlock *EntryBlock = BasicBlock::Create (Ctx, " entry" , GtLCFunc );
33653397 Builder.SetInsertPoint (EntryBlock);
33663398
33673399 // Buffer: global reduction buffer.
3368- Argument *BufferArg = LtGCFunc ->getArg (0 );
3400+ Argument *BufferArg = GtLCFunc ->getArg (0 );
33693401 // Idx: index of the buffer.
3370- Argument *IdxArg = LtGCFunc ->getArg (1 );
3402+ Argument *IdxArg = GtLCFunc ->getArg (1 );
33713403 // ReduceList: thread local Reduce list.
3372- Argument *ReduceListArg = LtGCFunc ->getArg (2 );
3404+ Argument *ReduceListArg = GtLCFunc ->getArg (2 );
33733405
33743406 Value *BufferArgAlloca = Builder.CreateAlloca (Builder.getPtrTy (), nullptr ,
33753407 BufferArg->getName () + " .addr" );
@@ -3413,7 +3445,15 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
34133445
34143446 switch (RI.EvaluationKind ) {
34153447 case EvalKind::Scalar: {
3416- Value *TargetElement = Builder.CreateLoad (RI.ElementType , GlobValPtr);
3448+ Type *ElemType = RI.ElementType ;
3449+
3450+ if (!IsByRef.empty () && IsByRef[En.index ()]) {
3451+ ElemType = RI.ByRefElementType ;
3452+ cantFail (RI.DataPtrPtrGen (Builder.saveIP (), ElemPtr, ElemPtr));
3453+ ElemPtr = Builder.CreateLoad (Builder.getPtrTy (), ElemPtr);
3454+ }
3455+
3456+ Value *TargetElement = Builder.CreateLoad (ElemType, GlobValPtr);
34173457 Builder.CreateStore (TargetElement, ElemPtr);
34183458 break ;
34193459 }
@@ -3449,35 +3489,35 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
34493489
34503490 Builder.CreateRetVoid ();
34513491 Builder.restoreIP (OldIP);
3452- return LtGCFunc ;
3492+ return GtLCFunc ;
34533493}
34543494
34553495Function *OpenMPIRBuilder::emitGlobalToListReduceFunction (
34563496 ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3457- Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3497+ Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef< bool > IsByRef ) {
34583498 OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP ();
34593499 LLVMContext &Ctx = M.getContext ();
34603500 auto *FuncTy = FunctionType::get (
34613501 Builder.getVoidTy (),
34623502 {Builder.getPtrTy (), Builder.getInt32Ty (), Builder.getPtrTy ()},
34633503 /* IsVarArg */ false );
3464- Function *LtGRFunc =
3504+ Function *GtLRFunc =
34653505 Function::Create (FuncTy, GlobalVariable::InternalLinkage,
34663506 " _omp_reduction_global_to_list_reduce_func" , &M);
3467- LtGRFunc ->setAttributes (FuncAttrs);
3468- LtGRFunc ->addParamAttr (0 , Attribute::NoUndef);
3469- LtGRFunc ->addParamAttr (1 , Attribute::NoUndef);
3470- LtGRFunc ->addParamAttr (2 , Attribute::NoUndef);
3507+ GtLRFunc ->setAttributes (FuncAttrs);
3508+ GtLRFunc ->addParamAttr (0 , Attribute::NoUndef);
3509+ GtLRFunc ->addParamAttr (1 , Attribute::NoUndef);
3510+ GtLRFunc ->addParamAttr (2 , Attribute::NoUndef);
34713511
3472- BasicBlock *EntryBlock = BasicBlock::Create (Ctx, " entry" , LtGRFunc );
3512+ BasicBlock *EntryBlock = BasicBlock::Create (Ctx, " entry" , GtLRFunc );
34733513 Builder.SetInsertPoint (EntryBlock);
34743514
34753515 // Buffer: global reduction buffer.
3476- Argument *BufferArg = LtGRFunc ->getArg (0 );
3516+ Argument *BufferArg = GtLRFunc ->getArg (0 );
34773517 // Idx: index of the buffer.
3478- Argument *IdxArg = LtGRFunc ->getArg (1 );
3518+ Argument *IdxArg = GtLRFunc ->getArg (1 );
34793519 // ReduceList: thread local Reduce list.
3480- Argument *ReduceListArg = LtGRFunc ->getArg (2 );
3520+ Argument *ReduceListArg = GtLRFunc ->getArg (2 );
34813521
34823522 Value *BufferArgAlloca = Builder.CreateAlloca (Builder.getPtrTy (), nullptr ,
34833523 BufferArg->getName () + " .addr" );
@@ -3493,6 +3533,8 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
34933533 Value *LocalReduceList =
34943534 Builder.CreateAlloca (RedListArrayTy, nullptr , " .omp.reduction.red_list" );
34953535
3536+ InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin ()};
3537+
34963538 Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast (
34973539 BufferArgAlloca, Builder.getPtrTy (),
34983540 BufferArgAlloca->getName () + " .ascast" );
@@ -3514,6 +3556,20 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
35143556 Type *IndexTy = Builder.getIndexTy (
35153557 M.getDataLayout (), M.getDataLayout ().getDefaultGlobalsAddressSpace ());
35163558 for (auto En : enumerate(ReductionInfos)) {
3559+ const ReductionInfo &RI = En.value ();
3560+ Value *ByRefAlloc;
3561+
3562+ if (IsByRef[En.index ()]) {
3563+ InsertPointTy OldIP = Builder.saveIP ();
3564+ Builder.restoreIP (AllocaIP);
3565+
3566+ ByRefAlloc = Builder.CreateAlloca (RI.ByRefAllocatedType );
3567+ ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast (
3568+ ByRefAlloc, Builder.getPtrTy (), ByRefAlloc->getName () + " .ascast" );
3569+
3570+ Builder.restoreIP (OldIP);
3571+ }
3572+
35173573 Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP (
35183574 RedListArrayTy, ReductionList,
35193575 {ConstantInt::get (IndexTy, 0 ), ConstantInt::get (IndexTy, En.index ())});
@@ -3522,7 +3578,15 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
35223578 Builder.CreateInBoundsGEP (ReductionsBufferTy, BufferVal, Idxs);
35233579 Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32 (
35243580 ReductionsBufferTy, BufferVD, 0 , En.index ());
3525- Builder.CreateStore (GlobValPtr, TargetElementPtrPtr);
3581+
3582+ if (IsByRef[En.index ()]) {
3583+ Value *ByRefDataPtr;
3584+ cantFail (RI.DataPtrPtrGen (Builder.saveIP (), ByRefAlloc, ByRefDataPtr));
3585+ Builder.CreateStore (GlobValPtr, ByRefDataPtr);
3586+ Builder.CreateStore (ByRefAlloc, TargetElementPtrPtr);
3587+ } else {
3588+ Builder.CreateStore (GlobValPtr, TargetElementPtrPtr);
3589+ }
35263590 }
35273591
35283592 // Call reduce_function(ReduceList, GlobalReduceList)
@@ -3532,7 +3596,7 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
35323596 ->addFnAttr (Attribute::NoUnwind);
35333597 Builder.CreateRetVoid ();
35343598 Builder.restoreIP (OldIP);
3535- return LtGRFunc ;
3599+ return GtLRFunc ;
35363600}
35373601
35383602std::string OpenMPIRBuilder::getReductionFuncName (StringRef Name) const {
@@ -3788,7 +3852,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
37883852 auto Size = M.getDataLayout ().getTypeStoreSize (En.value ().ElementType );
37893853 if (Size > MaxDataSize)
37903854 MaxDataSize = Size;
3791- ReductionTypeArgs.emplace_back (En.value ().ElementType );
3855+ Type *RedTypeArg = (!IsByRef.empty () && IsByRef[En.index ()])
3856+ ? En.value ().ByRefElementType
3857+ : En.value ().ElementType ;
3858+ ReductionTypeArgs.emplace_back (RedTypeArg);
37923859 }
37933860 Value *ReductionDataSize =
37943861 Builder.getInt64 (MaxDataSize * ReductionInfos.size ());
@@ -3806,20 +3873,20 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
38063873 CodeGenIP = Builder.saveIP ();
38073874 StructType *ReductionsBufferTy = StructType::create (
38083875 Ctx, ReductionTypeArgs, " struct._globalized_locals_ty" );
3809- Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr (
3876+ Function *RedFixedBufferFn = getOrCreateRuntimeFunctionPtr (
38103877 RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
38113878 Function *LtGCFunc = emitListToGlobalCopyFunction (
3812- ReductionInfos, ReductionsBufferTy, FuncAttrs);
3879+ ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef );
38133880 Function *LtGRFunc = emitListToGlobalReduceFunction (
3814- ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3881+ ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef );
38153882 Function *GtLCFunc = emitGlobalToListCopyFunction (
3816- ReductionInfos, ReductionsBufferTy, FuncAttrs);
3883+ ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef );
38173884 Function *GtLRFunc = emitGlobalToListReduceFunction (
3818- ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3885+ ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef );
38193886 Builder.restoreIP (CodeGenIP);
38203887
38213888 Value *KernelTeamsReductionPtr = createRuntimeFunctionCall (
3822- RedFixedBuferFn , {}, " _openmp_teams_reductions_buffer_$_$ptr" );
3889+ RedFixedBufferFn , {}, " _openmp_teams_reductions_buffer_$_$ptr" );
38233890
38243891 Value *Args3[] = {SrcLocInfo,
38253892 KernelTeamsReductionPtr,
0 commit comments