Skip to content

Commit 04ef6a7

Browse files
wsmosesvchuravy
andcommitted
Conditional Register Reduction (#179)
* lu cuda test opt * tmp * re-enable work * nocache * Conditional enable register reduction * Update enzyme/Enzyme/GradientUtils.cpp Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com> * Update enzyme/Enzyme/GradientUtils.cpp Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com> Co-authored-by: Valentin Churavy <v.churavy@gmail.com> Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>
1 parent 87325c3 commit 04ef6a7

File tree

1 file changed

+104
-70
lines changed

1 file changed

+104
-70
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 104 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ llvm::cl::opt<bool> EnzymeInactiveDynamic(
7070
llvm::cl::opt<bool>
7171
EnzymeSharedForward("enzyme-shared-forward", cl::init(false), cl::Hidden,
7272
cl::desc("Forward Shared Memory from definitions"));
73+
74+
llvm::cl::opt<bool>
75+
EnzymeRegisterReduce("enzyme-register-reduce", cl::init(false), cl::Hidden,
76+
cl::desc("Reduce the amount of register reduce"));
7377
}
7478

7579
bool isPotentialLastLoopValue(Value *val, const BasicBlock *loc,
@@ -2515,78 +2519,107 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
25152519
assert(inst->getParent()->getParent() == newFunc);
25162520
assert(BuilderM.GetInsertBlock()->getParent() == newFunc);
25172521

2518-
if (isOriginalBlock(*BuilderM.GetInsertBlock())) {
2519-
if (BuilderM.GetInsertBlock()->size() &&
2520-
BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) {
2521-
Instruction *use = &*BuilderM.GetInsertPoint();
2522-
while (isa<PHINode>(use))
2523-
use = use->getNextNode();
2524-
if (DT.dominates(inst, use)) {
2525-
return inst;
2526-
} else {
2527-
llvm::errs() << *BuilderM.GetInsertBlock()->getParent() << "\n";
2528-
llvm::errs() << "didnt dominate inst: " << *inst
2529-
<< " point: " << *BuilderM.GetInsertPoint()
2530-
<< "\nbb: " << *BuilderM.GetInsertBlock() << "\n";
2522+
bool reduceRegister = false;
2523+
2524+
if (EnzymeRegisterReduce) {
2525+
if (auto II = dyn_cast<IntrinsicInst>(inst)) {
2526+
switch (II->getIntrinsicID()) {
2527+
case Intrinsic::nvvm_ldu_global_i:
2528+
case Intrinsic::nvvm_ldu_global_p:
2529+
case Intrinsic::nvvm_ldu_global_f:
2530+
case Intrinsic::nvvm_ldg_global_i:
2531+
case Intrinsic::nvvm_ldg_global_p:
2532+
case Intrinsic::nvvm_ldg_global_f:
2533+
reduceRegister = true;
2534+
break;
2535+
default:
2536+
break;
25312537
}
2532-
} else {
2533-
if (inst->getParent() == BuilderM.GetInsertBlock() ||
2534-
DT.dominates(inst, BuilderM.GetInsertBlock())) {
2535-
// allowed from block domination
2536-
return inst;
2537-
} else {
2538-
llvm::errs() << *BuilderM.GetInsertBlock()->getParent() << "\n";
2539-
llvm::errs() << "didnt dominate inst: " << *inst
2540-
<< "\nbb: " << *BuilderM.GetInsertBlock() << "\n";
2541-
}
2542-
}
2543-
// This is a reverse block
2544-
} else if (BuilderM.GetInsertBlock() != inversionAllocs) {
2545-
// Something in the entry (or anything that dominates all returns, doesn't
2546-
// need caching)
2547-
2548-
BasicBlock *forwardBlock =
2549-
originalForReverseBlock(*BuilderM.GetInsertBlock());
2550-
2551-
// Don't allow this if we're not definitely using the last iteration of this
2552-
// value
2553-
// + either because the value isn't in a loop
2554-
// + or because the forward of the block usage location isn't in a loop
2555-
// (thus last iteration)
2556-
// + or because the loop nests share no ancestry
2557-
2558-
bool loopLegal = true;
2559-
for (Loop *idx = LI.getLoopFor(inst->getParent()); idx != nullptr;
2560-
idx = idx->getParentLoop()) {
2561-
for (Loop *fdx = LI.getLoopFor(forwardBlock); fdx != nullptr;
2562-
fdx = fdx->getParentLoop()) {
2563-
if (idx == fdx) {
2564-
loopLegal = false;
2565-
break;
2566-
}
2538+
}
2539+
if (auto LI = dyn_cast<LoadInst>(inst)) {
2540+
if (cast<PointerType>(LI->getPointerOperand()->getType())
2541+
->getAddressSpace() == 3) {
2542+
reduceRegister |= tryLegalRecomputeCheck &&
2543+
legalRecompute(LI, incoming_available, &BuilderM) &&
2544+
shouldRecompute(LI, incoming_available, &BuilderM);
25672545
}
25682546
}
2547+
}
25692548

2570-
if (loopLegal) {
2571-
if (inst->getParent() == &newFunc->getEntryBlock()) {
2572-
return inst;
2549+
if (!reduceRegister) {
2550+
if (isOriginalBlock(*BuilderM.GetInsertBlock())) {
2551+
if (BuilderM.GetInsertBlock()->size() &&
2552+
BuilderM.GetInsertPoint() != BuilderM.GetInsertBlock()->end()) {
2553+
Instruction *use = &*BuilderM.GetInsertPoint();
2554+
while (isa<PHINode>(use))
2555+
use = use->getNextNode();
2556+
if (DT.dominates(inst, use)) {
2557+
return inst;
2558+
} else {
2559+
llvm::errs() << *BuilderM.GetInsertBlock()->getParent() << "\n";
2560+
llvm::errs() << "didn't dominate inst: " << *inst
2561+
<< " point: " << *BuilderM.GetInsertPoint()
2562+
<< "\nbb: " << *BuilderM.GetInsertBlock() << "\n";
2563+
}
2564+
} else {
2565+
if (inst->getParent() == BuilderM.GetInsertBlock() ||
2566+
DT.dominates(inst, BuilderM.GetInsertBlock())) {
2567+
// allowed from block domination
2568+
return inst;
2569+
} else {
2570+
llvm::errs() << *BuilderM.GetInsertBlock()->getParent() << "\n";
2571+
llvm::errs() << "didn't dominate inst: " << *inst
2572+
<< "\nbb: " << *BuilderM.GetInsertBlock() << "\n";
2573+
}
25732574
}
2574-
// TODO upgrade this to be all returns that this could enter from
2575-
bool legal = true;
2576-
for (auto &BB : *oldFunc) {
2577-
if (isa<ReturnInst>(BB.getTerminator())) {
2578-
BasicBlock *returningBlock =
2579-
cast<BasicBlock>(getNewFromOriginal(&BB));
2580-
if (inst->getParent() == returningBlock)
2581-
continue;
2582-
if (!DT.dominates(inst, returningBlock)) {
2583-
legal = false;
2575+
// This is a reverse block
2576+
} else if (BuilderM.GetInsertBlock() != inversionAllocs) {
2577+
// Something in the entry (or anything that dominates all returns, doesn't
2578+
// need caching)
2579+
2580+
BasicBlock *forwardBlock =
2581+
originalForReverseBlock(*BuilderM.GetInsertBlock());
2582+
2583+
// Don't allow this if we're not definitely using the last iteration of
2584+
// this value
2585+
// + either because the value isn't in a loop
2586+
// + or because the forward of the block usage location isn't in a loop
2587+
// (thus last iteration)
2588+
// + or because the loop nests share no ancestry
2589+
2590+
bool loopLegal = true;
2591+
for (Loop *idx = LI.getLoopFor(inst->getParent()); idx != nullptr;
2592+
idx = idx->getParentLoop()) {
2593+
for (Loop *fdx = LI.getLoopFor(forwardBlock); fdx != nullptr;
2594+
fdx = fdx->getParentLoop()) {
2595+
if (idx == fdx) {
2596+
loopLegal = false;
25842597
break;
25852598
}
25862599
}
25872600
}
2588-
if (legal) {
2589-
return inst;
2601+
2602+
if (loopLegal) {
2603+
if (inst->getParent() == &newFunc->getEntryBlock()) {
2604+
return inst;
2605+
}
2606+
// TODO upgrade this to be all returns that this could enter from
2607+
bool legal = true;
2608+
for (auto &BB : *oldFunc) {
2609+
if (isa<ReturnInst>(BB.getTerminator())) {
2610+
BasicBlock *returningBlock =
2611+
cast<BasicBlock>(getNewFromOriginal(&BB));
2612+
if (inst->getParent() == returningBlock)
2613+
continue;
2614+
if (!DT.dominates(inst, returningBlock)) {
2615+
legal = false;
2616+
break;
2617+
}
2618+
}
2619+
}
2620+
if (legal) {
2621+
return inst;
2622+
}
25902623
}
25912624
}
25922625
}
@@ -2713,7 +2746,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
27132746
}
27142747
}
27152748
assert(op->getType() == inst->getType());
2716-
lookup_cache[idx] = op;
2749+
if (!reduceRegister)
2750+
lookup_cache[idx] = op;
27172751
return op;
27182752
}
27192753
} else {
@@ -2948,7 +2982,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
29482982
}
29492983
ValueToValueMapTy ThreadLookup;
29502984
bool legal = true;
2951-
for (int i = 0; i < svals.size(); i++) {
2985+
for (size_t i = 0; i < svals.size(); i++) {
29522986
auto ss = OrigSE.getSCEV(svals[i]);
29532987
auto ls = OrigSE.getSCEV(lvals[i]);
29542988
if (cast<IntegerType>(ss->getType())->getBitWidth() >
@@ -3942,10 +3976,10 @@ void GradientUtils::computeMinCache(
39423976
if (auto BO = dyn_cast<BinaryOperator>(
39433977
PN->getIncomingValueForBlock(B))) {
39443978
if (BO->getOpcode() == BinaryOperator::Add) {
3945-
if (BO->getOperand(0) == PN &&
3946-
invariant(BO->getOperand(1)) ||
3947-
BO->getOperand(1) == PN &&
3948-
invariant(BO->getOperand(0))) {
3979+
if ((BO->getOperand(0) == PN &&
3980+
invariant(BO->getOperand(1))) ||
3981+
(BO->getOperand(1) == PN &&
3982+
invariant(BO->getOperand(0)))) {
39493983
Increment.insert(BO);
39503984
} else {
39513985
legal = false;
@@ -4054,4 +4088,4 @@ void GradientUtils::computeMinCache(
40544088
knownRecomputeHeuristic[V] = !MinReq.count(V);
40554089
}
40564090
}
4057-
}
4091+
}

0 commit comments

Comments
 (0)