@@ -70,6 +70,10 @@ llvm::cl::opt<bool> EnzymeInactiveDynamic(
7070llvm::cl::opt<bool >
7171 EnzymeSharedForward (" enzyme-shared-forward" , cl::init(false ), cl::Hidden,
7272 cl::desc(" Forward Shared Memory from definitions" ));
73+
74+ llvm::cl::opt<bool >
75+ EnzymeRegisterReduce (" enzyme-register-reduce" , cl::init(false ), cl::Hidden,
76+ cl::desc(" Reduce the amount of register reduce" ));
7377}
7478
7579bool isPotentialLastLoopValue (Value *val, const BasicBlock *loc,
@@ -2515,78 +2519,107 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
25152519 assert (inst->getParent ()->getParent () == newFunc);
25162520 assert (BuilderM.GetInsertBlock ()->getParent () == newFunc);
25172521
2518- if (isOriginalBlock (*BuilderM.GetInsertBlock ())) {
2519- if (BuilderM.GetInsertBlock ()->size () &&
2520- BuilderM.GetInsertPoint () != BuilderM.GetInsertBlock ()->end ()) {
2521- Instruction *use = &*BuilderM.GetInsertPoint ();
2522- while (isa<PHINode>(use))
2523- use = use->getNextNode ();
2524- if (DT.dominates (inst, use)) {
2525- return inst;
2526- } else {
2527- llvm::errs () << *BuilderM.GetInsertBlock ()->getParent () << " \n " ;
2528- llvm::errs () << " didnt dominate inst: " << *inst
2529- << " point: " << *BuilderM.GetInsertPoint ()
2530- << " \n bb: " << *BuilderM.GetInsertBlock () << " \n " ;
2522+ bool reduceRegister = false ;
2523+
2524+ if (EnzymeRegisterReduce) {
2525+ if (auto II = dyn_cast<IntrinsicInst>(inst)) {
2526+ switch (II->getIntrinsicID ()) {
2527+ case Intrinsic::nvvm_ldu_global_i:
2528+ case Intrinsic::nvvm_ldu_global_p:
2529+ case Intrinsic::nvvm_ldu_global_f:
2530+ case Intrinsic::nvvm_ldg_global_i:
2531+ case Intrinsic::nvvm_ldg_global_p:
2532+ case Intrinsic::nvvm_ldg_global_f:
2533+ reduceRegister = true ;
2534+ break ;
2535+ default :
2536+ break ;
25312537 }
2532- } else {
2533- if (inst->getParent () == BuilderM.GetInsertBlock () ||
2534- DT.dominates (inst, BuilderM.GetInsertBlock ())) {
2535- // allowed from block domination
2536- return inst;
2537- } else {
2538- llvm::errs () << *BuilderM.GetInsertBlock ()->getParent () << " \n " ;
2539- llvm::errs () << " didnt dominate inst: " << *inst
2540- << " \n bb: " << *BuilderM.GetInsertBlock () << " \n " ;
2541- }
2542- }
2543- // This is a reverse block
2544- } else if (BuilderM.GetInsertBlock () != inversionAllocs) {
2545- // Something in the entry (or anything that dominates all returns, doesn't
2546- // need caching)
2547-
2548- BasicBlock *forwardBlock =
2549- originalForReverseBlock (*BuilderM.GetInsertBlock ());
2550-
2551- // Don't allow this if we're not definitely using the last iteration of this
2552- // value
2553- // + either because the value isn't in a loop
2554- // + or because the forward of the block usage location isn't in a loop
2555- // (thus last iteration)
2556- // + or because the loop nests share no ancestry
2557-
2558- bool loopLegal = true ;
2559- for (Loop *idx = LI.getLoopFor (inst->getParent ()); idx != nullptr ;
2560- idx = idx->getParentLoop ()) {
2561- for (Loop *fdx = LI.getLoopFor (forwardBlock); fdx != nullptr ;
2562- fdx = fdx->getParentLoop ()) {
2563- if (idx == fdx) {
2564- loopLegal = false ;
2565- break ;
2566- }
2538+ }
2539+ if (auto LI = dyn_cast<LoadInst>(inst)) {
2540+ if (cast<PointerType>(LI->getPointerOperand ()->getType ())
2541+ ->getAddressSpace () == 3 ) {
2542+ reduceRegister |= tryLegalRecomputeCheck &&
2543+ legalRecompute (LI, incoming_available, &BuilderM) &&
2544+ shouldRecompute (LI, incoming_available, &BuilderM);
25672545 }
25682546 }
2547+ }
25692548
2570- if (loopLegal) {
2571- if (inst->getParent () == &newFunc->getEntryBlock ()) {
2572- return inst;
2549+ if (!reduceRegister) {
2550+ if (isOriginalBlock (*BuilderM.GetInsertBlock ())) {
2551+ if (BuilderM.GetInsertBlock ()->size () &&
2552+ BuilderM.GetInsertPoint () != BuilderM.GetInsertBlock ()->end ()) {
2553+ Instruction *use = &*BuilderM.GetInsertPoint ();
2554+ while (isa<PHINode>(use))
2555+ use = use->getNextNode ();
2556+ if (DT.dominates (inst, use)) {
2557+ return inst;
2558+ } else {
2559+ llvm::errs () << *BuilderM.GetInsertBlock ()->getParent () << " \n " ;
2560+ llvm::errs () << " didn't dominate inst: " << *inst
2561+ << " point: " << *BuilderM.GetInsertPoint ()
2562+ << " \n bb: " << *BuilderM.GetInsertBlock () << " \n " ;
2563+ }
2564+ } else {
2565+ if (inst->getParent () == BuilderM.GetInsertBlock () ||
2566+ DT.dominates (inst, BuilderM.GetInsertBlock ())) {
2567+ // allowed from block domination
2568+ return inst;
2569+ } else {
2570+ llvm::errs () << *BuilderM.GetInsertBlock ()->getParent () << " \n " ;
2571+ llvm::errs () << " didn't dominate inst: " << *inst
2572+ << " \n bb: " << *BuilderM.GetInsertBlock () << " \n " ;
2573+ }
25732574 }
2574- // TODO upgrade this to be all returns that this could enter from
2575- bool legal = true ;
2576- for (auto &BB : *oldFunc) {
2577- if (isa<ReturnInst>(BB.getTerminator ())) {
2578- BasicBlock *returningBlock =
2579- cast<BasicBlock>(getNewFromOriginal (&BB));
2580- if (inst->getParent () == returningBlock)
2581- continue ;
2582- if (!DT.dominates (inst, returningBlock)) {
2583- legal = false ;
2575+ // This is a reverse block
2576+ } else if (BuilderM.GetInsertBlock () != inversionAllocs) {
2577+ // Something in the entry (or anything that dominates all returns, doesn't
2578+ // need caching)
2579+
2580+ BasicBlock *forwardBlock =
2581+ originalForReverseBlock (*BuilderM.GetInsertBlock ());
2582+
2583+ // Don't allow this if we're not definitely using the last iteration of
2584+ // this value
2585+ // + either because the value isn't in a loop
2586+ // + or because the forward of the block usage location isn't in a loop
2587+ // (thus last iteration)
2588+ // + or because the loop nests share no ancestry
2589+
2590+ bool loopLegal = true ;
2591+ for (Loop *idx = LI.getLoopFor (inst->getParent ()); idx != nullptr ;
2592+ idx = idx->getParentLoop ()) {
2593+ for (Loop *fdx = LI.getLoopFor (forwardBlock); fdx != nullptr ;
2594+ fdx = fdx->getParentLoop ()) {
2595+ if (idx == fdx) {
2596+ loopLegal = false ;
25842597 break ;
25852598 }
25862599 }
25872600 }
2588- if (legal) {
2589- return inst;
2601+
2602+ if (loopLegal) {
2603+ if (inst->getParent () == &newFunc->getEntryBlock ()) {
2604+ return inst;
2605+ }
2606+ // TODO upgrade this to be all returns that this could enter from
2607+ bool legal = true ;
2608+ for (auto &BB : *oldFunc) {
2609+ if (isa<ReturnInst>(BB.getTerminator ())) {
2610+ BasicBlock *returningBlock =
2611+ cast<BasicBlock>(getNewFromOriginal (&BB));
2612+ if (inst->getParent () == returningBlock)
2613+ continue ;
2614+ if (!DT.dominates (inst, returningBlock)) {
2615+ legal = false ;
2616+ break ;
2617+ }
2618+ }
2619+ }
2620+ if (legal) {
2621+ return inst;
2622+ }
25902623 }
25912624 }
25922625 }
@@ -2713,7 +2746,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
27132746 }
27142747 }
27152748 assert (op->getType () == inst->getType ());
2716- lookup_cache[idx] = op;
2749+ if (!reduceRegister)
2750+ lookup_cache[idx] = op;
27172751 return op;
27182752 }
27192753 } else {
@@ -2948,7 +2982,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
29482982 }
29492983 ValueToValueMapTy ThreadLookup;
29502984 bool legal = true ;
2951- for (int i = 0 ; i < svals.size (); i++) {
2985+ for (size_t i = 0 ; i < svals.size (); i++) {
29522986 auto ss = OrigSE.getSCEV (svals[i]);
29532987 auto ls = OrigSE.getSCEV (lvals[i]);
29542988 if (cast<IntegerType>(ss->getType ())->getBitWidth () >
@@ -3942,10 +3976,10 @@ void GradientUtils::computeMinCache(
39423976 if (auto BO = dyn_cast<BinaryOperator>(
39433977 PN->getIncomingValueForBlock (B))) {
39443978 if (BO->getOpcode () == BinaryOperator::Add) {
3945- if (BO->getOperand (0 ) == PN &&
3946- invariant (BO->getOperand (1 )) ||
3947- BO->getOperand (1 ) == PN &&
3948- invariant (BO->getOperand (0 ))) {
3979+ if (( BO->getOperand (0 ) == PN &&
3980+ invariant (BO->getOperand (1 ) )) ||
3981+ ( BO->getOperand (1 ) == PN &&
3982+ invariant (BO->getOperand (0 ) ))) {
39493983 Increment.insert (BO);
39503984 } else {
39513985 legal = false ;
@@ -4054,4 +4088,4 @@ void GradientUtils::computeMinCache(
40544088 knownRecomputeHeuristic[V] = !MinReq.count (V);
40554089 }
40564090 }
4057- }
4091+ }
0 commit comments