1818
1919#include < llvm/Config/llvm-config.h>
2020
21+ #include " Utils.h"
2122#include " SCEV/ScalarEvolutionExpander.h"
2223
2324#include " llvm/Transforms/Utils/PromoteMemToReg.h"
@@ -141,12 +142,6 @@ std::string tostring(DIFFE_TYPE t) {
141142 }
142143}
143144
144- static inline FastMathFlags getFast () {
145- FastMathFlags f;
146- f.set ();
147- return f;
148- }
149-
150145Instruction *getNextNonDebugInstruction (Instruction* Z) {
151146 for (Instruction *I = Z->getNextNode (); I; I = I->getNextNode ())
152147 if (!isa<DbgInfoIntrinsic>(I))
@@ -349,21 +344,22 @@ bool isIntASecretFloat(Value* val) {
349344 assert (0 && " unsure if constant or not" );
350345}
351346
352- bool isIntPointerASecretFloat (Value* val) {
347+ // ! return the secret float type if found, otherwise nullptr
348+ Type* isIntPointerASecretFloat (Value* val) {
353349 assert (val->getType ()->isPointerTy ());
354350 assert (cast<PointerType>(val->getType ())->getElementType ()->isIntegerTy ());
355351
356- if (isa<UndefValue>(val)) return true ;
352+ if (isa<UndefValue>(val)) return nullptr ;
357353
358354 if (auto cint = dyn_cast<ConstantInt>(val)) {
359- if (!cint->isZero ()) return false ;
355+ if (!cint->isZero ()) return nullptr ;
360356 assert (0 && " unsure if constant or not because constantint" );
361357 // if (cint->isOne()) return cint;
362358 }
363359
364360
365361 if (auto inst = dyn_cast<Instruction>(val)) {
366- bool floatingUse = false ;
362+ Type* floatingUse = nullptr ;
367363 bool pointerUse = false ;
368364 SmallPtrSet<Value*, 4 > seen;
369365
@@ -374,7 +370,11 @@ bool isIntPointerASecretFloat(Value* val) {
374370 do {
375371 Type* let = cast<PointerType>(v->getType ())->getElementType ();
376372 if (let->isFloatingPointTy ()) {
377- floatingUse = true ;
373+ if (floatingUse == nullptr ) {
374+ floatingUse = let;
375+ } else {
376+ assert (floatingUse == let);
377+ }
378378 }
379379 if (auto ci = dyn_cast<CastInst>(v)) {
380380 if (auto cal = dyn_cast<CallInst>(ci->getOperand (0 ))) {
@@ -409,7 +409,11 @@ bool isIntPointerASecretFloat(Value* val) {
409409 llvm::errs () << " for val " << *v << *et << " \n " ;
410410
411411 if (et->isFloatingPointTy ()) {
412- floatingUse = true ;
412+ if (floatingUse == nullptr ) {
413+ floatingUse = et;
414+ } else {
415+ assert (floatingUse == et);
416+ }
413417 }
414418 if (et->isPointerTy ()) {
415419 pointerUse = true ;
@@ -431,8 +435,8 @@ bool isIntPointerASecretFloat(Value* val) {
431435 }
432436 }
433437
434- if (pointerUse && ! floatingUse) return false ;
435- if (!pointerUse && floatingUse) return true ;
438+ if (pointerUse && ( floatingUse == nullptr )) return nullptr ;
439+ if (!pointerUse && ( floatingUse != nullptr )) return floatingUse ;
436440 llvm::errs () << *inst->getParent ()->getParent () << " \n " ;
437441 llvm::errs () << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << " \n " ;
438442 assert (0 && " ambiguous unsure if constant or not" );
@@ -894,6 +898,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
894898 nullptr );
895899 NewF->setAttributes (F->getAttributes ());
896900
901+ if (enzyme_preopt) {
897902 {
898903 FunctionAnalysisManager AM;
899904 AM.registerPass ([] { return LoopAnalysis (); });
@@ -908,8 +913,6 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
908913
909914 }
910915
911- if (enzyme_preopt) {
912-
913916 if (autodiff_inline) {
914917 llvm::errs () << " running inlining process\n " ;
915918 forceRecursiveInlining (NewF, F);
@@ -1098,52 +1101,32 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
10981101 LoopAnalysisManager LAM;
10991102 AM.registerPass ([&] { return LoopAnalysisManagerFunctionProxy (LAM); });
11001103 LAM.registerPass ([&] { return FunctionAnalysisManagerLoopProxy (AM); });
1101-
1102- SimplifyCFGOptions scfgo (/* unsigned BonusThreshold=*/ 1 , /* bool ForwardSwitchCond=*/ false , /* bool SwitchToLookup=*/ false , /* bool CanonicalLoops=*/ true , /* bool SinkCommon=*/ true , /* AssumptionCache *AssumpCache=*/ nullptr );
1103- SimplifyCFGPass (scfgo).run (*NewF, AM);
1104- LoopSimplifyPass ().run (*NewF, AM);
1105-
1106- if (autodiff_inline) {
1107- createFunctionToLoopPassAdaptor (LoopIdiomRecognizePass ()).run (*NewF, AM);
1108- }
1109- DSEPass ().run (*NewF, AM);
1110- LoopSimplifyPass ().run (*NewF, AM);
1111-
1112- }
1113- }
1114-
1115- {
1116- FunctionAnalysisManager AM;
1117- AM.registerPass ([] { return AAManager (); });
1118- AM.registerPass ([] { return ScalarEvolutionAnalysis (); });
1119- AM.registerPass ([] { return AssumptionAnalysis (); });
1120- AM.registerPass ([] { return TargetLibraryAnalysis (); });
1121- AM.registerPass ([] { return TargetIRAnalysis (); });
1122- AM.registerPass ([] { return LoopAnalysis (); });
1123- AM.registerPass ([] { return MemorySSAAnalysis (); });
1124- AM.registerPass ([] { return DominatorTreeAnalysis (); });
1125- AM.registerPass ([] { return MemoryDependenceAnalysis (); });
1126- #if LLVM_VERSION_MAJOR > 6
1127- AM.registerPass ([] { return PhiValuesAnalysis (); });
1128- #endif
1129- #if LLVM_VERSION_MAJOR >= 8
1130- AM.registerPass ([] { return PassInstrumentationAnalysis (); });
1131- #endif
11321104
11331105 ModuleAnalysisManager MAM;
11341106 AM.registerPass ([&] { return ModuleAnalysisManagerFunctionProxy (MAM); });
11351107 MAM.registerPass ([&] { return FunctionAnalysisManagerModuleProxy (AM); });
11361108
1109+ SimplifyCFGOptions scfgo (/* unsigned BonusThreshold=*/ 1 , /* bool ForwardSwitchCond=*/ false , /* bool SwitchToLookup=*/ false , /* bool CanonicalLoops=*/ true , /* bool SinkCommon=*/ true , /* AssumptionCache *AssumpCache=*/ nullptr );
1110+ SimplifyCFGPass (scfgo).run (*NewF, AM);
1111+ LoopSimplifyPass ().run (*NewF, AM);
1112+
1113+ // AAManager().run(*NewF, AM)
11371114 BasicAA ba;
11381115 auto baa = new BasicAAResult (ba.run (*NewF, AM));
11391116 AA.addAAResult (*baa);
11401117
11411118 ScopedNoAliasAA sa;
11421119 auto saa = new ScopedNoAliasAAResult (sa.run (*NewF, AM));
11431120 AA.addAAResult (*saa);
1144-
1121+ if (autodiff_inline) {
1122+ createFunctionToLoopPassAdaptor (LoopIdiomRecognizePass ()).run (*NewF, AM);
11451123 }
1124+ DSEPass ().run (*NewF, AM);
1125+ LoopSimplifyPass ().run (*NewF, AM);
11461126
1127+ }
1128+ }
1129+
11471130 if (autodiff_print)
11481131 llvm::errs () << " after simplification :\n " << *NewF << " \n " ;
11491132
@@ -3983,11 +3966,10 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
39833966 tbuild.SetInsertPoint (&gutils->reverseBlocks [loopContext.exit ]->back ());
39843967 }
39853968
3969+ loopContext.antivar ->addIncoming (gutils->lookupM (loopContext.limit , tbuild), gutils->reverseBlocks [loopContext.exit ]);
39863970 auto sub = Builder2.CreateSub (loopContext.antivar , ConstantInt::get (loopContext.antivar ->getType (), 1 ));
39873971 for (BasicBlock* in: successors (loopContext.latch ) ) {
3988- if (loopContext.exit == in) {
3989- loopContext.antivar ->addIncoming (gutils->lookupM (loopContext.limit , tbuild), gutils->reverseBlocks [in]);
3990- } else if (gutils->LI .getLoopFor (in) == gutils->LI .getLoopFor (BB)) {
3972+ if (gutils->LI .getLoopFor (in) == gutils->LI .getLoopFor (BB)) {
39913973 loopContext.antivar ->addIncoming (sub, gutils->reverseBlocks [in]);
39923974 }
39933975 }
@@ -4057,7 +4039,16 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
40574039 switch (op->getIntrinsicID ()) {
40584040 case Intrinsic::memcpy: {
40594041 if (gutils->isConstantInstruction (inst)) continue ;
4060- if (!isIntPointerASecretFloat (op->getOperand (0 )) ) {
4042+ if (Type* secretty = isIntPointerASecretFloat (op->getOperand (0 )) ) {
4043+ SmallVector<Value*, 4 > args;
4044+ auto secretpt = PointerType::getUnqual (secretty);
4045+
4046+ args.push_back (Builder2.CreatePointerCast (invertPointer (op->getOperand (0 )), secretpt));
4047+ args.push_back (Builder2.CreatePointerCast (invertPointer (op->getOperand (1 )), secretpt));
4048+ args.push_back (lookup (op->getOperand (2 )));
4049+ auto dmemcpy = getOrInsertDifferentialFloatMemcpy (*M, secretpt);
4050+ auto cal = Builder2.CreateCall (dmemcpy, args);
4051+ } else {
40614052 if (topLevel) {
40624053 SmallVector<Value*, 4 > args;
40634054 IRBuilder <>BuilderZ (op);
@@ -4072,32 +4063,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
40724063 cal->setCallingConv (op->getCallingConv ());
40734064 cal->setTailCallKind (op->getTailCallKind ());
40744065 }
4075- } else {
4076- // no change to forward pass if represents float
4077- // Zero the destination
4078- assert (0 && " TODO: memcpy has bug that needs fixing (per int double vs ptr)" );
4079- /*
4080- {
4081- TODO BECOME MEMSET
4082- SmallVector<Value*, 4> args;
4083- // source and dest are swapped
4084- args.push_back(invertPointer(op->getOperand(1)));
4085- args.push_back(invertPointer(op->getOperand(0)));
4086- args.push_back(lookup(op->getOperand(2)));
4087- args.push_back(lookup(op->getOperand(3)));
4088-
4089- Type *tys[] = {args[0]->getType(), args[1]->getType(), args[2]->getType()};
4090- auto cal = Builder2.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::memset, tys), args);
4091- cal->setAttributes(op->getAttributes());
4092- cal->setCallingConv(op->getCallingConv());
4093- cal->setTailCallKind(op->getTailCallKind());
4094- }
4095-
4096-
4097- {
4098-
4099- }
4100- */
41014066 }
41024067 break ;
41034068 }
0 commit comments