@@ -371,6 +371,7 @@ bool is_value_needed_in_reverse(TypeResults &TR, const GradientUtils* gutils, Va
371371 if (!topLevel) {
372372 // Proving that none of the uses (or uses' uses) are used in control flow allows us to safely not do this load
373373
374+ // TODO make this more aggressive and dont need to save loop latch
374375 if (isa<BranchInst>(use) || isa<SwitchInst>(use) || isa<CallInst>(use)) {
375376 // llvm::errs() << " had to use in reverse since used in branch/switch " << *inst << " use: " << *use << "\n";
376377 return seen[inst] = true ;
@@ -1397,7 +1398,7 @@ class DerivativeMaker : public llvm::InstVisitor<DerivativeMaker<AugmentedReturn
13971398 if (vdiff && !gutils->isConstantValue (orig_ops[1 ])) {
13981399 Value* cmp = Builder2.CreateFCmpOLT (lookup (ops[0 ], Builder2), lookup (ops[1 ], Builder2));
13991400 Value* dif1 = Builder2.CreateSelect (cmp, vdiff, ConstantFP::get (ops[0 ]->getType (), 0 ));
1400- addToDiffe (orig_ops[0 ], dif1, Builder2, II.getType ());
1401+ addToDiffe (orig_ops[1 ], dif1, Builder2, II.getType ());
14011402 }
14021403 return ;
14031404 }
@@ -1447,15 +1448,15 @@ class DerivativeMaker : public llvm::InstVisitor<DerivativeMaker<AugmentedReturn
14471448
14481449 case Intrinsic::exp: {
14491450 if (vdiff && !gutils->isConstantValue (orig_ops[0 ])) {
1450- Value* dif0 = Builder2.CreateFMul (vdiff, lookup (&II, Builder2));
1451+ Value* dif0 = Builder2.CreateFMul (vdiff, lookup (gutils-> getNewFromOriginal ( &II) , Builder2));
14511452 addToDiffe (orig_ops[0 ], dif0, Builder2, II.getType ());
14521453 }
14531454 return ;
14541455 }
14551456 case Intrinsic::exp2: {
14561457 if (vdiff && !gutils->isConstantValue (orig_ops[0 ])) {
14571458 Value* dif0 = Builder2.CreateFMul (
1458- Builder2.CreateFMul (vdiff, lookup (&II, Builder2)), ConstantFP::get (II.getType (), 0.6931471805599453 )
1459+ Builder2.CreateFMul (vdiff, lookup (gutils-> getNewFromOriginal ( &II) , Builder2)), ConstantFP::get (II.getType (), 0.6931471805599453 )
14591460 );
14601461 addToDiffe (orig_ops[0 ], dif0, Builder2, II.getType ());
14611462 }
@@ -1488,7 +1489,7 @@ class DerivativeMaker : public llvm::InstVisitor<DerivativeMaker<AugmentedReturn
14881489 Type *tys[] = {ops[1 ]->getType ()};
14891490
14901491 Value* dif1 = Builder2.CreateFMul (
1491- Builder2.CreateFMul (vdiff, lookup (&II, Builder2)),
1492+ Builder2.CreateFMul (vdiff, lookup (gutils-> getNewFromOriginal ( &II) , Builder2)),
14921493 Builder2.CreateCall (Intrinsic::getDeclaration (M, Intrinsic::log, tys), args)
14931494 );
14941495 addToDiffe (orig_ops[1 ], dif1, Builder2, II.getType ());
@@ -1681,7 +1682,7 @@ void calculateUnusedValues(Function& oldFunc, SmallPtrSetImpl<Instruction*> &val
16811682 }
16821683 if (!bad) continue ;
16831684
1684- // llvm::errs() << " cannot use value : " << *inst << " because of user " << *user_val << "\n";
1685+ llvm::errs () << " need to keep instruction : " << *inst << " because of user " << *user_val << " \n " ;
16851686 necessaryUse = true ;
16861687 break ;
16871688 }
@@ -1696,13 +1697,14 @@ void calculateUnusedValues(Function& oldFunc, SmallPtrSetImpl<Instruction*> &val
16961697 }
16971698 }
16981699
1699- /*
1700+ # if 0
17001701 llvm::errs() << "Prepping values for: " << oldFunc.getName() << " returnValue: " << returnValue << "\n";
17011702 for(auto v : valuesOnlyUsedInUnnecessaryReturns) {
17021703 llvm::errs() << "valuesOnlyUsedInUnnecessaryReturns: " << *v << "\n";
17031704 }
17041705 llvm::errs() << "</end>\n";
1705- */
1706+ #endif
1707+
17061708}
17071709
17081710// ! return structtype if recursive function
@@ -2420,10 +2422,16 @@ void handleAugmentedCallInst(TypeResults &TR, CallInst* op, GradientUtils* const
24202422 return ;
24212423 }
24222424
2425+ bool subretused = (op->getNumUses () != 0 ) && (valuesOnlyUsedInUnnecessaryReturns.find (orig) == valuesOnlyUsedInUnnecessaryReturns.end () || is_value_needed_in_reverse (TR, gutils, orig, /* topLevel*/ false ));
2426+
24232427 if (gutils->isConstantInstruction (orig)) {
2424- if (op->getNumUses () != 0 && !op->doesNotAccessMemory () && is_value_needed_in_reverse (TR, gutils, orig, /* topLevel*/ false )) {
2425- IRBuilder<> BuilderZ (op);
2426- gutils->addMalloc (BuilderZ, op, getIndex (orig, CacheType::Self) );
2428+
2429+ // If we need this value and it is illegal to recompute it (it writes or may load uncacheable data)
2430+ // Store and reload it
2431+ if (/* !topLevel*/ true && subretused && !op->doesNotAccessMemory ()) {
2432+ IRBuilder<> BuilderZ (op);
2433+ gutils->addMalloc (BuilderZ, op, getIndex (orig, CacheType::Self));
2434+ return ;
24272435 }
24282436 return ;
24292437 }
@@ -2466,7 +2474,6 @@ void handleAugmentedCallInst(TypeResults &TR, CallInst* op, GradientUtils* const
24662474 }
24672475 }
24682476
2469- bool subretused = (op->getNumUses () != 0 ) && (valuesOnlyUsedInUnnecessaryReturns.find (orig) == valuesOnlyUsedInUnnecessaryReturns.end () || is_value_needed_in_reverse (TR, gutils, orig, /* topLevel*/ false ));
24702477 // llvm::errs() << "aug subretused: " << subretused << " op: " << *op << "\n";
24712478
24722479 // We check uses of the original function as that includes potential uses in the return,
@@ -2996,14 +3003,28 @@ void handleGradientCallInst(TypeResults &TR, IRBuilder <>& Builder2, CallInst* o
29963003
29973004 // llvm::errs() << " considering op: " << *op << " isConstantInstruction:" << gutils->isConstantInstruction(orig) << " subretused: " << subretused << " !op->doesNotAccessMemory: " << !op->doesNotAccessMemory() << "\n";
29983005 if (gutils->isConstantInstruction (orig)) {
3006+
3007+ // If we need this value and it is illegal to recompute it (it writes or may load uncacheable data)
3008+ // Store and reload it
29993009 if (!topLevel && subretused && !op->doesNotAccessMemory ()) {
3000- if (is_value_needed_in_reverse (TR, gutils, orig, topLevel)) {
3001- IRBuilder<> BuilderZ (op);
3002- gutils->addMalloc (BuilderZ, op, getIndex (orig, CacheType::Self) );
3003- } else {
3004- op->replaceAllUsesWith (UndefValue::get (op->getType ()));
3005- gutils->erase (op);
3006- }
3010+ IRBuilder<> BuilderZ (op);
3011+ gutils->addMalloc (BuilderZ, op, getIndex (orig, CacheType::Self));
3012+ return ;
3013+ }
3014+
3015+ // If this call may write to memory and is a copy (in the just reverse pass), erase it
3016+ // Any uses of it should be handled by the case above so it is safe to RAUW
3017+ if (op->mayWriteToMemory () && !topLevel) {
3018+ op->replaceAllUsesWith (UndefValue::get (op->getType ()));
3019+ gutils->erase (op);
3020+ return ;
3021+ }
3022+
3023+ // if call does not write memory and isn't used, we can erase it
3024+ if (!op->mayWriteToMemory () && !subretused) {
3025+ op->replaceAllUsesWith (UndefValue::get (op->getType ()));
3026+ gutils->erase (op);
3027+ return ;
30073028 }
30083029 return ;
30093030 }
@@ -3512,10 +3533,20 @@ badfn:;
35123533
35133534 gutils->erase (op);
35143535
3515- if (augmentcall)
3516- gutils->replaceableCalls .insert (augmentcall);
35173536 } else {
3518- gutils->replaceableCalls .insert (op);
3537+
3538+ if (!subretused) {
3539+ for (auto inst_orig : valuesOnlyUsedInUnnecessaryReturns) {
3540+ if (isa<ReturnInst>(inst_orig)) continue ;
3541+ auto inst = gutils->getNewFromOriginal (inst_orig);
3542+ for (unsigned i=0 ; i<inst->getNumOperands (); i++) {
3543+ if (inst->getOperand (i) == op) {
3544+ inst->setOperand (i, UndefValue::get (inst->getType ()));
3545+ }
3546+ }
3547+ }
3548+ gutils->erase (op);
3549+ }
35193550 }
35203551}
35213552
0 commit comments