1818
1919#include < llvm/Config/llvm-config.h>
2020
21+ #include " Utils.h"
2122#include " SCEV/ScalarEvolutionExpander.h"
2223
2324#include " llvm/Transforms/Utils/PromoteMemToReg.h"
@@ -141,12 +142,6 @@ std::string tostring(DIFFE_TYPE t) {
141142 }
142143}
143144
144- static inline FastMathFlags getFast () {
145- FastMathFlags f;
146- f.set ();
147- return f;
148- }
149-
150145Instruction *getNextNonDebugInstruction (Instruction* Z) {
151146 for (Instruction *I = Z->getNextNode (); I; I = I->getNextNode ())
152147 if (!isa<DbgInfoIntrinsic>(I))
@@ -349,21 +344,22 @@ bool isIntASecretFloat(Value* val) {
349344 assert (0 && " unsure if constant or not" );
350345}
351346
352- bool isIntPointerASecretFloat (Value* val) {
347+ // ! return the secret float type if found, otherwise nullptr
348+ Type* isIntPointerASecretFloat (Value* val) {
353349 assert (val->getType ()->isPointerTy ());
354350 assert (cast<PointerType>(val->getType ())->getElementType ()->isIntegerTy ());
355351
356- if (isa<UndefValue>(val)) return true ;
352+ if (isa<UndefValue>(val)) return nullptr ;
357353
358354 if (auto cint = dyn_cast<ConstantInt>(val)) {
359- if (!cint->isZero ()) return false ;
355+ if (!cint->isZero ()) return nullptr ;
360356 assert (0 && " unsure if constant or not because constantint" );
361357 // if (cint->isOne()) return cint;
362358 }
363359
364360
365361 if (auto inst = dyn_cast<Instruction>(val)) {
366- bool floatingUse = false ;
362+ Type* floatingUse = nullptr ;
367363 bool pointerUse = false ;
368364 SmallPtrSet<Value*, 4 > seen;
369365
@@ -374,7 +370,11 @@ bool isIntPointerASecretFloat(Value* val) {
374370 do {
375371 Type* let = cast<PointerType>(v->getType ())->getElementType ();
376372 if (let->isFloatingPointTy ()) {
377- floatingUse = true ;
373+ if (floatingUse == nullptr ) {
374+ floatingUse = let;
375+ } else {
376+ assert (floatingUse == let);
377+ }
378378 }
379379 if (auto ci = dyn_cast<CastInst>(v)) {
380380 if (auto cal = dyn_cast<CallInst>(ci->getOperand (0 ))) {
@@ -409,7 +409,11 @@ bool isIntPointerASecretFloat(Value* val) {
409409 llvm::errs () << " for val " << *v << *et << " \n " ;
410410
411411 if (et->isFloatingPointTy ()) {
412- floatingUse = true ;
412+ if (floatingUse == nullptr ) {
413+ floatingUse = et;
414+ } else {
415+ assert (floatingUse == et);
416+ }
413417 }
414418 if (et->isPointerTy ()) {
415419 pointerUse = true ;
@@ -431,8 +435,8 @@ bool isIntPointerASecretFloat(Value* val) {
431435 }
432436 }
433437
434- if (pointerUse && ! floatingUse) return false ;
435- if (!pointerUse && floatingUse) return true ;
438+ if (pointerUse && ( floatingUse == nullptr )) return nullptr ;
439+ if (!pointerUse && ( floatingUse != nullptr )) return floatingUse ;
436440 llvm::errs () << *inst->getParent ()->getParent () << " \n " ;
437441 llvm::errs () << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << " \n " ;
438442 assert (0 && " ambiguous unsure if constant or not" );
@@ -4057,7 +4061,20 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
40574061 switch (op->getIntrinsicID ()) {
40584062 case Intrinsic::memcpy: {
40594063 if (gutils->isConstantInstruction (inst)) continue ;
4060- if (!isIntPointerASecretFloat (op->getOperand (0 )) ) {
4064+ if (Type* secretty = isIntPointerASecretFloat (op->getOperand (0 )) ) {
4065+ SmallVector<Value*, 4 > args;
4066+ auto secretpt = PointerType::getUnqual (secretty);
4067+
4068+ args.push_back (Builder2.CreatePointerCast (invertPointer (op->getOperand (0 )), secretpt));
4069+ args.push_back (Builder2.CreatePointerCast (invertPointer (op->getOperand (1 )), secretpt));
4070+ args.push_back (lookup (op->getOperand (2 )));
4071+ auto dmemcpy = getOrInsertDifferentialFloatMemcpy (*M, secretpt);
4072+ dmemcpy->dump ();
4073+ args[0 ]->dump ();
4074+ args[1 ]->dump ();
4075+ args[2 ]->dump ();
4076+ auto cal = Builder2.CreateCall (dmemcpy, args);
4077+ } else {
40614078 if (topLevel) {
40624079 SmallVector<Value*, 4 > args;
40634080 IRBuilder <>BuilderZ (op);
@@ -4072,32 +4089,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
40724089 cal->setCallingConv (op->getCallingConv ());
40734090 cal->setTailCallKind (op->getTailCallKind ());
40744091 }
4075- } else {
4076- // no change to forward pass if represents float
4077- // Zero the destination
4078- assert (0 && " TODO: memcpy has bug that needs fixing (per int double vs ptr)" );
4079- /*
4080- {
4081- TODO BECOME MEMSET
4082- SmallVector<Value*, 4> args;
4083- // source and dest are swapped
4084- args.push_back(invertPointer(op->getOperand(1)));
4085- args.push_back(invertPointer(op->getOperand(0)));
4086- args.push_back(lookup(op->getOperand(2)));
4087- args.push_back(lookup(op->getOperand(3)));
4088-
4089- Type *tys[] = {args[0]->getType(), args[1]->getType(), args[2]->getType()};
4090- auto cal = Builder2.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::memset, tys), args);
4091- cal->setAttributes(op->getAttributes());
4092- cal->setCallingConv(op->getCallingConv());
4093- cal->setTailCallKind(op->getTailCallKind());
4094- }
4095-
4096-
4097- {
4098-
4099- }
4100- */
41014092 }
41024093 break ;
41034094 }
0 commit comments