Skip to content

Conversation

@ergawy
Copy link
Member

@ergawy ergawy commented Nov 26, 2025

Extends the work started in #165714 by supporting team reductions. Similar to what was done in #165714, this PR introduces proper allocations, loads, and stores for by-ref reductions in teams-related callbacks:

  • _omp_reduction_list_to_global_copy_func,
  • _omp_reduction_list_to_global_reduce_func,
  • _omp_reduction_global_to_list_copy_func, and
  • _omp_reduction_global_to_list_reduce_func.

@llvmbot llvmbot added mlir:llvm mlir flang:openmp clang:openmp OpenMP related changes to Clang labels Nov 26, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-mlir-llvm

Author: Kareem Ergawy (ergawy)

Changes

Extends the work started in #165714 by supporting team reductions. Similar to what was done in #165714, this PR introduces proper allocations, loads, and stores for by-ref reductions in teams-related callbacks:

  • _omp_reduction_list_to_global_copy_func,
  • _omp_reduction_list_to_global_reduce_func,
  • _omp_reduction_global_to_list_copy_func, and
  • _omp_reduction_global_to_list_reduce_func.

Patch is 23.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169651.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+8-4)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+102-35)
  • (modified) mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir (+2)
  • (added) mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir (+121)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 7b097d1ac0ee0..27517903e3780 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1733,7 +1733,8 @@ class OpenMPIRBuilder {
   /// \return The ListToGlobalCopy function.
   Function *emitListToGlobalCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
                                          Type *ReductionsBufferTy,
-                                         AttributeList FuncAttrs);
+                                         AttributeList FuncAttrs,
+                                         ArrayRef<bool> IsByRef);
 
   /// This function emits a helper that copies all the reduction variables from
   /// the team into the provided global buffer for the reduction variables.
@@ -1750,7 +1751,8 @@ class OpenMPIRBuilder {
   /// \return The GlobalToList function.
   Function *emitGlobalToListCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
                                          Type *ReductionsBufferTy,
-                                         AttributeList FuncAttrs);
+                                         AttributeList FuncAttrs,
+                                         ArrayRef<bool> IsByRef);
 
   /// This function emits a helper that reduces all the reduction variables from
   /// the team into the provided global buffer for the reduction variables.
@@ -1772,7 +1774,8 @@ class OpenMPIRBuilder {
   Function *
   emitListToGlobalReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,
                                  Function *ReduceFn, Type *ReductionsBufferTy,
-                                 AttributeList FuncAttrs);
+                                 AttributeList FuncAttrs,
+                                 ArrayRef<bool> IsByRef);
 
   /// This function emits a helper that reduces all the reduction variables from
   /// the team into the provided global buffer for the reduction variables.
@@ -1794,7 +1797,8 @@ class OpenMPIRBuilder {
   Function *
   emitGlobalToListReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,
                                  Function *ReduceFn, Type *ReductionsBufferTy,
-                                 AttributeList FuncAttrs);
+                                 AttributeList FuncAttrs,
+                                 ArrayRef<bool> IsByRef);
 
   /// Get the function name of a reduction function.
   std::string getReductionFuncName(StringRef Name) const;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index c962368859730..01d369a8751be 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -3153,7 +3153,7 @@ Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
 
 Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
-    AttributeList FuncAttrs) {
+    AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
   LLVMContext &Ctx = M.getContext();
   FunctionType *FuncTy = FunctionType::get(
@@ -3223,7 +3223,15 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
 
     switch (RI.EvaluationKind) {
     case EvalKind::Scalar: {
-      Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
+      Value *TargetElement = [&]() {
+        if (IsByRef.empty() || !IsByRef[En.index()])
+          return Builder.CreateLoad(RI.ElementType, ElemPtr);
+
+        cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr));
+        ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtr);
+        return Builder.CreateLoad(RI.ByRefElementType, ElemPtr);
+      }();
+
       Builder.CreateStore(TargetElement, GlobVal);
       break;
     }
@@ -3263,7 +3271,7 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
 
 Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
-    Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+    Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
   LLVMContext &Ctx = M.getContext();
   FunctionType *FuncTy = FunctionType::get(
@@ -3302,6 +3310,8 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
   Value *LocalReduceList =
       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
 
+  InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
+
   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
       BufferArgAlloca, Builder.getPtrTy(),
       BufferArgAlloca->getName() + ".ascast");
@@ -3323,6 +3333,20 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
   Type *IndexTy = Builder.getIndexTy(
       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
   for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *ByRefAlloc;
+
+    if (!IsByRef.empty() && IsByRef[En.index()]) {
+      InsertPointTy OldIP = Builder.saveIP();
+      Builder.restoreIP(AllocaIP);
+
+      ByRefAlloc = Builder.CreateAlloca(RI.ByRefAllocatedType);
+      ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          ByRefAlloc, Builder.getPtrTy(), ByRefAlloc->getName() + ".ascast");
+
+      Builder.restoreIP(OldIP);
+    }
+
     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
         RedListArrayTy, LocalReduceListAddrCast,
         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
@@ -3331,7 +3355,15 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
     // Global = Buffer.VD[Idx];
     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
         ReductionsBufferTy, BufferVD, 0, En.index());
-    Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+
+    if (!IsByRef.empty() && IsByRef[En.index()]) {
+      Value *ByRefDataPtr;
+      cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr));
+      Builder.CreateStore(GlobValPtr, ByRefDataPtr);
+      Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
+    } else {
+      Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+    }
   }
 
   // Call reduce_function(GlobalReduceList, ReduceList)
@@ -3346,30 +3378,30 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
 
 Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
     ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
-    AttributeList FuncAttrs) {
+    AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
   LLVMContext &Ctx = M.getContext();
   FunctionType *FuncTy = FunctionType::get(
       Builder.getVoidTy(),
       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
       /* IsVarArg */ false);
-  Function *LtGCFunc =
+  Function *GtLCFunc =
       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
                        "_omp_reduction_global_to_list_copy_func", &M);
-  LtGCFunc->setAttributes(FuncAttrs);
-  LtGCFunc->addParamAttr(0, Attribute::NoUndef);
-  LtGCFunc->addParamAttr(1, Attribute::NoUndef);
-  LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+  GtLCFunc->setAttributes(FuncAttrs);
+  GtLCFunc->addParamAttr(0, Attribute::NoUndef);
+  GtLCFunc->addParamAttr(1, Attribute::NoUndef);
+  GtLCFunc->addParamAttr(2, Attribute::NoUndef);
 
-  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", GtLCFunc);
   Builder.SetInsertPoint(EntryBlock);
 
   // Buffer: global reduction buffer.
-  Argument *BufferArg = LtGCFunc->getArg(0);
+  Argument *BufferArg = GtLCFunc->getArg(0);
   // Idx: index of the buffer.
-  Argument *IdxArg = LtGCFunc->getArg(1);
+  Argument *IdxArg = GtLCFunc->getArg(1);
   // ReduceList: thread local Reduce list.
-  Argument *ReduceListArg = LtGCFunc->getArg(2);
+  Argument *ReduceListArg = GtLCFunc->getArg(2);
 
   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
                                                 BufferArg->getName() + ".addr");
@@ -3413,7 +3445,15 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
 
     switch (RI.EvaluationKind) {
     case EvalKind::Scalar: {
-      Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
+      Type *ElemType = RI.ElementType;
+
+      if (!IsByRef.empty() && IsByRef[En.index()]) {
+        ElemType = RI.ByRefElementType;
+        cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr));
+        ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtr);
+      }
+
+      Value *TargetElement = Builder.CreateLoad(ElemType, GlobValPtr);
       Builder.CreateStore(TargetElement, ElemPtr);
       break;
     }
@@ -3449,35 +3489,35 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
 
   Builder.CreateRetVoid();
   Builder.restoreIP(OldIP);
-  return LtGCFunc;
+  return GtLCFunc;
 }
 
 Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
     ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
-    Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+    Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
   OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
   LLVMContext &Ctx = M.getContext();
   auto *FuncTy = FunctionType::get(
       Builder.getVoidTy(),
       {Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
       /* IsVarArg */ false);
-  Function *LtGRFunc =
+  Function *GtLRFunc =
       Function::Create(FuncTy, GlobalVariable::InternalLinkage,
                        "_omp_reduction_global_to_list_reduce_func", &M);
-  LtGRFunc->setAttributes(FuncAttrs);
-  LtGRFunc->addParamAttr(0, Attribute::NoUndef);
-  LtGRFunc->addParamAttr(1, Attribute::NoUndef);
-  LtGRFunc->addParamAttr(2, Attribute::NoUndef);
+  GtLRFunc->setAttributes(FuncAttrs);
+  GtLRFunc->addParamAttr(0, Attribute::NoUndef);
+  GtLRFunc->addParamAttr(1, Attribute::NoUndef);
+  GtLRFunc->addParamAttr(2, Attribute::NoUndef);
 
-  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
+  BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", GtLRFunc);
   Builder.SetInsertPoint(EntryBlock);
 
   // Buffer: global reduction buffer.
-  Argument *BufferArg = LtGRFunc->getArg(0);
+  Argument *BufferArg = GtLRFunc->getArg(0);
   // Idx: index of the buffer.
-  Argument *IdxArg = LtGRFunc->getArg(1);
+  Argument *IdxArg = GtLRFunc->getArg(1);
   // ReduceList: thread local Reduce list.
-  Argument *ReduceListArg = LtGRFunc->getArg(2);
+  Argument *ReduceListArg = GtLRFunc->getArg(2);
 
   Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
                                                 BufferArg->getName() + ".addr");
@@ -3493,6 +3533,8 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
   Value *LocalReduceList =
       Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
 
+  InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
+
   Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
       BufferArgAlloca, Builder.getPtrTy(),
       BufferArgAlloca->getName() + ".ascast");
@@ -3514,6 +3556,20 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
   Type *IndexTy = Builder.getIndexTy(
       M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
   for (auto En : enumerate(ReductionInfos)) {
+    const ReductionInfo &RI = En.value();
+    Value *ByRefAlloc;
+
+    if (IsByRef[En.index()]) {
+      InsertPointTy OldIP = Builder.saveIP();
+      Builder.restoreIP(AllocaIP);
+
+      ByRefAlloc = Builder.CreateAlloca(RI.ByRefAllocatedType);
+      ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          ByRefAlloc, Builder.getPtrTy(), ByRefAlloc->getName() + ".ascast");
+
+      Builder.restoreIP(OldIP);
+    }
+
     Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
         RedListArrayTy, ReductionList,
         {ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
@@ -3522,7 +3578,15 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
         Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
     Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
         ReductionsBufferTy, BufferVD, 0, En.index());
-    Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+
+    if (IsByRef[En.index()]) {
+      Value *ByRefDataPtr;
+      cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr));
+      Builder.CreateStore(GlobValPtr, ByRefDataPtr);
+      Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
+    } else {
+      Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+    }
   }
 
   // Call reduce_function(ReduceList, GlobalReduceList)
@@ -3532,7 +3596,7 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
       ->addFnAttr(Attribute::NoUnwind);
   Builder.CreateRetVoid();
   Builder.restoreIP(OldIP);
-  return LtGRFunc;
+  return GtLRFunc;
 }
 
 std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
@@ -3788,7 +3852,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
     auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
     if (Size > MaxDataSize)
       MaxDataSize = Size;
-    ReductionTypeArgs.emplace_back(En.value().ElementType);
+    Type *RedTypeArg = (!IsByRef.empty() && IsByRef[En.index()])
+                           ? En.value().ByRefElementType
+                           : En.value().ElementType;
+    ReductionTypeArgs.emplace_back(RedTypeArg);
   }
   Value *ReductionDataSize =
       Builder.getInt64(MaxDataSize * ReductionInfos.size());
@@ -3806,20 +3873,20 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
     CodeGenIP = Builder.saveIP();
     StructType *ReductionsBufferTy = StructType::create(
         Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
-    Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
+    Function *RedFixedBufferFn = getOrCreateRuntimeFunctionPtr(
         RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
     Function *LtGCFunc = emitListToGlobalCopyFunction(
-        ReductionInfos, ReductionsBufferTy, FuncAttrs);
+        ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
     Function *LtGRFunc = emitListToGlobalReduceFunction(
-        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
+        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
     Function *GtLCFunc = emitGlobalToListCopyFunction(
-        ReductionInfos, ReductionsBufferTy, FuncAttrs);
+        ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
     Function *GtLRFunc = emitGlobalToListReduceFunction(
-        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
+        ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
     Builder.restoreIP(CodeGenIP);
 
     Value *KernelTeamsReductionPtr = createRuntimeFunctionCall(
-        RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
+        RedFixedBufferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
 
     Value *Args3[] = {SrcLocInfo,
                       KernelTeamsReductionPtr,
diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
index df606150b760a..95d12f304aca0 100644
--- a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
+++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
@@ -1,3 +1,5 @@
+// Tests single-team by-ref GPU reductions.
+
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
 module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir
new file mode 100644
index 0000000000000..1c73a49b0bf9f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir
@@ -0,0 +1,121 @@
+// Tests cross-teams by-ref GPU reductions.
+
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+  omp.private {type = private} @_QFfooEi_private_i32 : i32
+  omp.declare_reduction @add_reduction_byref_box_heap_f32 : !llvm.ptr attributes {byref_element_type = f32} alloc {
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> : (i64) -> !llvm.ptr<5>
+    %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+    omp.yield(%2 : !llvm.ptr)
+  } init {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+    omp.yield(%arg1 : !llvm.ptr)
+  } combiner {
+  ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+    %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+    %3 = llvm.mlir.constant(1 : i32) : i32
+    %4 = llvm.alloca %3 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+    %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+    %6 = llvm.mlir.constant(24 : i32) : i32
+    "llvm.intr.memcpy"(%5, %arg0, %6) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+    %7 = llvm.mlir.constant(24 : i32) : i32
+    "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+    %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+    %9 = llvm.load %8 : !llvm.ptr -> !llvm.ptr
+    %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+    %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
+    %12 = llvm.load %9 : !llvm.ptr -> f32
+    %13 = llvm.load %11 : !llvm.ptr -> f32
+    %14 = llvm.fadd %12, %13 {fastmathFlags = #llvm.fastmath<contract>} : f32
+    llvm.store %14, %9 : f32, !llvm.ptr
+    omp.yield(%arg0 : !llvm.ptr)
+  } data_ptr_ptr {
+  ^bb0(%arg0: !llvm.ptr):
+    %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+    omp.yield(%0 : !llvm.ptr)
+  }
+
+  llvm.func @foo_() {
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %4 = llvm.alloca %0 x i1 : (i64) -> !llvm.ptr<5>
+    %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+    %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+    %9 = omp.map.info var_ptr(%5 : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%8 : !llvm.ptr) -> !llvm.ptr {name = ""}
+    %10 = omp.map.info var_ptr(%5 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, descriptor, to, attach) capture(ByRef) members(%9 : [0] : !llvm.ptr) -> !llvm.ptr {name = "scalar_alloc"}
+    omp.target map_entries(%10 -> %arg0 : !llvm.ptr) {
+      %14 = llvm.mlir.constant(1000000 : i32) : i32
+      %15 = llvm.mlir.constant(1 : i32) : i32
+      omp.teams reduction(byref @add_reduction_byref_box_heap_f32 %arg0 -> %arg3 : !llvm.ptr) {
+        omp.parallel {
+          omp.distribute {
+            omp.wsloop reduction(byref @add_reduction_byref_box_heap_f32 %arg3 -> %arg5 : !llvm.ptr) {
+              omp.loop_nest (%arg6) : i32 = (%15) to (%14) inclusive step (%15) {
+                omp.yield
+              }
+            } {omp.composite}
+          } {omp.composite}
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// CHECK: %[[GLOBALIZED_LOCALS:.*]] = type { float }
+
+// CHECK: define internal void @_omp_reduction_list_to_global_copy_func({{.*}}) {{.*}} {
+// CHECK:   %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK:   %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LI...
[truncated]

@ergawy ergawy marked this pull request as draft November 26, 2025 13:27
@github-actions
Copy link

github-actions bot commented Nov 26, 2025

🐧 Linux x64 Test Results

  • 186643 tests passed
  • 4888 tests skipped

Extends the work started in #165714 by supporting team reductions.
Similar to what was done in #165714, this PR introduces proper
allocations, loads, and stores for by-ref reductions in teams-related
callbacks:
* `_omp_reduction_list_to_global_copy_func`,
* `_omp_reduction_list_to_global_reduce_func`,
* `_omp_reduction_global_to_list_copy_func`, and
* `_omp_reduction_global_to_list_reduce_func`.
@ergawy ergawy force-pushed the users/ergawy/allocatable_reduction_irbuilder2 branch from 9313342 to 008deca Compare November 27, 2025 09:10
@ergawy ergawy marked this pull request as ready for review November 27, 2025 10:07
@Meinersbur Meinersbur requested a review from tblah November 27, 2025 12:01
@ergawy ergawy changed the title [OpenMP][flang] Support GPU team reductions on allocatables [OpenMP][flang] Support GPU team-reductions on allocatables Nov 28, 2025
Copy link
Contributor

@jsjodin jsjodin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ergawy
Copy link
Member Author

ergawy commented Dec 1, 2025

Thanks Jan for the review. If anyone else has any comments on this PR please let me know.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM to the extent that I understand GPU things.

@ergawy ergawy merged commit 867d353 into main Dec 2, 2025
12 checks passed
@ergawy ergawy deleted the users/ergawy/allocatable_reduction_irbuilder2 branch December 2, 2025 04:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants