Skip to content

Commit 566a1be

Browse files
committed
latch count
1 parent 51e0fd3 commit 566a1be

File tree

5 files changed

+111
-22
lines changed

5 files changed

+111
-22
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,18 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
6161
// if this is a branch into the loop, this certainly should go to the merged
6262
// block as this represents starting the loop
6363
if (!inLoopContext || !isParentOrSameContext(branchingContext, lc) ) {
64-
llvm::errs() << "LC BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
64+
//llvm::errs() << "LC BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
6565
return lc.latchMerge;
6666
}
6767

6868
// if we branch from inside the loop, we only need to go to the merged loop
6969
// if the original branch is to the header (otherwise its an internal branch in the loop)
7070
if (branchingBlock == lc.header) {
71-
llvm::errs() << "LH BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
71+
//llvm::errs() << "LH BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
7272
return lc.latchMerge;
7373
}
7474

75-
llvm::errs() << " BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
75+
//llvm::errs() << " BB:" << BB->getName() << " branchingBlock:" << branchingBlock->getName() << "\n";
7676
return reverseBlocks[BB];
7777
}
7878

@@ -487,7 +487,7 @@ std::pair<PHINode*,Value*> insertNewCanonicalIV(Loop* L, Type* Ty) {
487487
return std::pair<PHINode*,Value*>(CanonicalIV,inc);
488488
}
489489

490-
void removeRedundantIVs(BasicBlock* Header, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment=nullptr) {
490+
void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment=nullptr) {
491491
assert(Header);
492492
assert(CanonicalIV);
493493

@@ -518,6 +518,59 @@ void removeRedundantIVs(BasicBlock* Header, PHINode* CanonicalIV, ScalarEvolutio
518518
gutils.erase(PN);
519519
}
520520

521+
for (auto use : CanonicalIV->users()) {
522+
if (auto cmp = dyn_cast<ICmpInst>(use)) {
523+
if (cmp->isUnsigned()) {
524+
// Force i to be on LHS
525+
if (cmp->getOperand(0) != CanonicalIV) {
526+
//Below also swaps predicate correctly
527+
cmp->swapOperands();
528+
}
529+
assert(cmp->getOperand(0) == CanonicalIV);
530+
531+
// valid replacements (since unsigned comparison and i starts at 0 counting up)
532+
533+
// * i < n => i != n, valid since first time i >= n occurs at i == n
534+
if (cmp->getPredicate() == ICmpInst::ICMP_ULT) {
535+
cmp->setPredicate(ICmpInst::ICMP_NE);
536+
goto cend;
537+
}
538+
539+
// * i <= n => i != n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
540+
if (cmp->getPredicate() == ICmpInst::ICMP_ULE) {
541+
IRBuilder <>builder (Preheader->getTerminator());
542+
if (auto inst = dyn_cast<Instruction>(cmp->getOperand(1))) {
543+
builder.SetInsertPoint(inst->getNextNode());
544+
}
545+
cmp->setOperand(1, builder.CreateNUWAdd(cmp->getOperand(1), ConstantInt::get(cmp->getOperand(1)->getType(), 1, false)));
546+
cmp->setPredicate(ICmpInst::ICMP_NE);
547+
goto cend;
548+
}
549+
550+
// * i >= n => i == n, valid since first time i >= n occurs at i == n
551+
if (cmp->getPredicate() == ICmpInst::ICMP_UGE) {
552+
cmp->setPredicate(ICmpInst::ICMP_EQ);
553+
goto cend;
554+
}
555+
556+
// * i > n => i == n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
557+
if (cmp->getPredicate() == ICmpInst::ICMP_UGT) {
558+
IRBuilder <>builder (Preheader->getTerminator());
559+
if (auto inst = dyn_cast<Instruction>(cmp->getOperand(1))) {
560+
builder.SetInsertPoint(inst->getNextNode());
561+
}
562+
cmp->setOperand(1, builder.CreateNUWAdd(cmp->getOperand(1), ConstantInt::get(cmp->getOperand(1)->getType(), 1, false)));
563+
cmp->setPredicate(ICmpInst::ICMP_EQ);
564+
goto cend;
565+
}
566+
}
567+
cend:;
568+
if (cmp->getPredicate() == ICmpInst::ICMP_NE) {
569+
570+
}
571+
}
572+
}
573+
521574

522575
// Replace previous increment usage with new increment value
523576
if (increment) {
@@ -547,6 +600,48 @@ void removeRedundantIVs(BasicBlock* Header, PHINode* CanonicalIV, ScalarEvolutio
547600
for(auto inst: toerase) {
548601
gutils.erase(inst);
549602
}
603+
604+
for (auto use : increment->users()) {
605+
if (auto cmp = dyn_cast<ICmpInst>(use)) {
606+
if (cmp->isUnsigned()) {
607+
// Force i+1 to be on LHS
608+
if (cmp->getOperand(0) != increment) {
609+
//Below also swaps predicate correctly
610+
cmp->swapOperands();
611+
}
612+
assert(cmp->getOperand(0) == increment);
613+
614+
// valid replacements (since unsigned comparison and i starts at 0 counting up)
615+
616+
// * i+1 < n => i+1 != n, valid since first time i+1 >= n occurs at i+1 == n
617+
if (cmp->getPredicate() == ICmpInst::ICMP_ULT) {
618+
cmp->setPredicate(ICmpInst::ICMP_NE);
619+
continue;
620+
}
621+
622+
// * i+1 <= n => i != n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
623+
if (cmp->getPredicate() == ICmpInst::ICMP_ULE) {
624+
cmp->setOperand(0, CanonicalIV);
625+
cmp->setPredicate(ICmpInst::ICMP_NE);
626+
continue;
627+
}
628+
629+
// * i+1 >= n => i+1 == n, valid since first time i+1 >= n occurs at i+1 == n
630+
if (cmp->getPredicate() == ICmpInst::ICMP_UGE) {
631+
cmp->setPredicate(ICmpInst::ICMP_EQ);
632+
continue;
633+
}
634+
635+
// * i+1 > n => i == n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
636+
if (cmp->getPredicate() == ICmpInst::ICMP_UGT) {
637+
cmp->setOperand(0, CanonicalIV);
638+
cmp->setPredicate(ICmpInst::ICMP_EQ);
639+
continue;
640+
}
641+
}
642+
}
643+
}
644+
550645
}
551646
}
552647

@@ -574,15 +669,11 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
574669
auto pair = insertNewCanonicalIV(L, Type::getInt64Ty(BB->getContext()));
575670
PHINode* CanonicalIV = pair.first;
576671
assert(CanonicalIV);
577-
removeRedundantIVs(loopContexts[L].header, CanonicalIV, SE, gutils, pair.second);
672+
removeRedundantIVs(loopContexts[L].header, loopContexts[L].preheader, CanonicalIV, SE, gutils, pair.second);
578673
loopContexts[L].var = CanonicalIV;
579674
loopContexts[L].antivar = PHINode::Create(CanonicalIV->getType(), CanonicalIV->getNumIncomingValues(), CanonicalIV->getName()+"'phi");
580675

581676
PredicatedScalarEvolution PSE(SE, *L);
582-
auto scev = PSE.getAsAddRec(CanonicalIV);
583-
gutils.newFunc->dump();
584-
L->dump();
585-
scev->dump();
586677
//predicate.addPredicate(SE.getWrapPredicate(SE.getSCEV(CanonicalIV), SCEVWrapPredicate::IncrementNoWrapMask));
587678
// Note exitcount needs the true latch (e.g. the one that branches back to header)
588679
// tather than the latch that contains the branch (as we define latch)

enzyme/test/Enzyme/sumbr3.ll

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ for.body:
1717
%add = fadd fast double %0, %total.07
1818
%indvars.iv.next = add nuw i64 %indvars.iv, 1
1919
%exitcond = icmp ult i64 %indvars.iv, %n
20-
br i1 %exitcond, label %for.cond.cleanup, label %extra
20+
br i1 %exitcond, label %extra, label %for.cond.cleanup
2121

2222
extra:
2323
br label %for.body
@@ -42,16 +42,14 @@ attributes #2 = { nounwind }
4242
; CHECK-NEXT: entry:
4343
; CHECK-NEXT: br label %invertfor.body.i
4444
; CHECK: invertfor.body.i:
45-
; CHECK-NEXT: %"indvars.iv'phi.i" = phi i64 [ %n, %entry ], [ %0, %invertextra.i ]
46-
; CHECK-NEXT: %0 = sub i64 %"indvars.iv'phi.i", 1
47-
; CHECK-NEXT: %"arrayidx'ipg.i" = getelementptr double, double* %xp, i64 %"indvars.iv'phi.i"
45+
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %n, %entry ], [ %[[sub:.+]], %invertfor.body.i ]
46+
; CHECK-NEXT: %[[sub]] = sub i64 %[[antivar]], 1
47+
; CHECK-NEXT: %"arrayidx'ipg.i" = getelementptr double, double* %xp, i64 %[[antivar]]
4848
; CHECK-NEXT: %1 = load double, double* %"arrayidx'ipg.i"
4949
; CHECK-NEXT: %2 = fadd fast double %1, 1.000000e+00
5050
; CHECK-NEXT: store double %2, double* %"arrayidx'ipg.i"
51-
; CHECK-NEXT: %3 = icmp ne i64 %"indvars.iv'phi.i", 0
52-
; CHECK-NEXT: br i1 %3, label %invertextra.i, label %diffesum.exit
53-
; CHECK: invertextra.i:
54-
; CHECK-NEXT: br label %invertfor.body.i
55-
; CHECK: diffesum.exit: ; preds = %invertfor.body.i
51+
; CHECK-NEXT: %3 = icmp eq i64 %[[antivar]], 0
52+
; CHECK-NEXT: br i1 %3, label %diffesum.exit, label %invertfor.body.i
53+
; CHECK: diffesum.exit:
5654
; CHECK-NEXT: ret void
5755
; CHECK-NEXT: }

enzyme/test/Enzyme/sumsimple.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ attributes #0 = { noinline nounwind uwtable }
5555
; CHECK-NEXT: %4 = load double, double* %3
5656
; CHECK-NEXT: %add = fadd fast double %4, %1
5757
; CHECK-NEXT: store double %add, double* %3
58-
; CHECK-NEXT: %cmp = icmp ule i64 %iv.next, %n
58+
; CHECK-NEXT: %cmp = icmp ne i64 %iv, %n
5959
; CHECK-NEXT: br i1 %cmp, label %for.body, label %invertfor.body
6060

6161
; CHECK: invertentry: ; preds = %invertfor.body

enzyme/test/Enzyme/sumsimpleoptnone.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ attributes #0 = { noinline nounwind uwtable optnone }
5555
; CHECK-NEXT: %4 = load double, double* %3
5656
; CHECK-NEXT: %add = fadd fast double %4, %1
5757
; CHECK-NEXT: store double %add, double* %3
58-
; CHECK-NEXT: %cmp = icmp ule i64 %iv.next, %n
58+
; CHECK-NEXT: %cmp = icmp ne i64 %iv, %n
5959
; CHECK-NEXT: br i1 %cmp, label %for.body, label %invertfor.body
6060

6161
; CHECK: invertentry: ; preds = %invertfor.body

enzyme/test/Enzyme/sumwithbreak.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ attributes #0 = { noinline nounwind uwtable }
5757
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds double, double* %x, i64 %iv
5858
; CHECK-NEXT: %0 = load double, double* %arrayidx4, align 8
5959
; CHECK-NEXT: %add5 = fadd fast double %0, %data.016
60-
; CHECK-NEXT: %cmp = icmp ult i64 %iv, %n
61-
; CHECK-NEXT: br i1 %cmp, label %for.body, label %loopMerge.peel
60+
; CHECK-NEXT: %cmp = icmp eq i64 %iv, %n
61+
; CHECK-NEXT: br i1 %cmp, label %loopMerge.peel, label %for.body
6262

6363
; CHECK: invertentry:
6464
; CHECK-NEXT: ret {} undef

0 commit comments

Comments
 (0)