@@ -504,7 +504,7 @@ std::pair<PHINode*,Value*> insertNewCanonicalIV(Loop* L, Type* Ty) {
504504 return std::pair<PHINode*,Value*>(CanonicalIV,inc);
505505}
506506
507- void removeRedundantIVs (BasicBlock* Header, BasicBlock* Preheader, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment= nullptr ) {
507+ void removeRedundantIVs (const Loop* L, BasicBlock* Header, BasicBlock* Preheader, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment, const SmallVectorImpl<BasicBlock*>&& latches ) {
508508 assert (Header);
509509 assert (CanonicalIV);
510510
@@ -535,26 +535,30 @@ void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* Cano
535535 gutils.erase (PN);
536536 }
537537
538+ if (latches.size () == 1 && isa<BranchInst>(latches[0 ]->getTerminator ()) && cast<BranchInst>(latches[0 ]->getTerminator ())->isConditional ())
538539 for (auto use : CanonicalIV->users ()) {
539540 if (auto cmp = dyn_cast<ICmpInst>(use)) {
540- if (cmp->isUnsigned ()) {
541- // Force i to be on LHS
542- if (cmp->getOperand (0 ) != CanonicalIV) {
543- // Below also swaps predicate correctly
544- cmp->swapOperands ();
545- }
546- assert (cmp->getOperand (0 ) == CanonicalIV);
541+ if (cast<BranchInst>(latches[0 ]->getTerminator ())->getCondition () != cmp) continue ;
542+ // Force i to be on LHS
543+ if (cmp->getOperand (0 ) != CanonicalIV) {
544+ // Below also swaps predicate correctly
545+ cmp->swapOperands ();
546+ }
547+ assert (cmp->getOperand (0 ) == CanonicalIV);
548+
549+ auto scv = SE.getSCEVAtScope (cmp->getOperand (1 ), L);
550+ if (cmp->isUnsigned () || (scv != SE.getCouldNotCompute () && SE.isKnownNonNegative (scv)) ) {
547551
548552 // valid replacements (since unsigned comparison and i starts at 0 counting up)
549553
550554 // * i < n => i != n, valid since first time i >= n occurs at i == n
551- if (cmp->getPredicate () == ICmpInst::ICMP_ULT) {
555+ if (cmp->getPredicate () == ICmpInst::ICMP_ULT || cmp-> getPredicate () == ICmpInst::ICMP_SLT ) {
552556 cmp->setPredicate (ICmpInst::ICMP_NE);
553557 goto cend;
554558 }
555559
556560 // * i <= n => i != n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
557- if (cmp->getPredicate () == ICmpInst::ICMP_ULE) {
561+ if (cmp->getPredicate () == ICmpInst::ICMP_ULE || cmp-> getPredicate () == ICmpInst::ICMP_SLE ) {
558562 IRBuilder <>builder (Preheader->getTerminator ());
559563 if (auto inst = dyn_cast<Instruction>(cmp->getOperand (1 ))) {
560564 builder.SetInsertPoint (inst->getNextNode ());
@@ -565,13 +569,13 @@ void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* Cano
565569 }
566570
567571 // * i >= n => i == n, valid since first time i >= n occurs at i == n
568- if (cmp->getPredicate () == ICmpInst::ICMP_UGE) {
572+ if (cmp->getPredicate () == ICmpInst::ICMP_UGE || cmp-> getPredicate () == ICmpInst::ICMP_SGE ) {
569573 cmp->setPredicate (ICmpInst::ICMP_EQ);
570574 goto cend;
571575 }
572576
573577 // * i > n => i == n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
574- if (cmp->getPredicate () == ICmpInst::ICMP_UGT) {
578+ if (cmp->getPredicate () == ICmpInst::ICMP_UGT || cmp-> getPredicate () == ICmpInst::ICMP_SGT ) {
575579 IRBuilder <>builder (Preheader->getTerminator ());
576580 if (auto inst = dyn_cast<Instruction>(cmp->getOperand (1 ))) {
577581 builder.SetInsertPoint (inst->getNextNode ());
@@ -618,39 +622,44 @@ void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* Cano
618622 gutils.erase (inst);
619623 }
620624
625+ if (latches.size () == 1 && isa<BranchInst>(latches[0 ]->getTerminator ()) && cast<BranchInst>(latches[0 ]->getTerminator ())->isConditional ())
621626 for (auto use : increment->users ()) {
622627 if (auto cmp = dyn_cast<ICmpInst>(use)) {
623- if (cmp->isUnsigned ()) {
624- // Force i+1 to be on LHS
625- if (cmp->getOperand (0 ) != increment) {
626- // Below also swaps predicate correctly
627- cmp->swapOperands ();
628- }
629- assert (cmp->getOperand (0 ) == increment);
628+ if (cast<BranchInst>(latches[0 ]->getTerminator ())->getCondition () != cmp) continue ;
629+
630+ // Force i+1 to be on LHS
631+ if (cmp->getOperand (0 ) != increment) {
632+ // Below also swaps predicate correctly
633+ cmp->swapOperands ();
634+ }
635+ assert (cmp->getOperand (0 ) == increment);
636+
637+ auto scv = SE.getSCEVAtScope (cmp->getOperand (1 ), L);
638+ if (cmp->isUnsigned () || (scv != SE.getCouldNotCompute () && SE.isKnownNonNegative (scv)) ) {
630639
631640 // valid replacements (since unsigned comparison and i starts at 0 counting up)
632641
633642 // * i+1 < n => i+1 != n, valid since first time i+1 >= n occurs at i+1 == n
634- if (cmp->getPredicate () == ICmpInst::ICMP_ULT) {
643+ if (cmp->getPredicate () == ICmpInst::ICMP_ULT || cmp-> getPredicate () == ICmpInst::ICMP_SLT ) {
635644 cmp->setPredicate (ICmpInst::ICMP_NE);
636645 continue ;
637646 }
638647
639648 // * i+1 <= n => i != n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
640- if (cmp->getPredicate () == ICmpInst::ICMP_ULE) {
649+ if (cmp->getPredicate () == ICmpInst::ICMP_ULE || cmp-> getPredicate () == ICmpInst::ICMP_SLE ) {
641650 cmp->setOperand (0 , CanonicalIV);
642651 cmp->setPredicate (ICmpInst::ICMP_NE);
643652 continue ;
644653 }
645654
646655 // * i+1 >= n => i+1 == n, valid since first time i+1 >= n occurs at i+1 == n
647- if (cmp->getPredicate () == ICmpInst::ICMP_UGE) {
656+ if (cmp->getPredicate () == ICmpInst::ICMP_UGE || cmp-> getPredicate () == ICmpInst::ICMP_SGE ) {
648657 cmp->setPredicate (ICmpInst::ICMP_EQ);
649658 continue ;
650659 }
651660
652661 // * i+1 > n => i == n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
653- if (cmp->getPredicate () == ICmpInst::ICMP_UGT) {
662+ if (cmp->getPredicate () == ICmpInst::ICMP_UGT || cmp-> getPredicate () == ICmpInst::ICMP_SGT ) {
654663 cmp->setOperand (0 , CanonicalIV);
655664 cmp->setPredicate (ICmpInst::ICMP_EQ);
656665 continue ;
@@ -686,7 +695,7 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
686695 auto pair = insertNewCanonicalIV (L, Type::getInt64Ty (BB->getContext ()));
687696 PHINode* CanonicalIV = pair.first ;
688697 assert (CanonicalIV);
689- removeRedundantIVs (loopContexts[L].header , loopContexts[L].preheader , CanonicalIV, SE, gutils, pair.second );
698+ removeRedundantIVs (L, loopContexts[L].header , loopContexts[L].preheader , CanonicalIV, SE, gutils, pair.second , fake::SCEVExpander::getLatches (L, loopContexts[L]. exitBlocks ) );
690699 loopContexts[L].var = CanonicalIV;
691700 loopContexts[L].antivar = PHINode::Create (CanonicalIV->getType (), CanonicalIV->getNumIncomingValues (), CanonicalIV->getName ()+" 'phi" );
692701
0 commit comments