@@ -61,18 +61,18 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
6161 // if this is a branch into the loop, this certainly should go to the merged
6262 // block as this represents starting the loop
6363 if (!inLoopContext || !isParentOrSameContext (branchingContext, lc) ) {
64- llvm::errs () << " LC BB:" << BB->getName () << " branchingBlock:" << branchingBlock->getName () << " \n " ;
64+ // llvm::errs() << "LC BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
6565 return lc.latchMerge ;
6666 }
6767
6868 // if we branch from inside the loop, we only need to go to the merged loop
6969 // if the original branch is to the header (otherwise its an internal branch in the loop)
7070 if (branchingBlock == lc.header ) {
71- llvm::errs () << " LH BB:" << BB->getName () << " branchingBlock:" << branchingBlock->getName () << " \n " ;
71+ // llvm::errs() << "LH BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
7272 return lc.latchMerge ;
7373 }
7474
75- llvm::errs () << " BB:" << BB->getName () << " branchingBlock:" << branchingBlock->getName () << " \n " ;
75+ // llvm::errs() << " BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
7676 return reverseBlocks[BB];
7777 }
7878
@@ -487,7 +487,7 @@ std::pair<PHINode*,Value*> insertNewCanonicalIV(Loop* L, Type* Ty) {
487487 return std::pair<PHINode*,Value*>(CanonicalIV,inc);
488488}
489489
490- void removeRedundantIVs (BasicBlock* Header, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment=nullptr ) {
490+ void removeRedundantIVs (BasicBlock* Header, BasicBlock* Preheader, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment=nullptr ) {
491491 assert (Header);
492492 assert (CanonicalIV);
493493
@@ -518,6 +518,59 @@ void removeRedundantIVs(BasicBlock* Header, PHINode* CanonicalIV, ScalarEvolutio
518518 gutils.erase (PN);
519519 }
520520
521+ for (auto use : CanonicalIV->users ()) {
522+ if (auto cmp = dyn_cast<ICmpInst>(use)) {
523+ if (cmp->isUnsigned ()) {
524+ // Force i to be on LHS
525+ if (cmp->getOperand (0 ) != CanonicalIV) {
526+ // Below also swaps predicate correctly
527+ cmp->swapOperands ();
528+ }
529+ assert (cmp->getOperand (0 ) == CanonicalIV);
530+
531+ // valid replacements (since unsigned comparison and i starts at 0 counting up)
532+
533+ // * i < n => i != n, valid since first time i >= n occurs at i == n
534+ if (cmp->getPredicate () == ICmpInst::ICMP_ULT) {
535+ cmp->setPredicate (ICmpInst::ICMP_NE);
536+ goto cend;
537+ }
538+
539+ // * i <= n => i != n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
540+ if (cmp->getPredicate () == ICmpInst::ICMP_ULE) {
541+ IRBuilder <>builder (Preheader->getTerminator ());
542+ if (auto inst = dyn_cast<Instruction>(cmp->getOperand (1 ))) {
543+ builder.SetInsertPoint (inst->getNextNode ());
544+ }
545+ cmp->setOperand (1 , builder.CreateNUWAdd (cmp->getOperand (1 ), ConstantInt::get (cmp->getOperand (1 )->getType (), 1 , false )));
546+ cmp->setPredicate (ICmpInst::ICMP_NE);
547+ goto cend;
548+ }
549+
550+ // * i >= n => i == n, valid since first time i >= n occurs at i == n
551+ if (cmp->getPredicate () == ICmpInst::ICMP_UGE) {
552+ cmp->setPredicate (ICmpInst::ICMP_EQ);
553+ goto cend;
554+ }
555+
556+ // * i > n => i == n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
557+ if (cmp->getPredicate () == ICmpInst::ICMP_UGT) {
558+ IRBuilder <>builder (Preheader->getTerminator ());
559+ if (auto inst = dyn_cast<Instruction>(cmp->getOperand (1 ))) {
560+ builder.SetInsertPoint (inst->getNextNode ());
561+ }
562+ cmp->setOperand (1 , builder.CreateNUWAdd (cmp->getOperand (1 ), ConstantInt::get (cmp->getOperand (1 )->getType (), 1 , false )));
563+ cmp->setPredicate (ICmpInst::ICMP_EQ);
564+ goto cend;
565+ }
566+ }
567+ cend:;
568+ if (cmp->getPredicate () == ICmpInst::ICMP_NE) {
569+
570+ }
571+ }
572+ }
573+
521574
522575 // Replace previous increment usage with new increment value
523576 if (increment) {
@@ -547,6 +600,48 @@ void removeRedundantIVs(BasicBlock* Header, PHINode* CanonicalIV, ScalarEvolutio
547600 for (auto inst: toerase) {
548601 gutils.erase (inst);
549602 }
603+
604+ for (auto use : increment->users ()) {
605+ if (auto cmp = dyn_cast<ICmpInst>(use)) {
606+ if (cmp->isUnsigned ()) {
607+ // Force i+1 to be on LHS
608+ if (cmp->getOperand (0 ) != increment) {
609+ // Below also swaps predicate correctly
610+ cmp->swapOperands ();
611+ }
612+ assert (cmp->getOperand (0 ) == increment);
613+
614+ // valid replacements (since unsigned comparison and i starts at 0 counting up)
615+
616+ // * i+1 < n => i+1 != n, valid since first time i+1 >= n occurs at i+1 == n
617+ if (cmp->getPredicate () == ICmpInst::ICMP_ULT) {
618+ cmp->setPredicate (ICmpInst::ICMP_NE);
619+ continue ;
620+ }
621+
622+ // * i+1 <= n => i != n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
623+ if (cmp->getPredicate () == ICmpInst::ICMP_ULE) {
624+ cmp->setOperand (0 , CanonicalIV);
625+ cmp->setPredicate (ICmpInst::ICMP_NE);
626+ continue ;
627+ }
628+
629+ // * i+1 >= n => i+1 == n, valid since first time i+1 >= n occurs at i+1 == n
630+ if (cmp->getPredicate () == ICmpInst::ICMP_UGE) {
631+ cmp->setPredicate (ICmpInst::ICMP_EQ);
632+ continue ;
633+ }
634+
635+ // * i+1 > n => i == n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
636+ if (cmp->getPredicate () == ICmpInst::ICMP_UGT) {
637+ cmp->setOperand (0 , CanonicalIV);
638+ cmp->setPredicate (ICmpInst::ICMP_EQ);
639+ continue ;
640+ }
641+ }
642+ }
643+ }
644+
550645 }
551646}
552647
@@ -574,15 +669,11 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
574669 auto pair = insertNewCanonicalIV (L, Type::getInt64Ty (BB->getContext ()));
575670 PHINode* CanonicalIV = pair.first ;
576671 assert (CanonicalIV);
577- removeRedundantIVs (loopContexts[L].header , CanonicalIV, SE, gutils, pair.second );
672+ removeRedundantIVs (loopContexts[L].header , loopContexts[L]. preheader , CanonicalIV, SE, gutils, pair.second );
578673 loopContexts[L].var = CanonicalIV;
579674 loopContexts[L].antivar = PHINode::Create (CanonicalIV->getType (), CanonicalIV->getNumIncomingValues (), CanonicalIV->getName ()+" 'phi" );
580675
581676 PredicatedScalarEvolution PSE (SE, *L);
582- auto scev = PSE.getAsAddRec (CanonicalIV);
583- gutils.newFunc ->dump ();
584- L->dump ();
585- scev->dump ();
586677 // predicate.addPredicate(SE.getWrapPredicate(SE.getSCEV(CanonicalIV), SCEVWrapPredicate::IncrementNoWrapMask));
587678 // Note exitcount needs the true latch (e.g. the one that branches back to header)
588679 // tather than the latch that contains the branch (as we define latch)
0 commit comments