Skip to content

Commit a8bdf3b

Browse files
committed
revamp loop calc
1 parent 1498c0e commit a8bdf3b

File tree

6 files changed

+50
-41
lines changed

6 files changed

+50
-41
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ std::pair<PHINode*,Value*> insertNewCanonicalIV(Loop* L, Type* Ty) {
504504
return std::pair<PHINode*,Value*>(CanonicalIV,inc);
505505
}
506506

507-
void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment=nullptr) {
507+
void removeRedundantIVs(const Loop* L, BasicBlock* Header, BasicBlock* Preheader, PHINode* CanonicalIV, ScalarEvolution &SE, GradientUtils &gutils, Value* increment, const SmallVectorImpl<BasicBlock*>&& latches) {
508508
assert(Header);
509509
assert(CanonicalIV);
510510

@@ -535,26 +535,30 @@ void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* Cano
535535
gutils.erase(PN);
536536
}
537537

538+
if (latches.size() == 1 && isa<BranchInst>(latches[0]->getTerminator()) && cast<BranchInst>(latches[0]->getTerminator())->isConditional())
538539
for (auto use : CanonicalIV->users()) {
539540
if (auto cmp = dyn_cast<ICmpInst>(use)) {
540-
if (cmp->isUnsigned()) {
541-
// Force i to be on LHS
542-
if (cmp->getOperand(0) != CanonicalIV) {
543-
//Below also swaps predicate correctly
544-
cmp->swapOperands();
545-
}
546-
assert(cmp->getOperand(0) == CanonicalIV);
541+
if (cast<BranchInst>(latches[0]->getTerminator())->getCondition() != cmp) continue;
542+
// Force i to be on LHS
543+
if (cmp->getOperand(0) != CanonicalIV) {
544+
//Below also swaps predicate correctly
545+
cmp->swapOperands();
546+
}
547+
assert(cmp->getOperand(0) == CanonicalIV);
548+
549+
auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
550+
if (cmp->isUnsigned() || (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv)) ) {
547551

548552
// valid replacements (since unsigned comparison and i starts at 0 counting up)
549553

550554
// * i < n => i != n, valid since first time i >= n occurs at i == n
551-
if (cmp->getPredicate() == ICmpInst::ICMP_ULT) {
555+
if (cmp->getPredicate() == ICmpInst::ICMP_ULT || cmp->getPredicate() == ICmpInst::ICMP_SLT) {
552556
cmp->setPredicate(ICmpInst::ICMP_NE);
553557
goto cend;
554558
}
555559

556560
// * i <= n => i != n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
557-
if (cmp->getPredicate() == ICmpInst::ICMP_ULE) {
561+
if (cmp->getPredicate() == ICmpInst::ICMP_ULE || cmp->getPredicate() == ICmpInst::ICMP_SLE) {
558562
IRBuilder <>builder (Preheader->getTerminator());
559563
if (auto inst = dyn_cast<Instruction>(cmp->getOperand(1))) {
560564
builder.SetInsertPoint(inst->getNextNode());
@@ -565,13 +569,13 @@ void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* Cano
565569
}
566570

567571
// * i >= n => i == n, valid since first time i >= n occurs at i == n
568-
if (cmp->getPredicate() == ICmpInst::ICMP_UGE) {
572+
if (cmp->getPredicate() == ICmpInst::ICMP_UGE || cmp->getPredicate() == ICmpInst::ICMP_SGE) {
569573
cmp->setPredicate(ICmpInst::ICMP_EQ);
570574
goto cend;
571575
}
572576

573577
// * i > n => i == n+1, valid since first time i > n occurs at i == n+1 [ which we assert is in bitrange as not infinite loop ]
574-
if (cmp->getPredicate() == ICmpInst::ICMP_UGT) {
578+
if (cmp->getPredicate() == ICmpInst::ICMP_UGT || cmp->getPredicate() == ICmpInst::ICMP_SGT) {
575579
IRBuilder <>builder (Preheader->getTerminator());
576580
if (auto inst = dyn_cast<Instruction>(cmp->getOperand(1))) {
577581
builder.SetInsertPoint(inst->getNextNode());
@@ -618,39 +622,44 @@ void removeRedundantIVs(BasicBlock* Header, BasicBlock* Preheader, PHINode* Cano
618622
gutils.erase(inst);
619623
}
620624

625+
if (latches.size() == 1 && isa<BranchInst>(latches[0]->getTerminator()) && cast<BranchInst>(latches[0]->getTerminator())->isConditional())
621626
for (auto use : increment->users()) {
622627
if (auto cmp = dyn_cast<ICmpInst>(use)) {
623-
if (cmp->isUnsigned()) {
624-
// Force i+1 to be on LHS
625-
if (cmp->getOperand(0) != increment) {
626-
//Below also swaps predicate correctly
627-
cmp->swapOperands();
628-
}
629-
assert(cmp->getOperand(0) == increment);
628+
if (cast<BranchInst>(latches[0]->getTerminator())->getCondition() != cmp) continue;
629+
630+
// Force i+1 to be on LHS
631+
if (cmp->getOperand(0) != increment) {
632+
//Below also swaps predicate correctly
633+
cmp->swapOperands();
634+
}
635+
assert(cmp->getOperand(0) == increment);
636+
637+
auto scv = SE.getSCEVAtScope(cmp->getOperand(1), L);
638+
if (cmp->isUnsigned() || (scv != SE.getCouldNotCompute() && SE.isKnownNonNegative(scv)) ) {
630639

631640
// valid replacements (since unsigned comparison and i starts at 0 counting up)
632641

633642
// * i+1 < n => i+1 != n, valid since first time i+1 >= n occurs at i+1 == n
634-
if (cmp->getPredicate() == ICmpInst::ICMP_ULT) {
643+
if (cmp->getPredicate() == ICmpInst::ICMP_ULT || cmp->getPredicate() == ICmpInst::ICMP_SLT) {
635644
cmp->setPredicate(ICmpInst::ICMP_NE);
636645
continue;
637646
}
638647

639648
// * i+1 <= n => i != n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
640-
if (cmp->getPredicate() == ICmpInst::ICMP_ULE) {
649+
if (cmp->getPredicate() == ICmpInst::ICMP_ULE || cmp->getPredicate() == ICmpInst::ICMP_SLE) {
641650
cmp->setOperand(0, CanonicalIV);
642651
cmp->setPredicate(ICmpInst::ICMP_NE);
643652
continue;
644653
}
645654

646655
// * i+1 >= n => i+1 == n, valid since first time i+1 >= n occurs at i+1 == n
647-
if (cmp->getPredicate() == ICmpInst::ICMP_UGE) {
656+
if (cmp->getPredicate() == ICmpInst::ICMP_UGE || cmp->getPredicate() == ICmpInst::ICMP_SGE) {
648657
cmp->setPredicate(ICmpInst::ICMP_EQ);
649658
continue;
650659
}
651660

652661
// * i+1 > n => i == n, valid since first time i+1 > n occurs at i+1 == n+1 => i == n
653-
if (cmp->getPredicate() == ICmpInst::ICMP_UGT) {
662+
if (cmp->getPredicate() == ICmpInst::ICMP_UGT || cmp->getPredicate() == ICmpInst::ICMP_SGT) {
654663
cmp->setOperand(0, CanonicalIV);
655664
cmp->setPredicate(ICmpInst::ICMP_EQ);
656665
continue;
@@ -686,7 +695,7 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
686695
auto pair = insertNewCanonicalIV(L, Type::getInt64Ty(BB->getContext()));
687696
PHINode* CanonicalIV = pair.first;
688697
assert(CanonicalIV);
689-
removeRedundantIVs(loopContexts[L].header, loopContexts[L].preheader, CanonicalIV, SE, gutils, pair.second);
698+
removeRedundantIVs(L, loopContexts[L].header, loopContexts[L].preheader, CanonicalIV, SE, gutils, pair.second, fake::SCEVExpander::getLatches(L, loopContexts[L].exitBlocks));
690699
loopContexts[L].var = CanonicalIV;
691700
loopContexts[L].antivar = PHINode::Create(CanonicalIV->getType(), CanonicalIV->getNumIncomingValues(), CanonicalIV->getName()+"'phi");
692701

enzyme/test/Enzyme/llist.ll

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,17 +172,17 @@ attributes #4 = { nounwind }
172172
; CHECK-NEXT: %"'ipl" = load %struct.n*, %struct.n** %"next'ipg", align 8
173173
; CHECK-NEXT: %[[loadst]] = load %struct.n*, %struct.n** %next, align 8, !tbaa !8
174174
; CHECK-NEXT: %cmp = icmp eq %struct.n* %[[loadst]], null
175-
; CHECK-NEXT: br i1 %cmp, label %invertfor.body, label %for.body
175+
; CHECK-NEXT: br i1 %cmp, label %[[antiloop:.+]], label %for.body
176176

177177
; CHECK: invertentry:
178178
; CHECK-NEXT: ret {} undef
179179

180-
; CHECK: invertfor.body.preheader: ; preds = %invertfor.body
180+
; CHECK: invertfor.body.preheader:
181181
; CHECK-NEXT: tail call void @free(i8* nonnull %_realloccache)
182182
; CHECK-NEXT: br label %invertentry
183183

184-
; CHECK: invertfor.body:
185-
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %[[subidx:.+]], %invertfor.body ], [ %[[preidx]], %for.body ]
184+
; CHECK: [[antiloop]]:
185+
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %[[subidx:.+]], %[[antiloop]] ], [ %[[preidx]], %for.body ]
186186
; CHECK-NEXT: %[[subidx]] = add i64 %[[antivar]], -1
187187
; CHECK-NEXT: %[[structptr:.+]] = getelementptr %struct.n*, %struct.n** %[[bcalloc]], i64 %[[antivar]]
188188
; CHECK-NEXT: %[[struct:.+]] = load %struct.n*, %struct.n** %[[structptr]]
@@ -191,5 +191,5 @@ attributes #4 = { nounwind }
191191
; CHECK-NEXT: %[[addval:.+]] = fadd fast double %[[val0]], %[[differet]]
192192
; CHECK-NEXT: store double %[[addval]], double* %"value'ipg"
193193
; CHECK-NEXT: %[[cmpeq:.+]] = icmp eq i64 %[[antivar]], 0
194-
; CHECK-NEXT: br i1 %[[cmpeq]], label %invertfor.body.preheader, label %invertfor.body
194+
; CHECK-NEXT: br i1 %[[cmpeq]], label %invertfor.body.preheader, label %[[antiloop]]
195195
; CHECK-NEXT: }

enzyme/test/Enzyme/nllist.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ attributes #4 = { nounwind }
332332
; CHECK-NEXT: %[[dstructload]] = load %struct.n*, %struct.n** %"next'ipg", align 8
333333
; CHECK-NEXT: %[[nextstruct]] = load %struct.n*, %struct.n** %next, align 8, !tbaa !7
334334
; CHECK-NEXT: %[[mycmp:.+]] = icmp eq %struct.n* %[[nextstruct]], null
335-
; CHECK-NEXT: br i1 %[[mycmp]], label %invertfor.cond.cleanup4, label %for.cond1.preheader
335+
; CHECK-NEXT: br i1 %[[mycmp]], label %[[invertforcondcleanup:.+]], label %for.cond1.preheader
336336

337337
; CHECK: for.body5: ; preds = %for.body5, %for.cond1.preheader
338338
; CHECK-NEXT: %[[iv:.+]] = phi i64 [ %[[ivnext:.+]], %for.body5 ], [ 0, %for.cond1.preheader ]
@@ -349,17 +349,17 @@ attributes #4 = { nounwind }
349349

350350
; CHECK: invertfor.cond1.preheader: ; preds = %invertfor.body5
351351
; CHECK-NEXT: %[[icmp:.+]] = icmp eq i64 %[[antivar:.+]], 0
352-
; CHECK-NEXT: br i1 %[[icmp]], label %invertfor.cond1.preheader.preheader, label %invertfor.cond.cleanup4
352+
; CHECK-NEXT: br i1 %[[icmp]], label %invertfor.cond1.preheader.preheader, label %[[invertforcondcleanup]]
353353

354-
; CHECK: invertfor.cond.cleanup4:
354+
; CHECK: [[invertforcondcleanup]]:
355355
; CHECK-NEXT: %[[antivar]] = phi i64 [ %[[isub:.+]], %invertfor.cond1.preheader ], [ %[[preidx]], %for.cond.cleanup4 ]
356356
; CHECK-NEXT: %[[isub]] = add i64 %[[antivar]], -1
357357
; CHECK-NEXT: %[[toload:.+]] = getelementptr double*, double** %[[todoublep]], i64 %[[antivar]]
358358
; CHECK-NEXT: %[[loadediv:.+]] = load double*, double** %[[toload]], align 8, !invariant.load
359359
; CHECK-NEXT: br label %invertfor.body5
360360

361361
; CHECK: invertfor.body5:
362-
; CHECK-NEXT: %[[mantivar:.+]] = phi i64 [ %times, %invertfor.cond.cleanup4 ], [ %[[idxsub:.+]], %invertfor.body5 ]
362+
; CHECK-NEXT: %[[mantivar:.+]] = phi i64 [ %times, %[[invertforcondcleanup]] ], [ %[[idxsub:.+]], %invertfor.body5 ]
363363
; CHECK-NEXT: %[[idxsub]] = add i64 %[[mantivar]], -1
364364
; CHECK-NEXT: %"arrayidx'ipg" = getelementptr double, double* %[[loadediv]], i64 %[[mantivar]]
365365
; CHECK-NEXT: %[[arrayload:.+]] = load double, double* %"arrayidx'ipg"

enzyme/test/Enzyme/sumbr2.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -instcombine -simplifycfg -S | FileCheck %s
1+
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -instcombine -simplifycfg -S -jump-threading -instsimplify -simplifycfg -adce -loop-deletion -simplifycfg | FileCheck %s
22

33
; Function Attrs: norecurse nounwind readonly uwtable
44
define dso_local double @sum(double* nocapture readonly %x, i64 %n) #0 {
@@ -45,15 +45,15 @@ attributes #2 = { nounwind }
4545
; CHECK-NEXT: br i1 %[[exists]], label %diffesum.exit, label %[[antiloop:.+]]
4646

4747
; CHECK: [[antiloop]]:
48-
; CHECK-NEXT: %"add'de.0.i" = phi double [ %[[m0dadd:.+]], %[[antiloop]] ], [ 1.000000e+00, %entry ]
48+
; CHECK-NEXT: %[[dadd:.+]] = phi double [ %[[m0dadd:.+]], %[[antiloop]] ], [ 1.000000e+00, %entry ]
4949
; CHECK-NEXT: %[[antivar:.+]] = phi i64 [ %[[sub:.+]], %[[antiloop]] ], [ %n, %entry ]
5050
; CHECK-NEXT: %[[sub]] = add i64 %[[antivar]], -1
5151
; CHECK-NEXT: %"arrayidx'ipg.i" = getelementptr double, double* %xp, i64 %[[antivar]]
5252
; CHECK-NEXT: %[[toload:.+]] = load double, double* %"arrayidx'ipg.i", align 8
53-
; CHECK-NEXT: %[[tostore:.+]] = fadd fast double %[[toload]], %"add'de.0.i"
53+
; CHECK-NEXT: %[[tostore:.+]] = fadd fast double %[[toload]], %[[dadd]]
5454
; CHECK-NEXT: store double %[[tostore]], double* %"arrayidx'ipg.i", align 8
5555
; CHECK-NEXT: %res_unwrap.i = uitofp i64 %[[sub]] to double
56-
; CHECK-NEXT: %[[m0dadd]] = fmul fast double %"add'de.0.i", %res_unwrap.i
56+
; CHECK-NEXT: %[[m0dadd]] = fmul fast double %[[dadd]], %res_unwrap.i
5757
; CHECK-NEXT: %[[itercmp:.+]] = icmp eq i64 %[[sub]], 0
5858
; CHECK-NEXT: br i1 %[[itercmp]], label %diffesum.exit, label %invertextra.i
5959

enzyme/test/Enzyme/sumsimple.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S -early-cse | FileCheck %s
1+
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S -early-cse -simplifycfg | FileCheck %s
22

33
; Function Attrs: noinline nounwind uwtable
44
define dso_local void @f(double* %x, double** %y, i64 %n) #0 {

enzyme/test/Enzyme/sumwithbreak.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instcombine -correlated-propagation -adce -instcombine -simplifycfg -early-cse -simplifycfg -loop-unroll -instcombine -simplifycfg -gvn -jump-threading -instcombine -S | FileCheck %s
1+
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instcombine -correlated-propagation -adce -instcombine -simplifycfg -early-cse -simplifycfg -loop-unroll -instcombine -simplifycfg -gvn -jump-threading -instcombine -simplifycfg -S | FileCheck %s
22

33
; Function Attrs: noinline nounwind uwtable
44
define dso_local double @f(double* nocapture readonly %x, i64 %n) #0 {
@@ -57,8 +57,8 @@ attributes #0 = { noinline nounwind uwtable }
5757
; CHECK-NEXT: %arrayidx4 = getelementptr inbounds double, double* %x, i64 %iv
5858
; CHECK-NEXT: %0 = load double, double* %arrayidx4, align 8
5959
; CHECK-NEXT: %add5 = fadd fast double %0, %data.016
60-
; CHECK-NEXT: %cmp = icmp eq i64 %iv, %n
61-
; CHECK-NEXT: br i1 %cmp, label %invertif.end.peel, label %for.body
60+
; CHECK-NEXT: %cmp = icmp ult i64 %iv, %n
61+
; CHECK-NEXT: br i1 %cmp, label %for.body, label %invertif.end.peel
6262

6363
; CHECK: invertentry:
6464
; CHECK-NEXT: ret {} undef

0 commit comments

Comments
 (0)