Skip to content

Commit f6ca557

Browse files
authored
Fix the incredibly stupid plug AssertingVH bug (#4)
* start canonical fixups and call it a night * Horrible horrible hacks * Working tests * working LLVM 6? * don't forget to prepend fake * Fix default behavior for unknown constant intrinsic * cleanup prints
1 parent d82ac16 commit f6ca557

File tree

12 files changed

+4930
-197
lines changed

12 files changed

+4930
-197
lines changed

enzyme/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ project(Enzyme)
33
SET(CMAKE_CXX_FLAGS "-Wall")
44
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g")
55
SET(CMAKE_CXX_FLAGS_RELEASE "-O2")
6-
SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g")
6+
7+
SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -fno-omit-frame-pointer")
8+
9+
#SET(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -fno-omit-frame-pointer -fsanitize=address")
10+
#SET(CMAKE_LINKER_FLAGS_DEBUG "${CMAKE_LINKER_FLAGS_DEBUG} -fno-omit-frame-pointer -fsanitize=address")
711

812
set(CMAKE_CXX_STANDARD 11)
913
cmake_minimum_required(VERSION 3.5)

enzyme/Enzyme/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
if (${LLVM_VERSION_MAJOR} LESS 8)
66
add_llvm_loadable_module( LLVMEnzyme-${LLVM_VERSION_MAJOR}
77
Enzyme.cpp
8+
SCEV/ScalarEvolutionExpander.cpp
89
DEPENDS
910
intrinsics_gen
1011
PLUGIN_TOOL
@@ -13,6 +14,7 @@ if (${LLVM_VERSION_MAJOR} LESS 8)
1314
else()
1415
add_llvm_library( LLVMEnzyme-${LLVM_VERSION_MAJOR}
1516
Enzyme.cpp
17+
SCEV/ScalarEvolutionExpander.cpp
1618
MODULE
1719
DEPENDS
1820
intrinsics_gen

enzyme/Enzyme/Enzyme.cpp

Lines changed: 49 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
#include <llvm/Config/llvm-config.h>
2020

21+
#include "SCEV/ScalarEvolutionExpander.h"
22+
2123
#include "llvm/Transforms/Utils/PromoteMemToReg.h"
2224
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
2325
#include "llvm/Transforms/Scalar/GVN.h"
@@ -827,7 +829,11 @@ void forceRecursiveInlining(Function *NewF, const Function* F) {
827829
}
828830
}
829831

830-
Function* preprocessForClone(Function *F, AAResults &AA) {
832+
class GradientUtils;
833+
834+
PHINode* canonicalizeIVs(Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils* gutils);
835+
836+
Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) {
831837
static std::map<Function*,Function*> cache;
832838
if (cache.find(F) != cache.end()) return cache[F];
833839

@@ -1071,7 +1077,7 @@ Function* preprocessForClone(Function *F, AAResults &AA) {
10711077
DSEPass().run(*NewF, AM);
10721078
LoopSimplifyPass().run(*NewF, AM);
10731079

1074-
}
1080+
}
10751081

10761082
if (autodiff_print)
10771083
llvm::errs() << "after simplification :\n" << *NewF << "\n";
@@ -1084,9 +1090,9 @@ Function* preprocessForClone(Function *F, AAResults &AA) {
10841090
return NewF;
10851091
}
10861092

1087-
Function *CloneFunctionWithReturns(Function *&F, AAResults &AA, ValueToValueMapTy& ptrInputs, const std::set<unsigned>& constant_args, SmallPtrSetImpl<Value*> &constants, SmallPtrSetImpl<Value*> &nonconstant, SmallPtrSetImpl<Value*> &returnvals, ReturnType returnValue, bool differentialReturn, Twine name, ValueToValueMapTy *VMapO, bool diffeReturnArg, llvm::Type* additionalArg = nullptr) {
1093+
Function *CloneFunctionWithReturns(Function *&F, AAResults &AA, TargetLibraryInfo &TLI, ValueToValueMapTy& ptrInputs, const std::set<unsigned>& constant_args, SmallPtrSetImpl<Value*> &constants, SmallPtrSetImpl<Value*> &nonconstant, SmallPtrSetImpl<Value*> &returnvals, ReturnType returnValue, bool differentialReturn, Twine name, ValueToValueMapTy *VMapO, bool diffeReturnArg, llvm::Type* additionalArg = nullptr) {
10881094
assert(!F->empty());
1089-
F = preprocessForClone(F, AA);
1095+
F = preprocessForClone(F, AA, TLI);
10901096
diffeReturnArg &= differentialReturn;
10911097
std::vector<Type*> RetTypes;
10921098
if (returnValue == ReturnType::ArgsWithReturn)
@@ -1257,15 +1263,6 @@ Function *CloneFunctionWithReturns(Function *&F, AAResults &AA, ValueToValueMapT
12571263
#include "llvm/IR/Constant.h"
12581264
#include <deque>
12591265
#include "llvm/IR/CFG.h"
1260-
class GradientUtils;
1261-
1262-
PHINode* canonicalizeIVs(Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils *gutils);
1263-
1264-
/// \brief Replace the latch of the loop to check that IV is always less than or
1265-
/// equal to the limit.
1266-
///
1267-
/// This method assumes that the loop has a single loop latch.
1268-
Value* canonicalizeLoopLatch(PHINode *IV, Value *Limit, Loop* L, ScalarEvolution &SE, BasicBlock* ExitBlock, GradientUtils *gutils);
12691266

12701267
bool shouldRecompute(Value* val, const ValueToValueMapTy& available) {
12711268
if (available.count(val)) return false;
@@ -1943,7 +1940,7 @@ class GradientUtils {
19431940
SmallPtrSet<Value*,20> nonconstant;
19441941
SmallPtrSet<Value*,2> returnvals;
19451942
ValueToValueMapTy originalToNew;
1946-
auto newFunc = CloneFunctionWithReturns(todiff, AA, invertedPointers, constant_args, constants, nonconstant, returnvals, /*returnValue*/returnValue, /*differentialReturn*/differentialReturn, "fakeaugmented_"+todiff->getName(), &originalToNew, /*diffeReturnArg*/false, additionalArg);
1943+
auto newFunc = CloneFunctionWithReturns(todiff, AA, TLI, invertedPointers, constant_args, constants, nonconstant, returnvals, /*returnValue*/returnValue, /*differentialReturn*/differentialReturn, "fakeaugmented_"+todiff->getName(), &originalToNew, /*diffeReturnArg*/false, additionalArg);
19471944
auto res = new GradientUtils(newFunc, AA, TLI, invertedPointers, constants, nonconstant, returnvals, originalToNew);
19481945
res->oldFunc = todiff;
19491946
return res;
@@ -2230,7 +2227,6 @@ class GradientUtils {
22302227
IRBuilder <> v(putafter);
22312228
v.setFastMathFlags(getFast());
22322229
v.CreateStore(inst, scopeMap[inst]);
2233-
llvm::errs() << " place foo\n"; dumpSet(originalInstructions);
22342230
} else {
22352231

22362232
ValueToValueMapTy valmap;
@@ -2722,7 +2718,7 @@ class DiffeGradientUtils : public GradientUtils {
27222718
SmallPtrSet<Value*,20> nonconstant;
27232719
SmallPtrSet<Value*,2> returnvals;
27242720
ValueToValueMapTy originalToNew;
2725-
auto newFunc = CloneFunctionWithReturns(todiff, AA, invertedPointers, constant_args, constants, nonconstant, returnvals, returnValue, differentialReturn, "diffe"+todiff->getName(), &originalToNew, /*diffeReturnArg*/true, additionalArg);
2721+
auto newFunc = CloneFunctionWithReturns(todiff, AA, TLI, invertedPointers, constant_args, constants, nonconstant, returnvals, returnValue, differentialReturn, "diffe"+todiff->getName(), &originalToNew, /*diffeReturnArg*/true, additionalArg);
27262722
auto res = new DiffeGradientUtils(newFunc, AA, TLI, invertedPointers, constants, nonconstant, returnvals, originalToNew);
27272723
res->oldFunc = todiff;
27282724
return res;
@@ -3062,6 +3058,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
30623058
case Intrinsic::cos:
30633059
break;
30643060
default:
3061+
if (gutils->isConstantInstruction(inst)) continue;
30653062
assert(inst);
30663063
llvm::errs() << "cannot handle (augmented) unknown intrinsic\n" << *inst;
30673064
report_fatal_error("(augmented) unknown intrinsic");
@@ -3695,7 +3692,7 @@ std::pair<SmallVector<Type*,4>,SmallVector<Type*,4>> getDefaultFunctionTypeForGr
36953692
return std::pair<SmallVector<Type*,4>,SmallVector<Type*,4>>(args, outs);
36963693
}
36973694

3698-
Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg) {
3695+
Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg) {
36993696
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*retval*/, bool/*differentialReturn*/, bool/*topLevel*/, llvm::Type*>, Function*> cachedfunctions;
37003697
auto tup = std::make_tuple(todiff, std::set<unsigned>(constant_args.begin(), constant_args.end()), returnValue, differentialReturn, topLevel, additionalArg);
37013698
if (cachedfunctions.find(tup) != cachedfunctions.end()) {
@@ -3783,10 +3780,10 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
37833780

37843781
DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(todiff, AA, TLI, constant_args, returnValue ? ReturnType::ArgsWithReturn : ReturnType::Args, differentialReturn, additionalArg);
37853782
cachedfunctions[tup] = gutils->newFunc;
3786-
3783+
37873784
gutils->forceContexts();
37883785
gutils->forceAugmentedReturns();
3789-
3786+
37903787
Argument* additionalValue = nullptr;
37913788
if (additionalArg) {
37923789
auto v = gutils->newFunc->arg_end();
@@ -3810,7 +3807,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
38103807

38113808
std::map<ReturnInst*,StoreInst*> replacedReturns;
38123809

3813-
38143810
for(BasicBlock* BB: gutils->originalBlocks) {
38153811

38163812
LoopContext loopContext;
@@ -4141,6 +4137,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
41414137
break;
41424138
}
41434139
default:
4140+
if (gutils->isConstantInstruction(inst)) continue;
41444141
assert(inst);
41454142
llvm::errs() << "cannot handle unknown intrinsic\n" << *inst;
41464143
report_fatal_error("unknown intrinsic");
@@ -4239,7 +4236,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
42394236
if (auto dc = dyn_cast<CallInst>(val)) {
42404237
if (dc->getCalledFunction()->getName() == "malloc") {
42414238
gutils->erase(op);
4242-
llvm::errs() << " place free\n"; dumpSet(gutils->originalInstructions);
42434239
continue;
42444240
}
42454241
}
@@ -4903,6 +4899,8 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
49034899
}
49044900

49054901
void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, LoopInfo& LI, DominatorTree& DT) {
4902+
4903+
49064904
Value* fn = CI->getArgOperand(0);
49074905

49084906
while (auto ci = dyn_cast<CastInst>(fn)) {
@@ -4916,7 +4914,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
49164914
}
49174915
auto FT = cast<Function>(fn)->getFunctionType();
49184916
assert(fn);
4919-
4917+
49204918
if (autodiff_print)
49214919
llvm::errs() << "prefn:\n" << *fn << "\n";
49224920

@@ -5006,6 +5004,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
50065004
}
50075005

50085006
bool differentialReturn = cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy();
5007+
50095008
auto newFunc = CreatePrimalAndGradient(cast<Function>(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr);//, LI, DT);
50105009

50115010
if (differentialReturn)
@@ -5029,13 +5028,14 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
50295028
}
50305029

50315030
static bool lowerAutodiffIntrinsic(Function &F, TargetLibraryInfo &TLI, AAResults &AA) {//, LoopInfo& LI, DominatorTree& DT) {
5031+
50325032
bool Changed = false;
50335033

5034+
reset:
50345035
for (BasicBlock &BB : F) {
50355036

5036-
for (auto BI = BB.rbegin(), BE = BB.rend(); BI != BE;) {
5037-
Instruction *Inst = &*BI++;
5038-
CallInst *CI = dyn_cast_or_null<CallInst>(Inst);
5037+
for (auto BI = BB.rbegin(), BE = BB.rend(); BI != BE; BI++) {
5038+
CallInst *CI = dyn_cast<CallInst>(&*BI);
50395039
if (!CI) continue;
50405040

50415041
Function *Fn = CI->getCalledFunction();
@@ -5049,103 +5049,33 @@ static bool lowerAutodiffIntrinsic(Function &F, TargetLibraryInfo &TLI, AAResult
50495049
if (Fn && ( Fn->getName() == "__enzyme_autodiff" || Fn->getName().startswith("__enzyme_autodiff")) ) {
50505050
HandleAutoDiff(CI, TLI, AA);//, LI, DT);
50515051
Changed = true;
5052+
goto reset;
50525053
}
50535054
}
50545055
}
50555056

50565057
return Changed;
50575058
}
50585059

5059-
PHINode* canonicalizeIVs(Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils *gutils) {
5060-
//PHINode* pn = L->getCanonicalInductionVariable();
5061-
//assert( pn && "canonical IV");
5062-
//return pn;
5063-
5060+
PHINode* canonicalizeIVs(Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils* gutils) {
50645061

5065-
PHINode *CanonicalIV;
5066-
5067-
/*
5068-
{
5069-
SCEVExpander e(SE, L->getHeader()->getParent()->getParent()->getDataLayout(), "ad");
5070-
5071-
assert(Ty->isIntegerTy() && "Can only insert integer induction variables!");
5072-
5073-
// Build a SCEV for {0,+,1}<L>.
5074-
// Conservatively use FlagAnyWrap for now.
5075-
const SCEV *H = SE.getAddRecExpr(SE.getConstant(Ty, 0),
5076-
SE.getConstant(Ty, 1), L, SCEV::FlagAnyWrap);
5077-
5078-
// Emit code for it.
5079-
e.setInsertPoint(&L->getHeader()->front());
5080-
Value *V = e.expand(H);
5081-
5082-
CanonicalIV = cast<PHINode>(V); //e.expandCodeFor(H, nullptr));
5083-
}
5084-
*/
5085-
5086-
BasicBlock* Header = L->getHeader();
5087-
Module* M = Header->getParent()->getParent();
5088-
const DataLayout &DL = M->getDataLayout();
5089-
SmallVector<Instruction*, 16> DeadInsts;
5090-
5091-
{
5092-
SCEVExpander Exp(SE, DL, "ad");
5093-
5094-
CanonicalIV = Exp.getOrInsertCanonicalInductionVariable(L, Ty);
5062+
fake::SCEVExpander e(SE, L->getHeader()->getParent()->getParent()->getDataLayout(), "ad");
5063+
5064+
PHINode *CanonicalIV = e.getOrInsertCanonicalInductionVariable(L, Ty);
5065+
5066+
assert (CanonicalIV && "canonicalizing IV");
50955067

5096-
assert (CanonicalIV && "canonicalizing IV");
5097-
//DEBUG(dbgs() << "Canonical induction variable " << *CanonicalIV << "\n");
5098-
50995068
SmallVector<WeakTrackingVH, 16> DeadInst0;
5100-
Exp.replaceCongruentIVs(L, &DT, DeadInst0);
5069+
e.replaceCongruentIVs(L, &DT, DeadInst0);
51015070
for (WeakTrackingVH V : DeadInst0) {
5102-
//DeadInsts.push_back(cast<Instruction>(V));
5103-
}
5104-
5071+
gutils->erase(cast<Instruction>(V)); //->eraseFromParent();
51055072
}
51065073

5107-
for (Instruction* I : DeadInsts) {
5108-
if (gutils) gutils->erase(I);
5109-
}
51105074

51115075
return CanonicalIV;
51125076

51135077
}
51145078

5115-
Value* canonicalizeLoopLatch(PHINode *IV, Value *Limit, Loop* L, ScalarEvolution &SE, BasicBlock* ExitBlock, GradientUtils *gutils) {
5116-
Value *NewCondition;
5117-
BasicBlock *Header = L->getHeader();
5118-
BasicBlock *Latch = L->getLoopLatch();
5119-
assert(Latch && "No single loop latch found for loop.");
5120-
5121-
IRBuilder<> Builder(&*Latch->getFirstInsertionPt());
5122-
Builder.setFastMathFlags(getFast());
5123-
5124-
// This process assumes that IV's increment is in Latch.
5125-
5126-
// Create comparison between IV and Limit at top of Latch.
5127-
NewCondition = Builder.CreateICmpULT(IV, Limit);
5128-
5129-
// Replace the conditional branch at the end of Latch.
5130-
BranchInst *LatchBr = dyn_cast_or_null<BranchInst>(Latch->getTerminator());
5131-
assert(LatchBr && LatchBr->isConditional() &&
5132-
"Latch does not terminate with a conditional branch.");
5133-
Builder.SetInsertPoint(Latch->getTerminator());
5134-
Builder.CreateCondBr(NewCondition, Header, ExitBlock);
5135-
5136-
// Erase the old conditional branch.
5137-
Value *OldCond = LatchBr->getCondition();
5138-
if (gutils) gutils->erase(LatchBr);
5139-
5140-
if (!OldCond->hasNUsesOrMore(1))
5141-
if (Instruction *OldCondInst = dyn_cast<Instruction>(OldCond)) {
5142-
if (gutils) gutils->erase(OldCondInst);
5143-
}
5144-
5145-
5146-
return NewCondition;
5147-
}
5148-
51495079
bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopContext> &loopContexts, LoopInfo &LI,ScalarEvolution &SE,DominatorTree &DT, GradientUtils &gutils) {
51505080
if (auto L = LI.getLoopFor(BB)) {
51515081
if (loopContexts.find(L) != loopContexts.end()) {
@@ -5234,13 +5164,10 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
52345164
CanonicalSCEV, Limit) &&
52355165
"Loop backedge is not guarded by canonical comparison with limit.");
52365166

5237-
SCEVExpander Exp(SE, Preheader->getParent()->getParent()->getDataLayout(), "ad");
5167+
fake::SCEVExpander Exp(SE, Preheader->getParent()->getParent()->getDataLayout(), "ad");
52385168
LimitVar = Exp.expandCodeFor(Limit, CanonicalIV->getType(),
52395169
Preheader->getTerminator());
52405170

5241-
// Canonicalize the loop latch.
5242-
canonicalizeLoopLatch(CanonicalIV, LimitVar, L, SE, ExitBlock, &gutils);
5243-
52445171
loopContext.dynamic = false;
52455172
} else {
52465173

@@ -5273,7 +5200,7 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
52735200

52745201
// Remove Canonicalizable IV's
52755202
{
5276-
SCEVExpander Exp(SE, Preheader->getParent()->getParent()->getDataLayout(), "ad");
5203+
fake::SCEVExpander Exp(SE, Preheader->getParent()->getParent()->getDataLayout(), "ad");
52775204
for (BasicBlock::iterator II = Header->begin(); isa<PHINode>(II); ++II) {
52785205
PHINode *PN = cast<PHINode>(II);
52795206
if (PN == CanonicalIV) continue;
@@ -5336,12 +5263,24 @@ class Enzyme : public FunctionPass {
53365263
AU.addRequired<AAResultsWrapperPass>();
53375264
AU.addRequired<GlobalsAAWrapperPass>();
53385265
AU.addRequiredID(LoopSimplifyID);
5339-
AU.addRequiredID(LCSSAID);
5266+
//AU.addRequiredID(LCSSAID);
5267+
5268+
AU.addRequired<LoopInfoWrapperPass>();
5269+
AU.addRequired<DominatorTreeWrapperPass>();
5270+
AU.addRequired<ScalarEvolutionWrapperPass>();
53405271
}
53415272

53425273
bool runOnFunction(Function &F) override {
53435274
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
53445275
auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
5276+
5277+
/*
5278+
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
5279+
auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
5280+
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
5281+
*/
5282+
5283+
53455284
return lowerAutodiffIntrinsic(F, TLI, AA);
53465285
}
53475286
};

0 commit comments

Comments
 (0)