Skip to content

Commit 11c6c14

Browse files
authored
Fix LCSSA lookup scope bug (rust-lang#951)
* Fix LCSSA lookup scope bug * temp * tmp * Fixup * update test
1 parent 9eefd84 commit 11c6c14

File tree

6 files changed

+224
-77
lines changed

6 files changed

+224
-77
lines changed

enzyme/Enzyme/CacheUtility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ class CacheUtility {
346346
lookupM(llvm::Value *val, llvm::IRBuilder<> &BuilderM,
347347
const llvm::ValueToValueMapTy &incoming_availalble =
348348
llvm::ValueToValueMapTy(),
349-
bool tryLegalityCheck = true) = 0;
349+
bool tryLegalityCheck = true, llvm::BasicBlock *scope = nullptr) = 0;
350350

351351
virtual bool assumeDynamicLoopOfSizeOne(llvm::Loop *L) const = 0;
352352

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
338338
}
339339
}
340340

341-
#define getOpFullest(Builder, vtmp, frominst, check) \
341+
#define getOpFullest(Builder, vtmp, frominst, lookupInst, check) \
342342
({ \
343343
Value *v = vtmp; \
344344
BasicBlock *origParent = frominst; \
@@ -362,24 +362,38 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
362362
if (!DT.dominates(opinst, &*Builder.GetInsertPoint())) \
363363
noLookup = true; \
364364
} \
365-
if (origParent) \
365+
origParent = lookupInst; \
366+
if (BasicBlock *forwardBlock = origParent) \
366367
if (auto opinst = dyn_cast<Instruction>(v)) { \
367-
v = fixLCSSA(opinst, origParent); \
368+
if (!isOriginalBlock(*forwardBlock)) { \
369+
forwardBlock = originalForReverseBlock(*forwardBlock); \
370+
} \
371+
if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { \
372+
v = fixLCSSA(opinst, forwardBlock); \
373+
origParent = nullptr; \
374+
} \
368375
} \
369376
if (!noLookup) \
370-
___res = lookupM(v, Builder, available, v != val); \
377+
___res = lookupM(v, Builder, available, v != val, origParent); \
371378
} \
372379
if (___res) \
373380
assert(___res->getType() == v->getType() && "uw"); \
374381
} else { \
375-
if (origParent) \
382+
origParent = lookupInst; \
383+
if (BasicBlock *forwardBlock = origParent) \
376384
if (auto opinst = dyn_cast<Instruction>(v)) { \
377-
v = fixLCSSA(opinst, origParent); \
385+
if (!isOriginalBlock(*forwardBlock)) { \
386+
forwardBlock = originalForReverseBlock(*forwardBlock); \
387+
} \
388+
if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) { \
389+
v = fixLCSSA(opinst, forwardBlock); \
390+
origParent = nullptr; \
391+
} \
378392
} \
379393
assert(unwrapMode == UnwrapMode::AttemptSingleUnwrap); \
380394
auto found = available.find(v); \
381395
assert(found == available.end() || found->second); \
382-
___res = lookupM(v, Builder, available, v != val); \
396+
___res = lookupM(v, Builder, available, v != val, origParent); \
383397
if (___res && ___res->getType() != v->getType()) { \
384398
llvm::errs() << *newFunc << "\n"; \
385399
llvm::errs() << " v = " << *v << " res = " << *___res << "\n"; \
@@ -390,19 +404,25 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
390404
___res; \
391405
})
392406
#define getOpFull(Builder, vtmp, frominst) \
393-
getOpFullest(Builder, vtmp, frominst, true)
407+
({ \
408+
BasicBlock *parent = scope; \
409+
if (parent == nullptr) \
410+
if (auto originst = dyn_cast<Instruction>(val)) \
411+
parent = originst->getParent(); \
412+
getOpFullest(Builder, vtmp, frominst, parent, true); \
413+
})
394414
#define getOpUnchecked(vtmp) \
395415
({ \
396416
BasicBlock *parent = scope; \
397-
getOpFullest(BuilderM, vtmp, parent, false); \
417+
getOpFullest(BuilderM, vtmp, parent, parent, false); \
398418
})
399419
#define getOp(vtmp) \
400420
({ \
401421
BasicBlock *parent = scope; \
402422
if (parent == nullptr) \
403423
if (auto originst = dyn_cast<Instruction>(val)) \
404424
parent = originst->getParent(); \
405-
getOpFullest(BuilderM, vtmp, parent, true); \
425+
getOpFullest(BuilderM, vtmp, parent, parent, true); \
406426
})
407427

408428
if (isa<Argument>(val) || isa<Constant>(val)) {
@@ -1470,16 +1490,39 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
14701490
if (!DT.dominates(inst, &*B.GetInsertPoint()))
14711491
noLookup = true;
14721492
}
1473-
Value *v = fixLCSSA(inst, nextScope);
1474-
if (!noLookup)
1475-
___res = lookupM(v, B, prevAvailable, v != val);
1493+
if (!noLookup) {
1494+
BasicBlock *nS2 = nextScope;
1495+
Value *v = inst;
1496+
if (BasicBlock *forwardBlock = nextScope)
1497+
if (auto opinst = dyn_cast<Instruction>(v)) {
1498+
if (!isOriginalBlock(*forwardBlock)) {
1499+
forwardBlock = originalForReverseBlock(*forwardBlock);
1500+
}
1501+
if (isPotentialLastLoopValue(opinst, forwardBlock,
1502+
LI)) {
1503+
v = fixLCSSA(opinst, forwardBlock);
1504+
nS2 = nullptr;
1505+
}
1506+
}
1507+
___res = lookupM(v, B, prevAvailable, v != val, nS2);
1508+
}
14761509
}
14771510
if (___res)
14781511
assert(___res->getType() == inst->getType() && "uw");
14791512
} else {
1480-
Value *v = fixLCSSA(inst, nextScope);
1481-
assert(unwrapMode == UnwrapMode::AttemptSingleUnwrap);
1482-
___res = lookupM(v, B, prevAvailable, v != val);
1513+
BasicBlock *nS2 = nextScope;
1514+
Value *v = inst;
1515+
if (BasicBlock *forwardBlock = nextScope)
1516+
if (auto opinst = dyn_cast<Instruction>(v)) {
1517+
if (!isOriginalBlock(*forwardBlock)) {
1518+
forwardBlock = originalForReverseBlock(*forwardBlock);
1519+
}
1520+
if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) {
1521+
v = fixLCSSA(opinst, forwardBlock);
1522+
nS2 = nullptr;
1523+
}
1524+
}
1525+
___res = lookupM(v, B, prevAvailable, v != val, nS2);
14831526
if (___res && ___res->getType() != v->getType()) {
14841527
llvm::errs() << *newFunc << "\n";
14851528
llvm::errs() << " v = " << *v << " res = " << *___res << "\n";
@@ -1771,12 +1814,19 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
17711814
return nullptr;
17721815
}
17731816
}
1774-
if (scope)
1817+
BasicBlock *nS2 = scope;
1818+
if (BasicBlock *forwardBlock = scope)
17751819
if (auto opinst = dyn_cast<Instruction>(nval)) {
1776-
nval = fixLCSSA(opinst, scope);
1820+
if (!isOriginalBlock(*forwardBlock)) {
1821+
forwardBlock = originalForReverseBlock(*forwardBlock);
1822+
}
1823+
if (isPotentialLastLoopValue(opinst, forwardBlock, LI)) {
1824+
nval = fixLCSSA(opinst, forwardBlock);
1825+
nS2 = nullptr;
1826+
}
17771827
}
1778-
auto toreturn =
1779-
lookupM(nval, BuilderM, available, /*tryLegalRecomputeCheck*/ false);
1828+
auto toreturn = lookupM(nval, BuilderM, available,
1829+
/*tryLegalRecomputeCheck*/ false, nS2);
17801830
assert(val->getType() == toreturn->getType());
17811831
return toreturn;
17821832
}
@@ -4974,7 +5024,7 @@ end:;
49745024

49755025
Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
49765026
const ValueToValueMapTy &incoming_available,
4977-
bool tryLegalRecomputeCheck) {
5027+
bool tryLegalRecomputeCheck, BasicBlock *scope) {
49785028

49795029
assert(mode == DerivativeMode::ReverseModePrimal ||
49805030
mode == DerivativeMode::ReverseModeGradient ||
@@ -5014,6 +5064,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
50145064
}
50155065
assert(inst->getParent()->getParent() == newFunc);
50165066
assert(BuilderM.GetInsertBlock()->getParent() == newFunc);
5067+
if (scope == nullptr)
5068+
scope = BuilderM.GetInsertBlock();
5069+
assert(scope->getParent() == newFunc);
50175070

50185071
bool reduceRegister = false;
50195072

@@ -5241,12 +5294,14 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
52415294
Instruction *prelcssaInst = inst;
52425295

52435296
assert(inst->getName() != "<badref>");
5244-
val = fixLCSSA(inst, BuilderM.GetInsertBlock());
5297+
val = fixLCSSA(inst, scope);
52455298
if (isa<UndefValue>(val)) {
52465299
llvm::errs() << *oldFunc << "\n";
52475300
llvm::errs() << *newFunc << "\n";
52485301
llvm::errs() << *BuilderM.GetInsertBlock() << "\n";
5302+
llvm::errs() << *scope << "\n";
52495303
llvm::errs() << *val << " inst " << *inst << "\n";
5304+
assert(0 && "undef value upon lcssa");
52505305
}
52515306
inst = cast<Instruction>(val);
52525307
assert(prelcssaInst->getType() == inst->getType());
@@ -5273,7 +5328,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
52735328
(lrc = legalRecompute(prelcssaInst, available, &BuilderM))) {
52745329
if ((src = shouldRecompute(prelcssaInst, available, &BuilderM))) {
52755330
auto op = unwrapM(prelcssaInst, BuilderM, available,
5276-
UnwrapMode::AttemptSingleUnwrap);
5331+
UnwrapMode::AttemptSingleUnwrap, scope);
52775332
if (op) {
52785333
assert(op);
52795334
assert(op->getType());
@@ -5571,8 +5626,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
55715626
}
55725627
Value *recomp = unwrapM(
55735628
getNewFromOriginal(SI->getValueOperand()), BuilderM,
5574-
ThreadLookup, UnwrapMode::AttemptFullUnwrap,
5575-
/*scope*/ nullptr,
5629+
ThreadLookup, UnwrapMode::AttemptFullUnwrap, scope,
55765630
/*permitCache*/ false);
55775631
if (recomp) {
55785632
resultValue = recomp;
@@ -6033,7 +6087,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
60336087
" tryLegalRecomputeCheck: ", tryLegalRecomputeCheck);
60346088
}
60356089

6036-
BasicBlock *scope = inst->getParent();
6090+
BasicBlock *scopeI = inst->getParent();
60376091
if (auto origInst = isOriginal(inst)) {
60386092
auto found = rematerializableAllocations.find(origInst);
60396093
if (found != rematerializableAllocations.end())
@@ -6048,7 +6102,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
60486102
// within the loop, force an entry-level scope so there is no need
60496103
// to cache.
60506104
if (!cacheWholeAllocation)
6051-
scope = &newFunc->getEntryBlock();
6105+
scopeI = &newFunc->getEntryBlock();
60526106
}
60536107
} else {
60546108
for (auto pair : backwardsOnlyShadows) {
@@ -6057,13 +6111,13 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
60576111
pair.second.LI->contains(pinst->getParent())) {
60586112
auto found = invertedPointers.find(pair.first);
60596113
if (found != invertedPointers.end() && found->second == inst) {
6060-
scope = &newFunc->getEntryBlock();
6114+
scopeI = &newFunc->getEntryBlock();
60616115

60626116
// Prevent the phi node from being stored into the cache by creating
60636117
// it before the ensureLookupCached.
60646118
if (scopeMap.find(inst) == scopeMap.end()) {
60656119
LimitContext lctx(/*ReverseLimit*/ reverseBlocks.size() > 0,
6066-
scope);
6120+
scopeI);
60676121

60686122
AllocaInst *cache = createCacheForScope(
60696123
lctx, inst->getType(), inst->getName(), /*shouldFree*/ true);
@@ -6078,7 +6132,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
60786132
}
60796133
}
60806134

6081-
ensureLookupCached(inst, /*shouldFree*/ true, scope,
6135+
ensureLookupCached(inst, /*shouldFree*/ true, scopeI,
60826136
inst->getMetadata(LLVMContext::MD_tbaa));
60836137
bool isi1 = inst->getType()->isIntegerTy() &&
60846138
cast<IntegerType>(inst->getType())->getBitWidth() == 1;

enzyme/Enzyme/GradientUtils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,8 @@ class GradientUtils : public CacheUtility {
13381338
for (auto &BB : *inst->getParent()->getParent()) {
13391339
if (!seen.count(&BB) || (inst->getParent() != &BB &&
13401340
DT.dominates(&BB, inst->getParent()))) {
1341+
// OrigPDT.dominates(isOriginal(inst->getParent()),
1342+
// isOriginal(&BB)))) {
13411343
lcssaFixes[inst][&BB] = UndefValue::get(inst->getType());
13421344
}
13431345
}
@@ -1426,7 +1428,8 @@ class GradientUtils : public CacheUtility {
14261428
Value *
14271429
lookupM(Value *val, IRBuilder<> &BuilderM,
14281430
const ValueToValueMapTy &incoming_availalble = ValueToValueMapTy(),
1429-
bool tryLegalRecomputeCheck = true) override;
1431+
bool tryLegalRecomputeCheck = true,
1432+
llvm::BasicBlock *scope = nullptr) override;
14301433

14311434
Value *invertPointerM(Value *val, IRBuilder<> &BuilderM,
14321435
bool nullShadow = false);
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s
2+
3+
source_filename = "text"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
5+
target triple = "x86_64-linux-gnu"
6+
7+
define internal fastcc void @f({} addrspace(10)* %arg) {
8+
bb:
9+
br label %bb54
10+
11+
bb54: ; preds = %bb91, %bb42
12+
%i55 = phi i64 [ 0, %bb ], [ %i69, %bb91 ]
13+
%i60 = icmp slt i64 %i55, 4
14+
br i1 %i60, label %bb66, label %bb92
15+
16+
bb66: ; preds = %bb86, %bb61
17+
%i67 = phi i64 [ %i55, %bb54 ], [ %i69, %bb84 ]
18+
%i68 = phi {} addrspace(10)* [ null, %bb54 ], [ %i90, %bb84 ]
19+
%i69 = add nsw i64 %i67, 1
20+
%i71 = icmp eq {} addrspace(10)* %i68, %arg
21+
br i1 %i71, label %bb72, label %bb91
22+
23+
bb72: ; preds = %bb66
24+
%i74 = call {}* @julia.pointer_from_objref({} addrspace(10)* %i68)
25+
%i79 = icmp eq {}* %i74, null
26+
br i1 %i79, label %bb84, label %bb91
27+
28+
bb84: ; preds = %bb72
29+
%i85 = icmp slt i64 %i67, 3
30+
%i90 = call {} addrspace(10)* @__dynamic_cast()
31+
br i1 %i85, label %bb66, label %bb92
32+
33+
bb91: ; preds = %bb66, %bb72
34+
br label %bb54
35+
36+
bb92: ; preds = %bb84, %bb54, %bb35
37+
ret void
38+
}
39+
40+
; Function Attrs: nofree nounwind readnone
41+
declare nonnull {}* @julia.pointer_from_objref({} addrspace(10)*) local_unnamed_addr #4
42+
43+
; Function Attrs: nofree readonly
44+
declare nonnull {} addrspace(10)* @__dynamic_cast()
45+
46+
47+
declare dso_local void @__enzyme_autodiff(...)
48+
49+
define void @dsquare() local_unnamed_addr {
50+
bb:
51+
call void (...) @__enzyme_autodiff(i8* bitcast (void ({} addrspace(10)*)* @f to i8*), metadata !"enzyme_dup", {} addrspace(10)* null, {} addrspace(10)* null)
52+
ret void
53+
}
54+
55+
attributes #4 = { nounwind readnone }
56+
57+
58+
; CHECK: invertbb91:
59+
; CHECK-NEXT: %[[i31:.+]] = load i64, i64* %"iv'ac"
60+
; CHECK-NEXT: %[[i32:.+]] = load {} addrspace(10)**, {} addrspace(10)*** %"i68!manual_lcssa_cache", align 8
61+
; CHECK-NEXT: %[[i33:.+]] = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)** %[[i32]], i64 %[[i31]]
62+
; CHECK-NEXT: %[[i34:.+]] = load {} addrspace(10)*, {} addrspace(10)** %[[i33]], align 8
63+
; CHECK-NEXT: %i71_unwrap = icmp eq {} addrspace(10)* %[[i34]], %arg
64+
; CHECK-NEXT: br i1 %i71_unwrap, label %mergeinvertbb66_bb91, label %mergeinvertbb66_bb911
65+
66+
67+
; OLD: invertbb91:
68+
; OLD-NEXT: %[[i78:.+]] = load i64, i64* %"iv'ac"
69+
; OLD-NEXT: %[[i79:.+]] = load i64*, i64** %loopLimit_cache, align 8
70+
; OLD-NEXT: %[[i80:.+]] = getelementptr inbounds i64, i64* %[[i79]], i64 %[[i78]]
71+
; OLD-NEXT: %[[i81:.+]] = load i64, i64* %[[i80]], align 8
72+
; OLD-NEXT: %[[i82:.+]] = icmp ne i64 %[[i81]], 0
73+
; OLD-NEXT: br i1 %[[i82]], label %invertbb91_phirc, label %invertbb91_phirc6
74+
75+
; OLD: invertbb91_phirc:
76+
; OLD-NEXT: %[[i83:.+]] = sub nuw i64 %[[i81]], 1
77+
; OLD-NEXT: %[[i84:.+]] = load {} addrspace(10)***, {} addrspace(10)**** %i90_cache, align 8
78+
; OLD-NEXT: %[[i85:.+]] = getelementptr inbounds {} addrspace(10)**, {} addrspace(10)*** %[[i84]], i64 %[[i78]]
79+
; OLD-NEXT: %[[i86:.+]] = load {} addrspace(10)**, {} addrspace(10)*** %[[i85]], align 8
80+
; OLD-NEXT: %[[i87:.+]] = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)** %[[i86]], i64 %[[i83]]
81+
; OLD-NEXT: %[[i88:.+]] = load {} addrspace(10)*, {} addrspace(10)** %[[i87]], align 8
82+
; OLD-NEXT: br label %invertbb91_phimerge
83+
84+
; OLD: invertbb91_phirc6:
85+
; OLD-NEXT: br label %invertbb91_phimerge
86+
87+
; OLD: invertbb91_phimerge:
88+
; OLD-NEXT: %[[i89:.+]] = phi {} addrspace(10)* [ %[[i88]], %invertbb91_phirc ], [ null, %invertbb91_phirc6 ]
89+
; OLD-NEXT: %i71_unwrap = icmp eq {} addrspace(10)* %[[i89]], %arg

0 commit comments

Comments
 (0)