1818
1919#include < llvm/Config/llvm-config.h>
2020
21+ #include " SCEV/ScalarEvolutionExpander.h"
22+
2123#include " llvm/Transforms/Utils/PromoteMemToReg.h"
2224#include " llvm/Transforms/Utils/BasicBlockUtils.h"
2325#include " llvm/Transforms/Scalar/GVN.h"
@@ -827,7 +829,11 @@ void forceRecursiveInlining(Function *NewF, const Function* F) {
827829 }
828830}
829831
830- Function* preprocessForClone (Function *F, AAResults &AA) {
832+ class GradientUtils ;
833+
834+ PHINode* canonicalizeIVs (Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils* gutils);
835+
836+ Function* preprocessForClone (Function *F, AAResults &AA, TargetLibraryInfo &TLI) {
831837 static std::map<Function*,Function*> cache;
832838 if (cache.find (F) != cache.end ()) return cache[F];
833839
@@ -1071,7 +1077,7 @@ Function* preprocessForClone(Function *F, AAResults &AA) {
10711077 DSEPass ().run (*NewF, AM);
10721078 LoopSimplifyPass ().run (*NewF, AM);
10731079
1074- }
1080+ }
10751081
10761082 if (autodiff_print)
10771083 llvm::errs () << " after simplification :\n " << *NewF << " \n " ;
@@ -1084,9 +1090,9 @@ Function* preprocessForClone(Function *F, AAResults &AA) {
10841090 return NewF;
10851091}
10861092
1087- Function *CloneFunctionWithReturns (Function *&F, AAResults &AA, ValueToValueMapTy& ptrInputs, const std::set<unsigned >& constant_args, SmallPtrSetImpl<Value*> &constants, SmallPtrSetImpl<Value*> &nonconstant, SmallPtrSetImpl<Value*> &returnvals, ReturnType returnValue, bool differentialReturn, Twine name, ValueToValueMapTy *VMapO, bool diffeReturnArg, llvm::Type* additionalArg = nullptr ) {
1093+ Function *CloneFunctionWithReturns (Function *&F, AAResults &AA, TargetLibraryInfo &TLI, ValueToValueMapTy& ptrInputs, const std::set<unsigned >& constant_args, SmallPtrSetImpl<Value*> &constants, SmallPtrSetImpl<Value*> &nonconstant, SmallPtrSetImpl<Value*> &returnvals, ReturnType returnValue, bool differentialReturn, Twine name, ValueToValueMapTy *VMapO, bool diffeReturnArg, llvm::Type* additionalArg = nullptr ) {
10881094 assert (!F->empty ());
1089- F = preprocessForClone (F, AA);
1095+ F = preprocessForClone (F, AA, TLI );
10901096 diffeReturnArg &= differentialReturn;
10911097 std::vector<Type*> RetTypes;
10921098 if (returnValue == ReturnType::ArgsWithReturn)
@@ -1257,15 +1263,6 @@ Function *CloneFunctionWithReturns(Function *&F, AAResults &AA, ValueToValueMapT
12571263#include " llvm/IR/Constant.h"
12581264#include < deque>
12591265#include " llvm/IR/CFG.h"
1260- class GradientUtils ;
1261-
1262- PHINode* canonicalizeIVs (Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils *gutils);
1263-
1264- // / \brief Replace the latch of the loop to check that IV is always less than or
1265- // / equal to the limit.
1266- // /
1267- // / This method assumes that the loop has a single loop latch.
1268- Value* canonicalizeLoopLatch (PHINode *IV, Value *Limit, Loop* L, ScalarEvolution &SE, BasicBlock* ExitBlock, GradientUtils *gutils);
12691266
12701267bool shouldRecompute (Value* val, const ValueToValueMapTy& available) {
12711268 if (available.count (val)) return false ;
@@ -1943,7 +1940,7 @@ class GradientUtils {
19431940 SmallPtrSet<Value*,20 > nonconstant;
19441941 SmallPtrSet<Value*,2 > returnvals;
19451942 ValueToValueMapTy originalToNew;
1946- auto newFunc = CloneFunctionWithReturns (todiff, AA, invertedPointers, constant_args, constants, nonconstant, returnvals, /* returnValue*/ returnValue, /* differentialReturn*/ differentialReturn, " fakeaugmented_" +todiff->getName (), &originalToNew, /* diffeReturnArg*/ false , additionalArg);
1943+ auto newFunc = CloneFunctionWithReturns (todiff, AA, TLI, invertedPointers, constant_args, constants, nonconstant, returnvals, /* returnValue*/ returnValue, /* differentialReturn*/ differentialReturn, " fakeaugmented_" +todiff->getName (), &originalToNew, /* diffeReturnArg*/ false , additionalArg);
19471944 auto res = new GradientUtils (newFunc, AA, TLI, invertedPointers, constants, nonconstant, returnvals, originalToNew);
19481945 res->oldFunc = todiff;
19491946 return res;
@@ -2230,7 +2227,6 @@ class GradientUtils {
22302227 IRBuilder <> v (putafter);
22312228 v.setFastMathFlags (getFast ());
22322229 v.CreateStore (inst, scopeMap[inst]);
2233- llvm::errs () << " place foo\n " ; dumpSet (originalInstructions);
22342230 } else {
22352231
22362232 ValueToValueMapTy valmap;
@@ -2722,7 +2718,7 @@ class DiffeGradientUtils : public GradientUtils {
27222718 SmallPtrSet<Value*,20 > nonconstant;
27232719 SmallPtrSet<Value*,2 > returnvals;
27242720 ValueToValueMapTy originalToNew;
2725- auto newFunc = CloneFunctionWithReturns (todiff, AA, invertedPointers, constant_args, constants, nonconstant, returnvals, returnValue, differentialReturn, " diffe" +todiff->getName (), &originalToNew, /* diffeReturnArg*/ true , additionalArg);
2721+ auto newFunc = CloneFunctionWithReturns (todiff, AA, TLI, invertedPointers, constant_args, constants, nonconstant, returnvals, returnValue, differentialReturn, " diffe" +todiff->getName (), &originalToNew, /* diffeReturnArg*/ true , additionalArg);
27262722 auto res = new DiffeGradientUtils (newFunc, AA, TLI, invertedPointers, constants, nonconstant, returnvals, originalToNew);
27272723 res->oldFunc = todiff;
27282724 return res;
@@ -3062,6 +3058,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
30623058 case Intrinsic::cos:
30633059 break ;
30643060 default :
3061+ if (gutils->isConstantInstruction (inst)) continue ;
30653062 assert (inst);
30663063 llvm::errs () << " cannot handle (augmented) unknown intrinsic\n " << *inst;
30673064 report_fatal_error (" (augmented) unknown intrinsic" );
@@ -3695,7 +3692,7 @@ std::pair<SmallVector<Type*,4>,SmallVector<Type*,4>> getDefaultFunctionTypeForGr
36953692 return std::pair<SmallVector<Type*,4 >,SmallVector<Type*,4 >>(args, outs);
36963693}
36973694
3698- Function* CreatePrimalAndGradient (Function* todiff, const std::set<unsigned >& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg) {
3695+ Function* CreatePrimalAndGradient (Function* todiff, const std::set<unsigned >& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg) {
36993696 static std::map<std::tuple<Function*,std::set<unsigned >, bool /* retval*/ , bool /* differentialReturn*/ , bool /* topLevel*/ , llvm::Type*>, Function*> cachedfunctions;
37003697 auto tup = std::make_tuple (todiff, std::set<unsigned >(constant_args.begin (), constant_args.end ()), returnValue, differentialReturn, topLevel, additionalArg);
37013698 if (cachedfunctions.find (tup) != cachedfunctions.end ()) {
@@ -3783,10 +3780,10 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
37833780
37843781 DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone (todiff, AA, TLI, constant_args, returnValue ? ReturnType::ArgsWithReturn : ReturnType::Args, differentialReturn, additionalArg);
37853782 cachedfunctions[tup] = gutils->newFunc ;
3786-
3783+
37873784 gutils->forceContexts ();
37883785 gutils->forceAugmentedReturns ();
3789-
3786+
37903787 Argument* additionalValue = nullptr ;
37913788 if (additionalArg) {
37923789 auto v = gutils->newFunc ->arg_end ();
@@ -3810,7 +3807,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
38103807
38113808 std::map<ReturnInst*,StoreInst*> replacedReturns;
38123809
3813-
38143810 for (BasicBlock* BB: gutils->originalBlocks ) {
38153811
38163812 LoopContext loopContext;
@@ -4141,6 +4137,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
41414137 break ;
41424138 }
41434139 default :
4140+ if (gutils->isConstantInstruction (inst)) continue ;
41444141 assert (inst);
41454142 llvm::errs () << " cannot handle unknown intrinsic\n " << *inst;
41464143 report_fatal_error (" unknown intrinsic" );
@@ -4239,7 +4236,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
42394236 if (auto dc = dyn_cast<CallInst>(val)) {
42404237 if (dc->getCalledFunction ()->getName () == " malloc" ) {
42414238 gutils->erase (op);
4242- llvm::errs () << " place free\n " ; dumpSet (gutils->originalInstructions );
42434239 continue ;
42444240 }
42454241 }
@@ -4903,6 +4899,8 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
49034899}
49044900
49054901void HandleAutoDiff (CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {// , LoopInfo& LI, DominatorTree& DT) {
4902+
4903+
49064904 Value* fn = CI->getArgOperand (0 );
49074905
49084906 while (auto ci = dyn_cast<CastInst>(fn)) {
@@ -4916,7 +4914,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
49164914 }
49174915 auto FT = cast<Function>(fn)->getFunctionType ();
49184916 assert (fn);
4919-
4917+
49204918 if (autodiff_print)
49214919 llvm::errs () << " prefn:\n " << *fn << " \n " ;
49224920
@@ -5006,6 +5004,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
50065004 }
50075005
50085006 bool differentialReturn = cast<Function>(fn)->getReturnType ()->isFPOrFPVectorTy ();
5007+
50095008 auto newFunc = CreatePrimalAndGradient (cast<Function>(fn), constants, TLI, AA, /* should return*/ false , differentialReturn, /* topLevel*/ true , /* addedType*/ nullptr );// , LI, DT);
50105009
50115010 if (differentialReturn)
@@ -5029,13 +5028,14 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
50295028}
50305029
50315030static bool lowerAutodiffIntrinsic (Function &F, TargetLibraryInfo &TLI, AAResults &AA) {// , LoopInfo& LI, DominatorTree& DT) {
5031+
50325032 bool Changed = false ;
50335033
5034+ reset:
50345035 for (BasicBlock &BB : F) {
50355036
5036- for (auto BI = BB.rbegin (), BE = BB.rend (); BI != BE;) {
5037- Instruction *Inst = &*BI++;
5038- CallInst *CI = dyn_cast_or_null<CallInst>(Inst);
5037+ for (auto BI = BB.rbegin (), BE = BB.rend (); BI != BE; BI++) {
5038+ CallInst *CI = dyn_cast<CallInst>(&*BI);
50395039 if (!CI) continue ;
50405040
50415041 Function *Fn = CI->getCalledFunction ();
@@ -5049,103 +5049,33 @@ static bool lowerAutodiffIntrinsic(Function &F, TargetLibraryInfo &TLI, AAResult
50495049 if (Fn && ( Fn->getName () == " __enzyme_autodiff" || Fn->getName ().startswith (" __enzyme_autodiff" )) ) {
50505050 HandleAutoDiff (CI, TLI, AA);// , LI, DT);
50515051 Changed = true ;
5052+ goto reset;
50525053 }
50535054 }
50545055 }
50555056
50565057 return Changed;
50575058}
50585059
5059- PHINode* canonicalizeIVs (Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils *gutils) {
5060- // PHINode* pn = L->getCanonicalInductionVariable();
5061- // assert( pn && "canonical IV");
5062- // return pn;
5063-
5060+ PHINode* canonicalizeIVs (Type *Ty, Loop *L, ScalarEvolution &SE, DominatorTree &DT, GradientUtils* gutils) {
50645061
5065- PHINode *CanonicalIV;
5066-
5067- /*
5068- {
5069- SCEVExpander e(SE, L->getHeader()->getParent()->getParent()->getDataLayout(), "ad");
5070-
5071- assert(Ty->isIntegerTy() && "Can only insert integer induction variables!");
5072-
5073- // Build a SCEV for {0,+,1}<L>.
5074- // Conservatively use FlagAnyWrap for now.
5075- const SCEV *H = SE.getAddRecExpr(SE.getConstant(Ty, 0),
5076- SE.getConstant(Ty, 1), L, SCEV::FlagAnyWrap);
5077-
5078- // Emit code for it.
5079- e.setInsertPoint(&L->getHeader()->front());
5080- Value *V = e.expand(H);
5081-
5082- CanonicalIV = cast<PHINode>(V); //e.expandCodeFor(H, nullptr));
5083- }
5084- */
5085-
5086- BasicBlock* Header = L->getHeader ();
5087- Module* M = Header->getParent ()->getParent ();
5088- const DataLayout &DL = M->getDataLayout ();
5089- SmallVector<Instruction*, 16 > DeadInsts;
5090-
5091- {
5092- SCEVExpander Exp (SE, DL, " ad" );
5093-
5094- CanonicalIV = Exp.getOrInsertCanonicalInductionVariable (L, Ty);
5062+ fake::SCEVExpander e (SE, L->getHeader ()->getParent ()->getParent ()->getDataLayout (), " ad" );
5063+
5064+ PHINode *CanonicalIV = e.getOrInsertCanonicalInductionVariable (L, Ty);
5065+
5066+ assert (CanonicalIV && " canonicalizing IV" );
50955067
5096- assert (CanonicalIV && " canonicalizing IV" );
5097- // DEBUG(dbgs() << "Canonical induction variable " << *CanonicalIV << "\n");
5098-
50995068 SmallVector<WeakTrackingVH, 16 > DeadInst0;
5100- Exp .replaceCongruentIVs (L, &DT, DeadInst0);
5069+ e .replaceCongruentIVs (L, &DT, DeadInst0);
51015070 for (WeakTrackingVH V : DeadInst0) {
5102- // DeadInsts.push_back(cast<Instruction>(V));
5103- }
5104-
5071+ gutils->erase (cast<Instruction>(V)); // ->eraseFromParent();
51055072 }
51065073
5107- for (Instruction* I : DeadInsts) {
5108- if (gutils) gutils->erase (I);
5109- }
51105074
51115075 return CanonicalIV;
51125076
51135077}
51145078
5115- Value* canonicalizeLoopLatch (PHINode *IV, Value *Limit, Loop* L, ScalarEvolution &SE, BasicBlock* ExitBlock, GradientUtils *gutils) {
5116- Value *NewCondition;
5117- BasicBlock *Header = L->getHeader ();
5118- BasicBlock *Latch = L->getLoopLatch ();
5119- assert (Latch && " No single loop latch found for loop." );
5120-
5121- IRBuilder<> Builder (&*Latch->getFirstInsertionPt ());
5122- Builder.setFastMathFlags (getFast ());
5123-
5124- // This process assumes that IV's increment is in Latch.
5125-
5126- // Create comparison between IV and Limit at top of Latch.
5127- NewCondition = Builder.CreateICmpULT (IV, Limit);
5128-
5129- // Replace the conditional branch at the end of Latch.
5130- BranchInst *LatchBr = dyn_cast_or_null<BranchInst>(Latch->getTerminator ());
5131- assert (LatchBr && LatchBr->isConditional () &&
5132- " Latch does not terminate with a conditional branch." );
5133- Builder.SetInsertPoint (Latch->getTerminator ());
5134- Builder.CreateCondBr (NewCondition, Header, ExitBlock);
5135-
5136- // Erase the old conditional branch.
5137- Value *OldCond = LatchBr->getCondition ();
5138- if (gutils) gutils->erase (LatchBr);
5139-
5140- if (!OldCond->hasNUsesOrMore (1 ))
5141- if (Instruction *OldCondInst = dyn_cast<Instruction>(OldCond)) {
5142- if (gutils) gutils->erase (OldCondInst);
5143- }
5144-
5145-
5146- return NewCondition;
5147- }
5148-
51495079bool getContextM (BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopContext> &loopContexts, LoopInfo &LI,ScalarEvolution &SE,DominatorTree &DT, GradientUtils &gutils) {
51505080 if (auto L = LI.getLoopFor (BB)) {
51515081 if (loopContexts.find (L) != loopContexts.end ()) {
@@ -5234,13 +5164,10 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
52345164 CanonicalSCEV, Limit) &&
52355165 " Loop backedge is not guarded by canonical comparison with limit." );
52365166
5237- SCEVExpander Exp (SE, Preheader->getParent ()->getParent ()->getDataLayout (), " ad" );
5167+ fake:: SCEVExpander Exp (SE, Preheader->getParent ()->getParent ()->getDataLayout (), " ad" );
52385168 LimitVar = Exp.expandCodeFor (Limit, CanonicalIV->getType (),
52395169 Preheader->getTerminator ());
52405170
5241- // Canonicalize the loop latch.
5242- canonicalizeLoopLatch (CanonicalIV, LimitVar, L, SE, ExitBlock, &gutils);
5243-
52445171 loopContext.dynamic = false ;
52455172 } else {
52465173
@@ -5273,7 +5200,7 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo
52735200
52745201 // Remove Canonicalizable IV's
52755202 {
5276- SCEVExpander Exp (SE, Preheader->getParent ()->getParent ()->getDataLayout (), " ad" );
5203+ fake:: SCEVExpander Exp (SE, Preheader->getParent ()->getParent ()->getDataLayout (), " ad" );
52775204 for (BasicBlock::iterator II = Header->begin (); isa<PHINode>(II); ++II) {
52785205 PHINode *PN = cast<PHINode>(II);
52795206 if (PN == CanonicalIV) continue ;
@@ -5336,12 +5263,24 @@ class Enzyme : public FunctionPass {
53365263 AU.addRequired <AAResultsWrapperPass>();
53375264 AU.addRequired <GlobalsAAWrapperPass>();
53385265 AU.addRequiredID (LoopSimplifyID);
5339- AU.addRequiredID (LCSSAID);
5266+ // AU.addRequiredID(LCSSAID);
5267+
5268+ AU.addRequired <LoopInfoWrapperPass>();
5269+ AU.addRequired <DominatorTreeWrapperPass>();
5270+ AU.addRequired <ScalarEvolutionWrapperPass>();
53405271 }
53415272
53425273 bool runOnFunction (Function &F) override {
53435274 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI ();
53445275 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults ();
5276+
5277+ /*
5278+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
5279+ auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
5280+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
5281+ */
5282+
5283+
53455284 return lowerAutodiffIntrinsic (F, TLI, AA);
53465285 }
53475286};
0 commit comments