Skip to content

Commit 71809bb

Browse files
Fix reductions on gpu (#25)
- Fix codegen to avoid instruction selection errors from wrong address space of reduction privates - Use atomics by default on gpu - Update python module - Add tests
1 parent 75f85cc commit 71809bb

File tree

5 files changed

+334
-187
lines changed

5 files changed

+334
-187
lines changed

src/numba/openmp/__init__.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,7 +2410,7 @@ def add_mapped_to_ins(ins, tags):
24102410
if DEBUG_OPENMP >= 1:
24112411
print("ins:", ins, type(ins))
24122412
print("outs:", outs, type(outs))
2413-
print("args:", state.args)
2413+
# print("args:", state.args)
24142414
print("rettype:", state.return_type, type(state.return_type))
24152415
print("target_args_unordered:", target_args_unordered)
24162416
# Re-use Numba loop lifting code to extract the target region as
@@ -3712,21 +3712,15 @@ def is_target_arg(name):
37123712

37133713

37143714
def is_pointer_target_arg(name, typ):
3715+
if name.startswith("QUAL.OMP.REDUCTION"):
3716+
return True
37153717
if name.startswith("QUAL.OMP.MAP"):
3716-
if isinstance(typ, types.npytypes.Array):
3717-
return True
3718-
else:
3719-
return True
3720-
if name in ["QUAL.OMP.FIRSTPRIVATE", "QUAL.OMP.PRIVATE"]:
3721-
return False
3718+
return True
37223719
if name in ["QUAL.OMP.TARGET.IMPLICIT"]:
37233720
if isinstance(typ, types.npytypes.Array):
37243721
return True
3725-
else:
3726-
return False
3722+
37273723
return False
3728-
# print("is_pointer_target_arg:", name, typ, type(typ))
3729-
assert False
37303724

37313725

37323726
def is_internal_var(var):
@@ -7404,6 +7398,7 @@ def NUMBER(self, args):
74047398
| data_default_clause
74057399
| data_sharing_clause
74067400
// | reduction_default_only_clause
7401+
| reduction_clause
74077402
| ompx_attribute
74087403
74097404
target_teams_distribute_simd_directive: TARGET TEAMS DISTRIBUTE SIMD [target_teams_distribute_simd_clause*]

src/numba/openmp/libs/pass/CGIntrinsicsOpenMP.cpp

Lines changed: 104 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/Transforms/Utils/CodeExtractor.h"
1212
#include "llvm/Transforms/Utils/ModuleUtils.h"
1313
#include "llvm/Transforms/Utils/ValueMapper.h"
14+
#include <llvm/Frontend/OpenMP/OMPIRBuilder.h>
1415
#include <llvm/IR/BasicBlock.h>
1516
#include <llvm/IR/Constants.h>
1617
#include <stdexcept>
@@ -65,6 +66,44 @@ static CallInst *checkCreateCall(IRBuilderBase &Builder, FunctionCallee &Fn,
6566

6667
} // namespace
6768

69+
InsertPointTy CGIntrinsicsOpenMP::emitReductions(
70+
const OpenMPIRBuilder::LocationDescription &Loc, InsertPointTy AllocaIP,
71+
ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos) {
72+
// If targeting the host runtime, use the OpenMP IR builder.
73+
if (!isOpenMPDeviceRuntime())
74+
return OMPBuilder.createReductions(Loc, AllocaIP, ReductionInfos);
75+
76+
// Reductions for the GPU runtime use atomics in global memory.
77+
// TODO: optimize with hierarchical processing: warp -> block -> grid.
78+
BasicBlock *InsertBlock = Loc.IP.getBlock();
79+
BasicBlock *ContinuationBlock =
80+
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
81+
InsertBlock->getTerminator()->eraseFromParent();
82+
83+
OMPBuilder.Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
84+
85+
for (auto &RI : ReductionInfos) {
86+
assert(RI.Variable && "Expected non-null variable");
87+
assert(RI.PrivateVariable && "eExpected non-null private variable");
88+
assert(RI.AtomicReductionGen &&
89+
"Expected non-null atomic reduction generator callback");
90+
assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
91+
"Expected variables and their private equivalents to have the same "
92+
"type");
93+
assert(RI.Variable->getType()->isPointerTy() &&
94+
"Expected variables to be pointers");
95+
96+
OMPBuilder.Builder.restoreIP(
97+
RI.AtomicReductionGen(OMPBuilder.Builder.saveIP(), RI.ElementType,
98+
RI.Variable, RI.PrivateVariable));
99+
}
100+
101+
OMPBuilder.Builder.CreateBr(ContinuationBlock);
102+
103+
OMPBuilder.Builder.SetInsertPoint(ContinuationBlock);
104+
return OMPBuilder.Builder.saveIP();
105+
}
106+
68107
void CGIntrinsicsOpenMP::setDeviceGlobalizedValues(
69108
const ArrayRef<Value *> GlobalizedValues) {
70109
DeviceGlobalizedValues.clear();
@@ -89,7 +128,8 @@ Value *CGIntrinsicsOpenMP::createScalarCast(Value *V, Type *DestTy) {
89128
Function *CGIntrinsicsOpenMP::createOutlinedFunction(
90129
DSAValueMapTy &DSAValueMap, ValueToValueMapTy *VMap, Function *OuterFn,
91130
BasicBlock *StartBB, BasicBlock *EndBB,
92-
SmallVectorImpl<Value *> &CapturedVars, StringRef Suffix) {
131+
SmallVectorImpl<Value *> &CapturedVars, StringRef Suffix,
132+
bool EmitReductions) {
93133
SmallVector<Value *, 16> Privates;
94134
SmallVector<Value *, 16> CapturedShared;
95135
SmallVector<Value *, 16> CapturedFirstprivate;
@@ -362,32 +402,42 @@ Function *CGIntrinsicsOpenMP::createOutlinedFunction(
362402
SetVector<Use *> Uses;
363403
CollectUses(V, Uses);
364404

365-
if (VMap)
366-
(*VMap)[V] = AI;
367-
368-
InsertPointTy AllocaIP(OutlinedEntryBB,
369-
OutlinedEntryBB->getFirstInsertionPt());
405+
Value *ReplacementValue = nullptr;
406+
407+
// Private the reduction variable and initialize it.
408+
if (EmitReductions) {
409+
InsertPointTy AllocaIP(OutlinedEntryBB,
410+
OutlinedEntryBB->getFirstInsertionPt());
411+
412+
Value *Priv = nullptr;
413+
switch (DSAValueMap[V].Type) {
414+
case DSA_REDUCTION_ADD:
415+
Priv = CGReduction::emitInitAndAppendInfo<DSA_REDUCTION_ADD>(
416+
OMPBuilder.Builder, AllocaIP, AI, ReductionInfos);
417+
break;
418+
case DSA_REDUCTION_SUB:
419+
Priv = CGReduction::emitInitAndAppendInfo<DSA_REDUCTION_SUB>(
420+
OMPBuilder.Builder, AllocaIP, AI, ReductionInfos);
421+
break;
422+
case DSA_REDUCTION_MUL:
423+
Priv = CGReduction::emitInitAndAppendInfo<DSA_REDUCTION_MUL>(
424+
OMPBuilder.Builder, AllocaIP, AI, ReductionInfos);
425+
break;
426+
default:
427+
FATAL_ERROR("Unsupported reduction");
428+
}
370429

371-
Value *Priv = nullptr;
372-
switch (DSAValueMap[V].Type) {
373-
case DSA_REDUCTION_ADD:
374-
Priv = CGReduction::emitInitAndAppendInfo<DSA_REDUCTION_ADD>(
375-
OMPBuilder.Builder, AllocaIP, AI, ReductionInfos);
376-
break;
377-
case DSA_REDUCTION_SUB:
378-
Priv = CGReduction::emitInitAndAppendInfo<DSA_REDUCTION_SUB>(
379-
OMPBuilder.Builder, AllocaIP, AI, ReductionInfos);
380-
break;
381-
case DSA_REDUCTION_MUL:
382-
Priv = CGReduction::emitInitAndAppendInfo<DSA_REDUCTION_MUL>(
383-
OMPBuilder.Builder, AllocaIP, AI, ReductionInfos);
384-
break;
385-
default:
386-
FATAL_ERROR("Unsupported reduction");
430+
assert(Priv && "Expected non-null private reduction variable");
431+
ReplacementValue = Priv;
432+
} else {
433+
ReplacementValue = AI;
387434
}
388435

389-
assert(Priv && "Expected non-null private reduction variable");
390-
ReplaceUses(Uses, Priv);
436+
assert(ReplacementValue && "Expected non-null replacement value");
437+
if (VMap)
438+
(*VMap)[V] = ReplacementValue;
439+
440+
ReplaceUses(Uses, ReplacementValue);
391441

392442
++AI;
393443
}
@@ -397,11 +447,12 @@ Function *CGIntrinsicsOpenMP::createOutlinedFunction(
397447
EndBB->getTerminator()->setSuccessor(0, OutlinedExitBB);
398448
OMPBuilder.Builder.SetInsertPoint(OutlinedExitBB);
399449
OMPBuilder.Builder.CreateRetVoid();
400-
if (!ReductionInfos.empty())
401-
OMPBuilder.createReductions(
402-
InsertPointTy(OutlinedExitBB, OutlinedExitBB->begin()),
403-
InsertPointTy(OutlinedEntryBB, OutlinedEntryBB->begin()),
404-
ReductionInfos);
450+
451+
if (EmitReductions)
452+
if (!ReductionInfos.empty())
453+
emitReductions(InsertPointTy(OutlinedExitBB, OutlinedExitBB->begin()),
454+
InsertPointTy(OutlinedEntryBB, OutlinedEntryBB->begin()),
455+
ReductionInfos);
405456

406457
// Deterministic insertion of BBs, BlockVector needs ExitBB to move to the
407458
// outlined function.
@@ -471,9 +522,9 @@ void CGIntrinsicsOpenMP::emitOMPParallelHostRuntime(
471522
Value *ThreadID = OMPBuilder.getOrCreateThreadID(Ident);
472523

473524
SmallVector<Value *, 16> CapturedVars;
474-
Function *OutlinedFn =
475-
createOutlinedFunction(DSAValueMap, VMap, Fn, StartBB, EndBB,
476-
CapturedVars, ".omp_outlined_parallel");
525+
Function *OutlinedFn = createOutlinedFunction(
526+
DSAValueMap, VMap, Fn, StartBB, EndBB, CapturedVars,
527+
".omp_outlined_parallel", ParRegionInfo.EmitReductions);
477528

478529
auto EmitForkCall = [&](InsertPointTy InsertIP) {
479530
OMPBuilder.Builder.restoreIP(InsertIP);
@@ -743,7 +794,7 @@ void CGIntrinsicsOpenMP::emitOMPParallelHostRuntimeOMPIRBuilder(
743794
/* IsCancellable */ false);
744795

745796
if (!ReductionInfos.empty())
746-
OMPBuilder.createReductions(BodyIP, BodyAllocaIP, ReductionInfos);
797+
emitReductions(BodyIP, BodyAllocaIP, ReductionInfos);
747798

748799
BranchInst::Create(AfterBB, AfterIP.getBlock());
749800

@@ -761,9 +812,9 @@ void CGIntrinsicsOpenMP::emitOMPParallelDeviceRuntime(
761812
ParRegionInfoStruct &ParRegionInfo) {
762813
// Extract parallel region
763814
SmallVector<Value *, 16> CapturedVars;
764-
Function *OutlinedFn =
765-
createOutlinedFunction(DSAValueMap, VMap, Fn, StartBB, EndBB,
766-
CapturedVars, ".omp_outlined_parallel");
815+
Function *OutlinedFn = createOutlinedFunction(
816+
DSAValueMap, VMap, Fn, StartBB, EndBB, CapturedVars,
817+
".omp_outlined_parallel", ParRegionInfo.EmitReductions);
767818

768819
// Create wrapper for worker threads
769820
SmallVector<Type *, 2> Params;
@@ -1462,9 +1513,9 @@ void CGIntrinsicsOpenMP::emitLoop(DSAValueMapTy &DSAValueMap,
14621513
PrivatizeWithReductions();
14631514
if (!ReductionInfos.empty()) {
14641515
OMPBuilder.Builder.SetInsertPoint(ForEndBB->getTerminator());
1465-
OMPBuilder.createReductions(OpenMPIRBuilder::LocationDescription(
1466-
OMPBuilder.Builder.saveIP(), Loc.DL),
1467-
AllocaIP, ReductionInfos);
1516+
emitReductions(OpenMPIRBuilder::LocationDescription(
1517+
OMPBuilder.Builder.saveIP(), Loc.DL),
1518+
AllocaIP, ReductionInfos);
14681519
}
14691520

14701521
OMPBuilder.Builder.SetInsertPoint(ExitBB->getTerminator());
@@ -1869,6 +1920,9 @@ void CGIntrinsicsOpenMP::emitOMPOffloadingMappings(
18691920
if (IsTargetRegion)
18701921
MapType |= OMP_TGT_MAPTYPE_TARGET_PARAM;
18711922
break;
1923+
case DSA_REDUCTION_ADD:
1924+
case DSA_REDUCTION_SUB:
1925+
case DSA_REDUCTION_MUL:
18721926
case DSA_MAP_TOFROM:
18731927
MapType = OMP_TGT_MAPTYPE_TO | OMP_TGT_MAPTYPE_FROM;
18741928
if (IsTargetRegion)
@@ -1914,6 +1968,9 @@ void CGIntrinsicsOpenMP::emitOMPOffloadingMappings(
19141968
case DSA_MAP_TO:
19151969
case DSA_MAP_FROM:
19161970
case DSA_MAP_TOFROM:
1971+
case DSA_REDUCTION_ADD:
1972+
case DSA_REDUCTION_SUB:
1973+
case DSA_REDUCTION_MUL:
19171974
Size = ConstantInt::get(OMPBuilder.SizeTy,
19181975
M.getDataLayout().getTypeAllocSize(V->getType()));
19191976
EmitMappingEntry(Size, GetMapType(DSA), V, V);
@@ -1995,7 +2052,7 @@ void CGIntrinsicsOpenMP::emitOMPOffloadingMappings(
19952052
break;
19962053
}
19972054
default:
1998-
FATAL_ERROR("Unknown mapping type");
2055+
FATAL_ERROR("Unsupported mapping type " + toString(DSA));
19992056
}
20002057
}
20012058

@@ -2561,9 +2618,9 @@ void CGIntrinsicsOpenMP::emitOMPTeamsDeviceRuntime(
25612618
Function *Fn, BasicBlock *BBEntry, BasicBlock *StartBB, BasicBlock *EndBB,
25622619
BasicBlock *AfterBB, TeamsInfoStruct &TeamsInfo) {
25632620
SmallVector<Value *, 16> CapturedVars;
2564-
Function *OutlinedFn =
2565-
createOutlinedFunction(DSAValueMap, VMap, Fn, StartBB, EndBB,
2566-
CapturedVars, ".omp_outlined_teams");
2621+
Function *OutlinedFn = createOutlinedFunction(
2622+
DSAValueMap, VMap, Fn, StartBB, EndBB, CapturedVars,
2623+
".omp_outlined_teams", TeamsInfo.EmitReductions);
25672624

25682625
// Set up the call to the teams outlined function.
25692626
BBEntry->getTerminator()->eraseFromParent();
@@ -2653,7 +2710,7 @@ void CGIntrinsicsOpenMP::emitOMPTeamsHostRuntime(
26532710
SmallVector<Value *, 16> CapturedVars;
26542711
Function *OutlinedFn = createOutlinedFunction(
26552712
DSAValueMap, /*ValueToValueMapTy */ VMap, Fn, StartBB, EndBB,
2656-
CapturedVars, ".omp_outlined_teams");
2713+
CapturedVars, ".omp_outlined_teams", TeamsInfo.EmitReductions);
26572714

26582715
// Set up the call to the teams outlined function.
26592716
BBEntry->getTerminator()->eraseFromParent();
@@ -2958,71 +3015,12 @@ void CGIntrinsicsOpenMP::emitOMPDistributeParallelFor(
29583015
}
29593016
}
29603017

2961-
void CGIntrinsicsOpenMP::emitOMPTargetTeamsDistributeParallelFor(
2962-
DSAValueMapTy &DSAValueMap, const DebugLoc &DL, Function *Fn,
2963-
BasicBlock *EntryBB, BasicBlock *StartBB, BasicBlock *EndBB,
2964-
BasicBlock *ExitBB, BasicBlock *AfterBB, OMPLoopInfoStruct &OMPLoopInfo,
2965-
ParRegionInfoStruct &ParRegionInfo, TargetInfoStruct &TargetInfo,
2966-
StructMapTy &StructMappingInfoMap, bool IsDeviceTargetRegion) {
2967-
2968-
emitOMPDistributeParallelFor(DSAValueMap, StartBB, ExitBB, OMPLoopInfo,
2969-
ParRegionInfo,
2970-
/* isStandalone */ false);
2971-
2972-
emitOMPTargetTeams(DSAValueMap, nullptr, DL, Fn, EntryBB, StartBB, EndBB,
2973-
AfterBB, TargetInfo, &OMPLoopInfo, StructMappingInfoMap,
2974-
IsDeviceTargetRegion);
2975-
2976-
// Alternative codegen, starting from top-down and renaming values using the
2977-
// ValueToValueMap.
2978-
#if 0
2979-
ValueToValueMapTy VMap;
2980-
// Lower target_teams.
2981-
emitOMPTargetTeams(DSAValueMap, &VMap, DL, Fn, EntryBB, StartBB, EndBB, AfterBB,
2982-
TargetInfo, &OMPLoopInfo, StructMappingInfoMap,
2983-
IsDeviceTargetRegion);
2984-
2985-
dbgs() << "=== VMap\n";
2986-
for(auto VV : VMap) {
2987-
dbgs() << "V " << *VV.first << " -> " << *VV.second << "\n";
2988-
}
2989-
dbgs() << "=== End of VMap\n";
2990-
getchar();
2991-
2992-
// Update DSAValueMap
2993-
SmallVector<Value *, 8> ToDelete;
2994-
for(auto &It : DSAValueMap) {
2995-
Value *V = It.first;
2996-
if(!VMap.count(V))
2997-
continue;
2998-
2999-
DSAValueMap[VMap[V]] = It.second;
3000-
dbgs() << "Update DSAValueMap " << *VMap[V] << " ~> " << It.second.Type << "\n";
3001-
ToDelete.push_back(V);
3002-
}
3003-
for(auto *V : ToDelete) {
3004-
dbgs() << "Update DSAValueMAp delete " << *V << "\n";
3005-
DSAValueMap.erase(V);
3006-
}
3007-
3008-
// Update OMPLoopInfo
3009-
OMPLoopInfo.IV = VMap[OMPLoopInfo.IV];
3010-
OMPLoopInfo.Start = VMap[OMPLoopInfo.Start];
3011-
OMPLoopInfo.LB = VMap[OMPLoopInfo.LB];
3012-
OMPLoopInfo.UB = VMap[OMPLoopInfo.UB];
3013-
3014-
emitOMPDistributeParallelFor(DSAValueMap, StartBB, ExitBB, OMPLoopInfo,
3015-
ParRegionInfo,
3016-
/* isStandalone */ false);
3017-
#endif
3018-
}
3019-
30203018
void CGIntrinsicsOpenMP::emitOMPTargetTeams(
30213019
DSAValueMapTy &DSAValueMap, ValueToValueMapTy *VMap, const DebugLoc &DL,
30223020
Function *Fn, BasicBlock *EntryBB, BasicBlock *StartBB, BasicBlock *EndBB,
30233021
BasicBlock *AfterBB, TargetInfoStruct &TargetInfo,
3024-
OMPLoopInfoStruct *OMPLoopInfo, StructMapTy &StructMappingInfoMap,
3025-
bool IsDeviceTargetRegion) {
3022+
TeamsInfoStruct &TeamsInfo, OMPLoopInfoStruct *OMPLoopInfo,
3023+
StructMapTy &StructMappingInfoMap, bool IsDeviceTargetRegion) {
30263024

30273025
BasicBlock *TeamsEntryBB = SplitBlock(EntryBB, EntryBB->getTerminator());
30283026
TeamsEntryBB->setName("omp.teams.entry");
@@ -3032,10 +3030,7 @@ void CGIntrinsicsOpenMP::emitOMPTargetTeams(
30323030
BasicBlock *TeamsEndBB =
30333031
splitBlockBefore(EndBB, &*EndBB->getFirstInsertionPt(), nullptr, nullptr,
30343032
nullptr, "omp.teams.end");
3035-
// TargetInfo contains teams info.
3036-
TeamsInfoStruct TeamsInfo;
3037-
TeamsInfo.NumTeams = TargetInfo.NumTeams;
3038-
TeamsInfo.ThreadLimit = TargetInfo.ThreadLimit;
3033+
30393034
emitOMPTeams(DSAValueMap, VMap, DL, Fn, TeamsEntryBB, TeamsStartBB,
30403035
TeamsEndBB, EndBB, TeamsInfo);
30413036

0 commit comments

Comments
 (0)