Skip to content

Commit b8ac440

Browse files
committed
fix invert ending
1 parent beccb01 commit b8ac440

File tree

7 files changed

+343
-11
lines changed

7 files changed

+343
-11
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
3939
//BB is a predecessor of branchingBlock
4040
BasicBlock* GradientUtils::getReverseOrLatchMerge(BasicBlock* BB, BasicBlock* branchingBlock) {
4141
assert(BB);
42+
if (reverseBlocks.find(BB) == reverseBlocks.end()) {
43+
llvm::errs() << *oldFunc << "\n";
44+
llvm::errs() << *newFunc << "\n";
45+
llvm::errs() << "BB: " << *BB << "\n";
46+
llvm::errs() << "branchingBlock: " << *branchingBlock << "\n";
47+
}
4248
assert(reverseBlocks.find(BB) != reverseBlocks.end());
4349
LoopContext lc;
4450
bool inLoop = getContext(BB, lc);
@@ -90,6 +96,7 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
9096
lc.latchMerge->getInstList().push_front(lc.antivar);
9197

9298
IRBuilder<> mergeBuilder(lc.latchMerge);
99+
auto firstiter = mergeBuilder.CreatePHI(Type::getInt1Ty(mergeBuilder.getContext()), 1);
93100
auto sub = mergeBuilder.CreateSub(lc.antivar, ConstantInt::get(lc.antivar->getType(), 1));
94101

95102
auto latches = fake::SCEVExpander::getLatches(LI.getLoopFor(lc.header), lc.exitBlocks);
@@ -103,8 +110,11 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
103110
lim = lookupM(lc.limit, tbuild);
104111
}
105112
lc.antivar->addIncoming(lim, reverseBlocks[exit]);
113+
firstiter->addIncoming(ConstantInt::getFalse(mergeBuilder.getContext()), reverseBlocks[exit]);
106114
}
115+
107116
lc.antivar->addIncoming(sub, reverseBlocks[lc.header]);
117+
firstiter->addIncoming(ConstantInt::getTrue(mergeBuilder.getContext()), reverseBlocks[lc.header]);
108118

109119
if (latches.size() == 1) {
110120
lc.latchMerge->takeName(reverseBlocks[latches[0]]);
@@ -118,10 +128,37 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
118128
targetToPreds[reverseBlocks[latch]].push_back(latch);
119129
}
120130

131+
for(BasicBlock* exit : lc.exitBlocks) {
132+
std::vector<BasicBlock*> vec;
133+
for(auto pred : predecessors(exit)) {
134+
vec.push_back(pred);
135+
}
136+
if (vec.size() == 1) {
137+
auto fd = std::find(latches.begin(), latches.end(), vec[0]);
138+
if ( fd != latches.end()) {
139+
auto latch = *fd;
140+
targetToPreds[reverseBlocks[latch]].push_back(exit);
141+
}
142+
}
143+
}
144+
145+
BasicBlock* merger = BasicBlock::Create(newFunc->getContext(), "brancher", newFunc);
146+
BasicBlock* backlatch = nullptr;
147+
148+
for(auto blk : predecessors(lc.header)) {
149+
if (blk == lc.preheader) continue;
150+
assert(backlatch == nullptr);
151+
backlatch = blk;
152+
}
153+
assert(backlatch != nullptr);
154+
mergeBuilder.CreateCondBr(firstiter, merger, reverseBlocks[backlatch]);
155+
156+
mergeBuilder.SetInsertPoint(merger);
121157
this->branchToCorrespondingTarget(lc.preheader, mergeBuilder, targetToPreds);
122158
}
123159
}
124160
}
161+
125162
bool shouldRecompute(Value* val, const ValueToValueMapTy& available) {
126163
if (available.count(val)) return false;
127164
if (isa<Argument>(val) || isa<Constant>(val)) {

enzyme/Enzyme/GradientUtils.h

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ class GradientUtils {
569569
}
570570
llvm::errs() << *newFunc << "\n";
571571
llvm::errs() << BB2 << "\n";
572+
assert(0 && "could not find original block for given reverse block");
572573
report_fatal_error("could not find original block for given reverse block");
573574
}
574575

@@ -1154,6 +1155,7 @@ class GradientUtils {
11541155
bool isChildLoop = false;
11551156

11561157
BasicBlock* forwardBlock = BuilderM.GetInsertBlock();
1158+
11571159
if (!isOriginalBlock(*forwardBlock)) {
11581160
forwardBlock = originalForReverseBlock(*forwardBlock);
11591161
}
@@ -1241,6 +1243,10 @@ class GradientUtils {
12411243
}
12421244
}
12431245

1246+
IntegerType* T = (targetToPreds.size() == 2) ? Type::getInt1Ty(BuilderM.getContext()) : Type::getInt8Ty(BuilderM.getContext());
1247+
CallInst* freeLocation;
1248+
AllocaInst* cache = createCacheForScope(ctx, T, "heresay", /*shouldFree*/&freeLocation, /*lastAlloca*/nullptr);
1249+
12441250
TerminatorInst* equivalentTerminator = nullptr;
12451251

12461252
for(auto pair : done) {
@@ -1267,20 +1273,28 @@ class GradientUtils {
12671273
}
12681274
goto nofast;
12691275

1276+
12701277
fast:;
12711278
assert(equivalentTerminator);
12721279

12731280
if (auto branch = dyn_cast<BranchInst>(equivalentTerminator)) {
12741281
assert(branch->getCondition());
1275-
1276-
//if (!isa<Instruction>(branch->getCondition()) || DT.dominates(cast<Instruction>(branch->getCondition()), BB)) {
1277-
1278-
Value* phi = lookupM(branch->getCondition(), BuilderM);
1282+
1283+
IRBuilder<> pbuilder(equivalentTerminator);
1284+
pbuilder.setFastMathFlags(getFast());
1285+
storeInstructionInCache(ctx, pbuilder, branch->getCondition(), cache);
1286+
1287+
Value* phi = lookupValueFromCache(BuilderM, ctx, cache);
12791288
BuilderM.CreateCondBr(phi, *done[branch->getSuccessor(0)].begin(), *done[branch->getSuccessor(1)].begin());
12801289
return;
12811290
} else if (auto si = dyn_cast<SwitchInst>(equivalentTerminator)) {
1282-
1283-
Value* phi = lookupM(si->getCondition(), BuilderM);
1291+
assert(branch->getCondition());
1292+
1293+
IRBuilder<> pbuilder(equivalentTerminator);
1294+
pbuilder.setFastMathFlags(getFast());
1295+
storeInstructionInCache(ctx, pbuilder, branch->getCondition(), cache);
1296+
1297+
Value* phi = lookupValueFromCache(BuilderM, ctx, cache);
12841298
auto swtch = BuilderM.CreateSwitch(phi, *done[si->getDefaultDest()].begin());
12851299
for (auto switchcase : si->cases()) {
12861300
swtch->addCase(switchcase.getCaseValue(), *done[switchcase.getCaseSuccessor()].begin());
@@ -1295,10 +1309,6 @@ class GradientUtils {
12951309

12961310
nofast:;
12971311

1298-
IntegerType* T = (targetToPreds.size() == 2) ? Type::getInt1Ty(BuilderM.getContext()) : Type::getInt8Ty(BuilderM.getContext());
1299-
CallInst* freeLocation;
1300-
AllocaInst* cache = createCacheForScope(ctx, T, "heresay", /*shouldFree*/&freeLocation, /*lastAlloca*/nullptr);
1301-
13021312
std::vector<BasicBlock*> targets;
13031313
{
13041314
size_t idx = 0;
24 KB
Binary file not shown.

enzyme/test/Enzyme/insertsort.ll

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -correlated-propagation -instsimplify -adce -loop-deletion -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline norecurse nounwind uwtable
4+
define dso_local void @insertion_sort_inner(float* nocapture %array, i32 %i) local_unnamed_addr #0 {
5+
entry:
6+
%cmp29 = icmp sgt i32 %i, 0
7+
br i1 %cmp29, label %land.rhs.preheader, label %while.end
8+
9+
land.rhs.preheader: ; preds = %entry
10+
%0 = sext i32 %i to i64
11+
br label %land.rhs
12+
13+
land.rhs: ; preds = %land.rhs.preheader, %while.body
14+
%indvars.iv = phi i64 [ %0, %land.rhs.preheader ], [ %indvars.iv.next, %while.body ]
15+
%indvars.iv.next = add nsw i64 %indvars.iv, -1
16+
%arrayidx = getelementptr inbounds float, float* %array, i64 %indvars.iv.next
17+
%1 = load float, float* %arrayidx, align 4
18+
%arrayidx2 = getelementptr inbounds float, float* %array, i64 %indvars.iv
19+
%2 = load float, float* %arrayidx2, align 4
20+
%cmp3 = fcmp ogt float %1, %2
21+
br i1 %cmp3, label %while.body, label %while.end
22+
23+
while.body: ; preds = %land.rhs
24+
store float %1, float* %arrayidx2, align 4
25+
store float %2, float* %arrayidx, align 4
26+
%cmp = icmp sgt i64 %indvars.iv, 1
27+
br i1 %cmp, label %land.rhs, label %while.end
28+
29+
while.end: ; preds = %land.rhs, %while.body, %entry
30+
ret void
31+
}
32+
33+
34+
define dso_local void @dsum(float* %x, float* %xp, i32 %n) {
35+
entry:
36+
%0 = tail call double (void (float*, i32)*, ...) @__enzyme_autodiff(void (float*, i32)* nonnull @insertion_sort_inner, float* %x, float* %xp, i32 %n)
37+
ret void
38+
}
39+
40+
declare double @__enzyme_autodiff(void (float*, i32)*, ...)
41+
42+
attributes #0 = { noinline norecurse nounwind uwtable }
43+
44+
; CHECK: define internal {} @diffeinsertion_sort_inner(float* nocapture %array, float* %"array'", i32 %i) local_unnamed_addr #0 {
45+
; CHECK-NEXT: entry:
46+
; CHECK-NEXT: %cmp29 = icmp sgt i32 %i, 0
47+
; CHECK-NEXT: br i1 %cmp29, label %land.rhs.preheader, label %while.end
48+
49+
; CHECK: land.rhs.preheader: ; preds = %entry
50+
; CHECK-NEXT: %0 = sext i32 %i to i64
51+
; CHECK-NEXT: br label %land.rhs
52+
53+
; CHECK-NEXT: land.rhs: ; preds = %while.body, %land.rhs.preheader
54+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %while.body ], [ 0, %land.rhs.preheader ]
55+
; CHECK-NEXT: %1 = mul i64 %iv, -1
56+
; CHECK-NEXT: %2 = add i64 %0, %1
57+
; CHECK-NEXT: %iv.next = add nuw i64 %iv, 1
58+
; CHECK-NEXT: %indvars.iv.next = add nsw i64 %2, -1
59+
; CHECK-NEXT: %arrayidx = getelementptr inbounds float, float* %array, i64 %indvars.iv.next
60+
; CHECK-NEXT: %3 = load float, float* %arrayidx, align 4
61+
; CHECK-NEXT: %arrayidx2 = getelementptr inbounds float, float* %array, i64 %2
62+
; CHECK-NEXT: %4 = load float, float* %arrayidx2, align 4
63+
; CHECK-NEXT: %cmp3 = fcmp ogt float %3, %4
64+
; CHECK-NEXT: br i1 %cmp3, label %while.body, label %while.end.loopexit
65+
66+
; CHECK: while.body: ; preds = %land.rhs
67+
; CHECK-NEXT: store float %3, float* %arrayidx2, align 4
68+
; CHECK-NEXT: store float %4, float* %arrayidx, align 4
69+
; CHECK-NEXT: %cmp = icmp sgt i64 %2, 1
70+
; CHECK-NEXT: br i1 %cmp, label %land.rhs, label %while.end.loopexit
71+
72+
; CHECK: while.end.loopexit: ; preds = %while.body, %land.rhs
73+
; CHECK-NEXT: %"cmp3!manual_lcssa" = phi i1 [ %cmp3, %while.body ], [ %cmp3, %land.rhs ]
74+
; CHECK-NEXT: %5 = phi i8 [ 0, %while.body ], [ 1, %land.rhs ]
75+
; CHECK-NEXT: %6 = phi i64 [ %iv, %while.body ], [ %iv, %land.rhs ]
76+
; CHECK-NEXT: br label %while.end
77+
78+
; CHECK: while.end: ; preds = %while.end.loopexit, %entry
79+
; CHECK-NEXT: %"cmp3!manual_lcssa_cache.0" = phi i1 [ %"cmp3!manual_lcssa", %while.end.loopexit ], [ undef, %entry ]
80+
; CHECK-NEXT: %_cache1.0 = phi i8 [ %5, %while.end.loopexit ], [ undef, %entry ]
81+
; CHECK-NEXT: %_cache.0 = phi i64 [ %6, %while.end.loopexit ], [ undef, %entry ]
82+
; CHECK-NEXT: br label %invertwhile.end
83+
84+
; CHECK: invertentry: ; preds = %invertwhile.end, %invertland.rhs.preheader
85+
; CHECK-NEXT: ret {} undef
86+
87+
; CHECK: invertland.rhs.preheader: ; preds = %invertland.rhs
88+
; CHECK-NEXT: br label %invertentry
89+
90+
; CHECK: invertland.rhs: ; preds = %invertwhile.body, %loopMerge
91+
; CHECK-NEXT: %"'de2.0" = phi float [ 0.000000e+00, %loopMerge ], [ %30, %invertwhile.body ]
92+
; CHECK-NEXT: %"'de.0" = phi float [ 0.000000e+00, %loopMerge ], [ %24, %invertwhile.body ]
93+
; CHECK-NEXT: %_unwrap = sext i32 %i to i64
94+
; CHECK-NEXT: %7 = mul i64 %"iv'phi", -1
95+
; CHECK-NEXT: %8 = add i64 %_unwrap, %7
96+
; CHECK-NEXT: %"arrayidx2'ipg" = getelementptr float, float* %"array'", i64 %8
97+
; CHECK-NEXT: %9 = load float, float* %"arrayidx2'ipg"
98+
; CHECK-NEXT: %10 = fadd fast float %9, %"'de.0"
99+
; CHECK-NEXT: store float %10, float* %"arrayidx2'ipg"
100+
; CHECK-NEXT: %_unwrap3 = sext i32 %i to i64
101+
; CHECK-NEXT: %11 = mul i64 %"iv'phi", -1
102+
; CHECK-NEXT: %12 = add i64 %_unwrap3, %11
103+
; CHECK-NEXT: %13 = add i64 %12, -1
104+
; CHECK-NEXT: %"arrayidx'ipg" = getelementptr float, float* %"array'", i64 %13
105+
; CHECK-NEXT: %14 = load float, float* %"arrayidx'ipg"
106+
; CHECK-NEXT: %15 = fadd fast float %14, %"'de2.0"
107+
; CHECK-NEXT: store float %15, float* %"arrayidx'ipg"
108+
; CHECK-NEXT: %16 = icmp eq i64 %"iv'phi", 0
109+
; CHECK-NEXT: br i1 %16, label %invertland.rhs.preheader, label %loopMerge
110+
111+
; CHECK: invertwhile.body: ; preds = %loopMerge
112+
; CHECK-NEXT: %_unwrap4 = sext i32 %i to i64
113+
; CHECK-NEXT: %17 = mul i64 %"iv'phi", -1
114+
; CHECK-NEXT: %18 = add i64 %_unwrap4, %17
115+
; CHECK-NEXT: %19 = add i64 %18, -1
116+
; CHECK-NEXT: %"arrayidx'ipg5" = getelementptr float, float* %"array'", i64 %19
117+
; CHECK-NEXT: %20 = load float, float* %"arrayidx'ipg5"
118+
; CHECK-NEXT: %_unwrap6 = sext i32 %i to i64
119+
; CHECK-NEXT: %21 = mul i64 %"iv'phi", -1
120+
; CHECK-NEXT: %22 = add i64 %_unwrap6, %21
121+
; CHECK-NEXT: %23 = add i64 %22, -1
122+
; CHECK-NEXT: %"arrayidx'ipg7" = getelementptr float, float* %"array'", i64 %23
123+
; CHECK-NEXT: store float 0.000000e+00, float* %"arrayidx'ipg7"
124+
; CHECK-NEXT: %24 = fadd fast float 0.000000e+00, %20
125+
; CHECK-NEXT: %_unwrap8 = sext i32 %i to i64
126+
; CHECK-NEXT: %25 = mul i64 %"iv'phi", -1
127+
; CHECK-NEXT: %26 = add i64 %_unwrap8, %25
128+
; CHECK-NEXT: %"arrayidx2'ipg9" = getelementptr float, float* %"array'", i64 %26
129+
; CHECK-NEXT: %27 = load float, float* %"arrayidx2'ipg9"
130+
; CHECK-NEXT: %_unwrap10 = sext i32 %i to i64
131+
; CHECK-NEXT: %28 = mul i64 %"iv'phi", -1
132+
; CHECK-NEXT: %29 = add i64 %_unwrap10, %28
133+
; CHECK-NEXT: %"arrayidx2'ipg11" = getelementptr float, float* %"array'", i64 %29
134+
; CHECK-NEXT: store float 0.000000e+00, float* %"arrayidx2'ipg11"
135+
; CHECK-NEXT: %30 = fadd fast float 0.000000e+00, %27
136+
; CHECK-NEXT: br label %invertland.rhs
137+
138+
; CHECK: invertwhile.end.loopexit: ; preds = %invertwhile.end
139+
; CHECK-NEXT: br label %loopMerge
140+
141+
; CHECK: invertwhile.end: ; preds = %while.end
142+
; CHECK-NEXT: %31 = icmp sgt i32 %i, 0
143+
; CHECK-NEXT: br i1 %31, label %invertwhile.end.loopexit, label %invertentry
144+
145+
; CHECK: loopMerge: ; preds = %invertwhile.end.loopexit, %invertland.rhs
146+
; CHECK-NEXT: %"iv'phi" = phi i64 [ %_cache.0, %invertwhile.end.loopexit ], [ %32, %invertland.rhs ]
147+
; CHECK-NEXT: %32 = sub i64 %"iv'phi", 1
148+
; CHECK-NEXT: switch i8 %_cache1.0, label %invertland.rhs [
149+
; CHECK-NEXT: i8 0, label %invertwhile.body
150+
; CHECK-NEXT: ]
151+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)