-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[OpenMP][flang] Support GPU team-reductions on allocatables #169651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Kareem Ergawy (ergawy) ChangesExtends the work started in #165714 by supporting team reductions. Similar to what was done in #165714, this PR introduces proper allocations, loads, and stores for by-ref reductions in teams-related callbacks:
Patch is 23.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169651.diff 4 Files Affected:
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 7b097d1ac0ee0..27517903e3780 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1733,7 +1733,8 @@ class OpenMPIRBuilder {
/// \return The ListToGlobalCopy function.
Function *emitListToGlobalCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
Type *ReductionsBufferTy,
- AttributeList FuncAttrs);
+ AttributeList FuncAttrs,
+ ArrayRef<bool> IsByRef);
/// This function emits a helper that copies all the reduction variables from
/// the team into the provided global buffer for the reduction variables.
@@ -1750,7 +1751,8 @@ class OpenMPIRBuilder {
/// \return The GlobalToList function.
Function *emitGlobalToListCopyFunction(ArrayRef<ReductionInfo> ReductionInfos,
Type *ReductionsBufferTy,
- AttributeList FuncAttrs);
+ AttributeList FuncAttrs,
+ ArrayRef<bool> IsByRef);
/// This function emits a helper that reduces all the reduction variables from
/// the team into the provided global buffer for the reduction variables.
@@ -1772,7 +1774,8 @@ class OpenMPIRBuilder {
Function *
emitListToGlobalReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,
Function *ReduceFn, Type *ReductionsBufferTy,
- AttributeList FuncAttrs);
+ AttributeList FuncAttrs,
+ ArrayRef<bool> IsByRef);
/// This function emits a helper that reduces all the reduction variables from
/// the team into the provided global buffer for the reduction variables.
@@ -1794,7 +1797,8 @@ class OpenMPIRBuilder {
Function *
emitGlobalToListReduceFunction(ArrayRef<ReductionInfo> ReductionInfos,
Function *ReduceFn, Type *ReductionsBufferTy,
- AttributeList FuncAttrs);
+ AttributeList FuncAttrs,
+ ArrayRef<bool> IsByRef);
/// Get the function name of a reduction function.
std::string getReductionFuncName(StringRef Name) const;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index c962368859730..01d369a8751be 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -3153,7 +3153,7 @@ Expected<Function *> OpenMPIRBuilder::emitShuffleAndReduceFunction(
Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
- AttributeList FuncAttrs) {
+ AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
LLVMContext &Ctx = M.getContext();
FunctionType *FuncTy = FunctionType::get(
@@ -3223,7 +3223,15 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
switch (RI.EvaluationKind) {
case EvalKind::Scalar: {
- Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
+ Value *TargetElement = [&]() {
+ if (IsByRef.empty() || !IsByRef[En.index()])
+ return Builder.CreateLoad(RI.ElementType, ElemPtr);
+
+ cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr));
+ ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtr);
+ return Builder.CreateLoad(RI.ByRefElementType, ElemPtr);
+ }();
+
Builder.CreateStore(TargetElement, GlobVal);
break;
}
@@ -3263,7 +3271,7 @@ Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
- Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+ Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
LLVMContext &Ctx = M.getContext();
FunctionType *FuncTy = FunctionType::get(
@@ -3302,6 +3310,8 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
Value *LocalReduceList =
Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
+ InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
+
Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
BufferArgAlloca, Builder.getPtrTy(),
BufferArgAlloca->getName() + ".ascast");
@@ -3323,6 +3333,20 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
Type *IndexTy = Builder.getIndexTy(
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
for (auto En : enumerate(ReductionInfos)) {
+ const ReductionInfo &RI = En.value();
+ Value *ByRefAlloc;
+
+ if (!IsByRef.empty() && IsByRef[En.index()]) {
+ InsertPointTy OldIP = Builder.saveIP();
+ Builder.restoreIP(AllocaIP);
+
+ ByRefAlloc = Builder.CreateAlloca(RI.ByRefAllocatedType);
+ ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ByRefAlloc, Builder.getPtrTy(), ByRefAlloc->getName() + ".ascast");
+
+ Builder.restoreIP(OldIP);
+ }
+
Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
RedListArrayTy, LocalReduceListAddrCast,
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
@@ -3331,7 +3355,15 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
// Global = Buffer.VD[Idx];
Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
ReductionsBufferTy, BufferVD, 0, En.index());
- Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+
+ if (!IsByRef.empty() && IsByRef[En.index()]) {
+ Value *ByRefDataPtr;
+ cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr));
+ Builder.CreateStore(GlobValPtr, ByRefDataPtr);
+ Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
+ } else {
+ Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+ }
}
// Call reduce_function(GlobalReduceList, ReduceList)
@@ -3346,30 +3378,30 @@ Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
- AttributeList FuncAttrs) {
+ AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
LLVMContext &Ctx = M.getContext();
FunctionType *FuncTy = FunctionType::get(
Builder.getVoidTy(),
{Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
/* IsVarArg */ false);
- Function *LtGCFunc =
+ Function *GtLCFunc =
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
"_omp_reduction_global_to_list_copy_func", &M);
- LtGCFunc->setAttributes(FuncAttrs);
- LtGCFunc->addParamAttr(0, Attribute::NoUndef);
- LtGCFunc->addParamAttr(1, Attribute::NoUndef);
- LtGCFunc->addParamAttr(2, Attribute::NoUndef);
+ GtLCFunc->setAttributes(FuncAttrs);
+ GtLCFunc->addParamAttr(0, Attribute::NoUndef);
+ GtLCFunc->addParamAttr(1, Attribute::NoUndef);
+ GtLCFunc->addParamAttr(2, Attribute::NoUndef);
- BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
+ BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", GtLCFunc);
Builder.SetInsertPoint(EntryBlock);
// Buffer: global reduction buffer.
- Argument *BufferArg = LtGCFunc->getArg(0);
+ Argument *BufferArg = GtLCFunc->getArg(0);
// Idx: index of the buffer.
- Argument *IdxArg = LtGCFunc->getArg(1);
+ Argument *IdxArg = GtLCFunc->getArg(1);
// ReduceList: thread local Reduce list.
- Argument *ReduceListArg = LtGCFunc->getArg(2);
+ Argument *ReduceListArg = GtLCFunc->getArg(2);
Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
BufferArg->getName() + ".addr");
@@ -3413,7 +3445,15 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
switch (RI.EvaluationKind) {
case EvalKind::Scalar: {
- Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
+ Type *ElemType = RI.ElementType;
+
+ if (!IsByRef.empty() && IsByRef[En.index()]) {
+ ElemType = RI.ByRefElementType;
+ cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ElemPtr, ElemPtr));
+ ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtr);
+ }
+
+ Value *TargetElement = Builder.CreateLoad(ElemType, GlobValPtr);
Builder.CreateStore(TargetElement, ElemPtr);
break;
}
@@ -3449,35 +3489,35 @@ Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
Builder.CreateRetVoid();
Builder.restoreIP(OldIP);
- return LtGCFunc;
+ return GtLCFunc;
}
Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
- Type *ReductionsBufferTy, AttributeList FuncAttrs) {
+ Type *ReductionsBufferTy, AttributeList FuncAttrs, ArrayRef<bool> IsByRef) {
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
LLVMContext &Ctx = M.getContext();
auto *FuncTy = FunctionType::get(
Builder.getVoidTy(),
{Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
/* IsVarArg */ false);
- Function *LtGRFunc =
+ Function *GtLRFunc =
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
"_omp_reduction_global_to_list_reduce_func", &M);
- LtGRFunc->setAttributes(FuncAttrs);
- LtGRFunc->addParamAttr(0, Attribute::NoUndef);
- LtGRFunc->addParamAttr(1, Attribute::NoUndef);
- LtGRFunc->addParamAttr(2, Attribute::NoUndef);
+ GtLRFunc->setAttributes(FuncAttrs);
+ GtLRFunc->addParamAttr(0, Attribute::NoUndef);
+ GtLRFunc->addParamAttr(1, Attribute::NoUndef);
+ GtLRFunc->addParamAttr(2, Attribute::NoUndef);
- BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
+ BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", GtLRFunc);
Builder.SetInsertPoint(EntryBlock);
// Buffer: global reduction buffer.
- Argument *BufferArg = LtGRFunc->getArg(0);
+ Argument *BufferArg = GtLRFunc->getArg(0);
// Idx: index of the buffer.
- Argument *IdxArg = LtGRFunc->getArg(1);
+ Argument *IdxArg = GtLRFunc->getArg(1);
// ReduceList: thread local Reduce list.
- Argument *ReduceListArg = LtGRFunc->getArg(2);
+ Argument *ReduceListArg = GtLRFunc->getArg(2);
Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
BufferArg->getName() + ".addr");
@@ -3493,6 +3533,8 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
Value *LocalReduceList =
Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
+ InsertPointTy AllocaIP{EntryBlock, EntryBlock->begin()};
+
Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
BufferArgAlloca, Builder.getPtrTy(),
BufferArgAlloca->getName() + ".ascast");
@@ -3514,6 +3556,20 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
Type *IndexTy = Builder.getIndexTy(
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
for (auto En : enumerate(ReductionInfos)) {
+ const ReductionInfo &RI = En.value();
+ Value *ByRefAlloc;
+
+ if (IsByRef[En.index()]) {
+ InsertPointTy OldIP = Builder.saveIP();
+ Builder.restoreIP(AllocaIP);
+
+ ByRefAlloc = Builder.CreateAlloca(RI.ByRefAllocatedType);
+ ByRefAlloc = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ByRefAlloc, Builder.getPtrTy(), ByRefAlloc->getName() + ".ascast");
+
+ Builder.restoreIP(OldIP);
+ }
+
Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
RedListArrayTy, ReductionList,
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
@@ -3522,7 +3578,15 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
ReductionsBufferTy, BufferVD, 0, En.index());
- Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+
+ if (IsByRef[En.index()]) {
+ Value *ByRefDataPtr;
+ cantFail(RI.DataPtrPtrGen(Builder.saveIP(), ByRefAlloc, ByRefDataPtr));
+ Builder.CreateStore(GlobValPtr, ByRefDataPtr);
+ Builder.CreateStore(ByRefAlloc, TargetElementPtrPtr);
+ } else {
+ Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
+ }
}
// Call reduce_function(ReduceList, GlobalReduceList)
@@ -3532,7 +3596,7 @@ Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
->addFnAttr(Attribute::NoUnwind);
Builder.CreateRetVoid();
Builder.restoreIP(OldIP);
- return LtGRFunc;
+ return GtLRFunc;
}
std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
@@ -3788,7 +3852,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
if (Size > MaxDataSize)
MaxDataSize = Size;
- ReductionTypeArgs.emplace_back(En.value().ElementType);
+ Type *RedTypeArg = (!IsByRef.empty() && IsByRef[En.index()])
+ ? En.value().ByRefElementType
+ : En.value().ElementType;
+ ReductionTypeArgs.emplace_back(RedTypeArg);
}
Value *ReductionDataSize =
Builder.getInt64(MaxDataSize * ReductionInfos.size());
@@ -3806,20 +3873,20 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
CodeGenIP = Builder.saveIP();
StructType *ReductionsBufferTy = StructType::create(
Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
- Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
+ Function *RedFixedBufferFn = getOrCreateRuntimeFunctionPtr(
RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
Function *LtGCFunc = emitListToGlobalCopyFunction(
- ReductionInfos, ReductionsBufferTy, FuncAttrs);
+ ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
Function *LtGRFunc = emitListToGlobalReduceFunction(
- ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
+ ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
Function *GtLCFunc = emitGlobalToListCopyFunction(
- ReductionInfos, ReductionsBufferTy, FuncAttrs);
+ ReductionInfos, ReductionsBufferTy, FuncAttrs, IsByRef);
Function *GtLRFunc = emitGlobalToListReduceFunction(
- ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
+ ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs, IsByRef);
Builder.restoreIP(CodeGenIP);
Value *KernelTeamsReductionPtr = createRuntimeFunctionCall(
- RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
+ RedFixedBufferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
Value *Args3[] = {SrcLocInfo,
KernelTeamsReductionPtr,
diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
index df606150b760a..95d12f304aca0 100644
--- a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
+++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir
@@ -1,3 +1,5 @@
+// Tests single-team by-ref GPU reductions.
+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
diff --git a/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir
new file mode 100644
index 0000000000000..1c73a49b0bf9f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir
@@ -0,0 +1,121 @@
+// Tests cross-teams by-ref GPU reductions.
+
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<"dlti.alloca_memory_space" = 5 : ui64, "dlti.global_memory_space" = 1 : ui64>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {
+ omp.private {type = private} @_QFfooEi_private_i32 : i32
+ omp.declare_reduction @add_reduction_byref_box_heap_f32 : !llvm.ptr attributes {byref_element_type = f32} alloc {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> : (i64) -> !llvm.ptr<5>
+ %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+ omp.yield(%2 : !llvm.ptr)
+ } init {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ omp.yield(%arg1 : !llvm.ptr)
+ } combiner {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+ %2 = llvm.addrspacecast %1 : !llvm.ptr<5> to !llvm.ptr
+ %3 = llvm.mlir.constant(1 : i32) : i32
+ %4 = llvm.alloca %3 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+ %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+ %6 = llvm.mlir.constant(24 : i32) : i32
+ "llvm.intr.memcpy"(%5, %arg0, %6) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ %7 = llvm.mlir.constant(24 : i32) : i32
+ "llvm.intr.memcpy"(%2, %arg1, %7) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ %9 = llvm.load %8 : !llvm.ptr -> !llvm.ptr
+ %10 = llvm.getelementptr %2[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ %11 = llvm.load %10 : !llvm.ptr -> !llvm.ptr
+ %12 = llvm.load %9 : !llvm.ptr -> f32
+ %13 = llvm.load %11 : !llvm.ptr -> f32
+ %14 = llvm.fadd %12, %13 {fastmathFlags = #llvm.fastmath<contract>} : f32
+ llvm.store %14, %9 : f32, !llvm.ptr
+ omp.yield(%arg0 : !llvm.ptr)
+ } data_ptr_ptr {
+ ^bb0(%arg0: !llvm.ptr):
+ %0 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ omp.yield(%0 : !llvm.ptr)
+ }
+
+ llvm.func @foo_() {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %4 = llvm.alloca %0 x i1 : (i64) -> !llvm.ptr<5>
+ %5 = llvm.addrspacecast %4 : !llvm.ptr<5> to !llvm.ptr
+ %8 = llvm.getelementptr %5[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+ %9 = omp.map.info var_ptr(%5 : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) var_ptr_ptr(%8 : !llvm.ptr) -> !llvm.ptr {name = ""}
+ %10 = omp.map.info var_ptr(%5 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, descriptor, to, attach) capture(ByRef) members(%9 : [0] : !llvm.ptr) -> !llvm.ptr {name = "scalar_alloc"}
+ omp.target map_entries(%10 -> %arg0 : !llvm.ptr) {
+ %14 = llvm.mlir.constant(1000000 : i32) : i32
+ %15 = llvm.mlir.constant(1 : i32) : i32
+ omp.teams reduction(byref @add_reduction_byref_box_heap_f32 %arg0 -> %arg3 : !llvm.ptr) {
+ omp.parallel {
+ omp.distribute {
+ omp.wsloop reduction(byref @add_reduction_byref_box_heap_f32 %arg3 -> %arg5 : !llvm.ptr) {
+ omp.loop_nest (%arg6) : i32 = (%15) to (%14) inclusive step (%15) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// CHECK: %[[GLOBALIZED_LOCALS:.*]] = type { float }
+
+// CHECK: define internal void @_omp_reduction_list_to_global_copy_func({{.*}}) {{.*}} {
+// CHECK: %[[RED_ARR_LIST:.*]] = getelementptr inbounds [1 x ptr], ptr %{{.*}}, i64 0, i64 0
+// CHECK: %[[RED_ELEM_PTR:.*]] = load ptr, ptr %[[RED_ARR_LI...
[truncated]
|
🐧 Linux x64 Test Results
|
Extends the work started in #165714 by supporting team reductions. Similar to what was done in #165714, this PR introduces proper allocations, loads, and stores for by-ref reductions in teams-related callbacks: * `_omp_reduction_list_to_global_copy_func`, * `_omp_reduction_list_to_global_reduce_func`, * `_omp_reduction_global_to_list_copy_func`, and * `_omp_reduction_global_to_list_reduce_func`.
9313342 to
008deca
Compare
jsjodin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
|
Thanks Jan for the review. If anyone else has any comments on this PR please let me know. |
tblah
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM to the extent that I understand GPU things.
Extends the work started in #165714 by supporting team reductions. Similar to what was done in #165714, this PR introduces proper allocations, loads, and stores for by-ref reductions in teams-related callbacks:
_omp_reduction_list_to_global_copy_func,_omp_reduction_list_to_global_reduce_func,_omp_reduction_global_to_list_copy_func, and_omp_reduction_global_to_list_reduce_func.