Skip to content

Commit b007b57

Browse files
committed
Finish isconst transition to orig
1 parent 64470af commit b007b57

File tree

15 files changed

+128
-94
lines changed

15 files changed

+128
-94
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ bool is_value_needed_in_reverse(TypeResults &TR, const GradientUtils* gutils, Va
371371
if (!topLevel) {
372372
//Proving that none of the uses (or uses' uses) are used in control flow allows us to safely not do this load
373373

374+
//TODO make this more aggressive and dont need to save loop latch
374375
if (isa<BranchInst>(use) || isa<SwitchInst>(use) || isa<CallInst>(use)) {
375376
//llvm::errs() << " had to use in reverse since used in branch/switch " << *inst << " use: " << *use << "\n";
376377
return seen[inst] = true;
@@ -1397,7 +1398,7 @@ class DerivativeMaker : public llvm::InstVisitor<DerivativeMaker<AugmentedReturn
13971398
if (vdiff && !gutils->isConstantValue(orig_ops[1])) {
13981399
Value* cmp = Builder2.CreateFCmpOLT(lookup(ops[0], Builder2), lookup(ops[1], Builder2));
13991400
Value* dif1 = Builder2.CreateSelect(cmp, vdiff, ConstantFP::get(ops[0]->getType(), 0));
1400-
addToDiffe(orig_ops[0], dif1, Builder2, II.getType());
1401+
addToDiffe(orig_ops[1], dif1, Builder2, II.getType());
14011402
}
14021403
return;
14031404
}
@@ -1447,15 +1448,15 @@ class DerivativeMaker : public llvm::InstVisitor<DerivativeMaker<AugmentedReturn
14471448

14481449
case Intrinsic::exp: {
14491450
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
1450-
Value* dif0 = Builder2.CreateFMul(vdiff, lookup(&II, Builder2));
1451+
Value* dif0 = Builder2.CreateFMul(vdiff, lookup(gutils->getNewFromOriginal(&II), Builder2));
14511452
addToDiffe(orig_ops[0], dif0, Builder2, II.getType());
14521453
}
14531454
return;
14541455
}
14551456
case Intrinsic::exp2: {
14561457
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
14571458
Value* dif0 = Builder2.CreateFMul(
1458-
Builder2.CreateFMul(vdiff, lookup(&II, Builder2)), ConstantFP::get(II.getType(), 0.6931471805599453)
1459+
Builder2.CreateFMul(vdiff, lookup(gutils->getNewFromOriginal(&II), Builder2)), ConstantFP::get(II.getType(), 0.6931471805599453)
14591460
);
14601461
addToDiffe(orig_ops[0], dif0, Builder2, II.getType());
14611462
}
@@ -1488,7 +1489,7 @@ class DerivativeMaker : public llvm::InstVisitor<DerivativeMaker<AugmentedReturn
14881489
Type *tys[] = {ops[1]->getType()};
14891490

14901491
Value* dif1 = Builder2.CreateFMul(
1491-
Builder2.CreateFMul(vdiff, lookup(&II, Builder2)),
1492+
Builder2.CreateFMul(vdiff, lookup(gutils->getNewFromOriginal(&II), Builder2)),
14921493
Builder2.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::log, tys), args)
14931494
);
14941495
addToDiffe(orig_ops[1], dif1, Builder2, II.getType());
@@ -1681,7 +1682,7 @@ void calculateUnusedValues(Function& oldFunc, SmallPtrSetImpl<Instruction*> &val
16811682
}
16821683
if (!bad) continue;
16831684

1684-
//llvm::errs() << " cannot use value: " << *inst << " because of user " << *user_val << "\n";
1685+
llvm::errs() << " need to keep instruction: " << *inst << " because of user " << *user_val << "\n";
16851686
necessaryUse = true;
16861687
break;
16871688
}
@@ -1696,13 +1697,14 @@ void calculateUnusedValues(Function& oldFunc, SmallPtrSetImpl<Instruction*> &val
16961697
}
16971698
}
16981699

1699-
/*
1700+
#if 0
17001701
llvm::errs() << "Prepping values for: " << oldFunc.getName() << " returnValue: " << returnValue << "\n";
17011702
for(auto v : valuesOnlyUsedInUnnecessaryReturns) {
17021703
llvm::errs() << "valuesOnlyUsedInUnnecessaryReturns: " << *v << "\n";
17031704
}
17041705
llvm::errs() << "</end>\n";
1705-
*/
1706+
#endif
1707+
17061708
}
17071709

17081710
//! return structtype if recursive function
@@ -2420,10 +2422,16 @@ void handleAugmentedCallInst(TypeResults &TR, CallInst* op, GradientUtils* const
24202422
return;
24212423
}
24222424

2425+
bool subretused = (op->getNumUses() != 0) && (valuesOnlyUsedInUnnecessaryReturns.find(orig) == valuesOnlyUsedInUnnecessaryReturns.end() || is_value_needed_in_reverse(TR, gutils, orig, /*topLevel*/false));
2426+
24232427
if (gutils->isConstantInstruction(orig)) {
2424-
if (op->getNumUses() != 0 && !op->doesNotAccessMemory() && is_value_needed_in_reverse(TR, gutils, orig, /*topLevel*/false)) {
2425-
IRBuilder<> BuilderZ(op);
2426-
gutils->addMalloc(BuilderZ, op, getIndex(orig, CacheType::Self) );
2428+
2429+
// If we need this value and it is illegal to recompute it (it writes or may load uncacheable data)
2430+
// Store and reload it
2431+
if (/*!topLevel*/true && subretused && !op->doesNotAccessMemory()) {
2432+
IRBuilder<> BuilderZ(op);
2433+
gutils->addMalloc(BuilderZ, op, getIndex(orig, CacheType::Self));
2434+
return;
24272435
}
24282436
return;
24292437
}
@@ -2466,7 +2474,6 @@ void handleAugmentedCallInst(TypeResults &TR, CallInst* op, GradientUtils* const
24662474
}
24672475
}
24682476

2469-
bool subretused = (op->getNumUses() != 0) && (valuesOnlyUsedInUnnecessaryReturns.find(orig) == valuesOnlyUsedInUnnecessaryReturns.end() || is_value_needed_in_reverse(TR, gutils, orig, /*topLevel*/false));
24702477
//llvm::errs() << "aug subretused: " << subretused << " op: " << *op << "\n";
24712478

24722479
//We check uses of the original function as that includes potential uses in the return,
@@ -2996,14 +3003,28 @@ void handleGradientCallInst(TypeResults &TR, IRBuilder <>& Builder2, CallInst* o
29963003

29973004
//llvm::errs() << " considering op: " << *op << " isConstantInstruction:" << gutils->isConstantInstruction(orig) << " subretused: " << subretused << " !op->doesNotAccessMemory: " << !op->doesNotAccessMemory() << "\n";
29983005
if (gutils->isConstantInstruction(orig)) {
3006+
3007+
// If we need this value and it is illegal to recompute it (it writes or may load uncacheable data)
3008+
// Store and reload it
29993009
if (!topLevel && subretused && !op->doesNotAccessMemory()) {
3000-
if (is_value_needed_in_reverse(TR, gutils, orig, topLevel)) {
3001-
IRBuilder<> BuilderZ(op);
3002-
gutils->addMalloc(BuilderZ, op, getIndex(orig, CacheType::Self) );
3003-
} else {
3004-
op->replaceAllUsesWith(UndefValue::get(op->getType()));
3005-
gutils->erase(op);
3006-
}
3010+
IRBuilder<> BuilderZ(op);
3011+
gutils->addMalloc(BuilderZ, op, getIndex(orig, CacheType::Self));
3012+
return;
3013+
}
3014+
3015+
// If this call may write to memory and is a copy (in the just reverse pass), erase it
3016+
// Any uses of it should be handled by the case above so it is safe to RAUW
3017+
if (op->mayWriteToMemory() && !topLevel) {
3018+
op->replaceAllUsesWith(UndefValue::get(op->getType()));
3019+
gutils->erase(op);
3020+
return;
3021+
}
3022+
3023+
// if call does not write memory and isn't used, we can erase it
3024+
if (!op->mayWriteToMemory() && !subretused) {
3025+
op->replaceAllUsesWith(UndefValue::get(op->getType()));
3026+
gutils->erase(op);
3027+
return;
30073028
}
30083029
return;
30093030
}
@@ -3512,10 +3533,20 @@ badfn:;
35123533

35133534
gutils->erase(op);
35143535

3515-
if (augmentcall)
3516-
gutils->replaceableCalls.insert(augmentcall);
35173536
} else {
3518-
gutils->replaceableCalls.insert(op);
3537+
3538+
if (!subretused) {
3539+
for(auto inst_orig : valuesOnlyUsedInUnnecessaryReturns) {
3540+
if (isa<ReturnInst>(inst_orig)) continue;
3541+
auto inst = gutils->getNewFromOriginal(inst_orig);
3542+
for(unsigned i=0; i<inst->getNumOperands(); i++) {
3543+
if (inst->getOperand(i) == op) {
3544+
inst->setOperand(i, UndefValue::get(inst->getType()));
3545+
}
3546+
}
3547+
}
3548+
gutils->erase(op);
3549+
}
35193550
}
35203551
}
35213552

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -685,27 +685,5 @@ void optimizeIntermediate(GradientUtils* gutils, bool topLevel, Function *F) {
685685

686686
SimplifyCFGOptions scfgo(/*unsigned BonusThreshold=*/1, /*bool ForwardSwitchCond=*/false, /*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true, /*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr);
687687
SimplifyCFGPass(scfgo).run(*F, AM);
688-
689-
if (!topLevel) {
690-
for(BasicBlock& BB: *F) {
691-
692-
for (auto I = BB.begin(), E = BB.end(); I != E;) {
693-
Instruction* inst = &*I;
694-
assert(inst);
695-
I++;
696-
697-
if (gutils->originalInstructions.find(inst) == gutils->originalInstructions.end()) continue;
698-
699-
if (gutils->replaceableCalls.find(inst) != gutils->replaceableCalls.end()) {
700-
if (inst->getNumUses() != 0 && !cast<CallInst>(inst)->getCalledFunction()->hasFnAttribute(Attribute::ReadNone) ) {
701-
llvm::errs() << "found call ripe for replacement " << *inst;
702-
} else {
703-
gutils->erase(inst);
704-
continue;
705-
}
706-
}
707-
}
708-
}
709-
}
710688
//LCSSAPass().run(*NewF, AM);
711689
}

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,12 +1069,12 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
10691069

10701070
assert(branch->getCondition()->getType() == T);
10711071

1072-
Value* phi = lookupM(branch->getCondition(), BuilderM);
10731072
if (replacePHIs == nullptr) {
10741073
assert(BuilderM.GetInsertBlock()->size() == 0 || !isa<BranchInst>(BuilderM.GetInsertBlock()->back()));
1075-
BuilderM.CreateCondBr(phi, *done[std::make_pair(block, branch->getSuccessor(0))].begin(), *done[std::make_pair(block, branch->getSuccessor(1))].begin());
1074+
BuilderM.CreateCondBr(lookupM(branch->getCondition(), BuilderM), *done[std::make_pair(block, branch->getSuccessor(0))].begin(), *done[std::make_pair(block, branch->getSuccessor(1))].begin());
10761075
} else {
10771076
for (auto pair : *replacePHIs) {
1077+
Value* phi = lookupM(branch->getCondition(), BuilderM);
10781078
Value* val = nullptr;
10791079
if (pair.first == *done[std::make_pair(block, branch->getSuccessor(0))].begin()) {
10801080
val = phi;
@@ -1102,21 +1102,17 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
11021102
IRBuilder<> pbuilder(equivalentTerminator);
11031103
pbuilder.setFastMathFlags(getFast());
11041104

1105-
AllocaInst* cache = createCacheForScope(ctx, si->getCondition()->getType(), "", /*shouldFree*/true);
1106-
Value* condition = si->getCondition();
1107-
storeInstructionInCache(ctx, pbuilder, condition, cache);
1108-
1109-
Value* phi = lookupM(si->getCondition(), BuilderM);
11101105

11111106
if (replacePHIs == nullptr) {
1112-
SwitchInst* swtch = BuilderM.CreateSwitch(phi, *done[std::make_pair(block, si->getDefaultDest())].begin());
1107+
SwitchInst* swtch = BuilderM.CreateSwitch(lookupM(si->getCondition(), BuilderM), *done[std::make_pair(block, si->getDefaultDest())].begin());
11131108
for (auto switchcase : si->cases()) {
11141109
swtch->addCase(switchcase.getCaseValue(), *done[std::make_pair(block, switchcase.getCaseSuccessor())].begin());
11151110
}
11161111
} else {
11171112
for (auto pair : *replacePHIs) {
11181113
Value* cas = si->findCaseDest(pair.first);
11191114
Value* val = nullptr;
1115+
Value* phi = lookupM(si->getCondition(), BuilderM);
11201116
if (cas) {
11211117
val = BuilderM.CreateICmpEQ(cas, phi);
11221118
} else {

enzyme/Enzyme/GradientUtils.h

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,6 @@ class GradientUtils {
860860
return isconstantM(TR, val, constants, nonconstant, constant_values, nonconstant_values, AA);
861861
}
862862

863-
SmallPtrSet<Instruction*,4> replaceableCalls;
864863
void eraseStructuralStoresAndCalls() {
865864

866865
for(auto pp : fictiousPHIs) {
@@ -915,13 +914,6 @@ class GradientUtils {
915914
continue;
916915
}
917916
}
918-
if (replaceableCalls.find(inst) != replaceableCalls.end()) {
919-
if (inst->getNumUses() != 0) {
920-
} else {
921-
erase(inst);
922-
continue;
923-
}
924-
}
925917
}
926918
}
927919
}
@@ -1226,6 +1218,7 @@ class GradientUtils {
12261218
//if (isOriginal(op))
12271219
//llvm::errs() << "op: " << *op << " op0: " << *op0 << " SAFE(op,getOperand(0)):" << *SAFE(op, getOperand(0)) << " orig:" << *getOriginal(op) << "\n";
12281220
auto toreturn = BuilderM.CreateCast(op->getOpcode(), op0, op->getDestTy(), op->getName()+"_unwrap");
1221+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(op);
12291222
unwrap_cache[cidx] = toreturn;
12301223
assert(val->getType() == toreturn->getType());
12311224
return toreturn;
@@ -1234,6 +1227,7 @@ class GradientUtils {
12341227
if (op0 == nullptr) goto endCheck;
12351228
auto toreturn = BuilderM.CreateExtractValue(op0, op->getIndices(), op->getName()+"_unwrap");
12361229
unwrap_cache[cidx] = toreturn;
1230+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(op);
12371231
assert(val->getType() == toreturn->getType());
12381232
return toreturn;
12391233
} else if (auto op = dyn_cast<BinaryOperator>(val)) {
@@ -1242,7 +1236,7 @@ class GradientUtils {
12421236
auto op1 = getOp(SAFE(op,getOperand(1)));
12431237
if (op1 == nullptr) goto endCheck;
12441238
auto toreturn = BuilderM.CreateBinOp(op->getOpcode(), op0, op1, op->getName()+"_unwrap");
1245-
cast<BinaryOperator>(toreturn)->copyIRFlags(op);
1239+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(op);
12461240
unwrap_cache[cidx] = toreturn;
12471241
assert(val->getType() == toreturn->getType());
12481242
return toreturn;
@@ -1251,7 +1245,8 @@ class GradientUtils {
12511245
if (op0 == nullptr) goto endCheck;
12521246
auto op1 = getOp(SAFE(op,getOperand(1)));
12531247
if (op1 == nullptr) goto endCheck;
1254-
auto toreturn = BuilderM.CreateICmp(op->getPredicate(), op0, op1);
1248+
auto toreturn = BuilderM.CreateICmp(op->getPredicate(), op0, op1, op->getName()+"_unwrap");
1249+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(op);
12551250
unwrap_cache[cidx] = toreturn;
12561251
assert(val->getType() == toreturn->getType());
12571252
return toreturn;
@@ -1260,7 +1255,8 @@ class GradientUtils {
12601255
if (op0 == nullptr) goto endCheck;
12611256
auto op1 = getOp(SAFE(op,getOperand(1)));
12621257
if (op1 == nullptr) goto endCheck;
1263-
auto toreturn = BuilderM.CreateFCmp(op->getPredicate(), op0, op1);
1258+
auto toreturn = BuilderM.CreateFCmp(op->getPredicate(), op0, op1, op->getName()+"_unwrap");
1259+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(op);
12641260
unwrap_cache[cidx] = toreturn;
12651261
assert(val->getType() == toreturn->getType());
12661262
return toreturn;
@@ -1272,6 +1268,7 @@ class GradientUtils {
12721268
auto op2 = getOp(SAFE(op,getOperand(2)));
12731269
if (op2 == nullptr) goto endCheck;
12741270
auto toreturn = BuilderM.CreateSelect(op0, op1, op2);
1271+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(op);
12751272
unwrap_cache[cidx] = toreturn;
12761273
assert(val->getType() == toreturn->getType());
12771274
return toreturn;
@@ -1295,6 +1292,7 @@ class GradientUtils {
12951292
//llvm::errs() << "safe: " << *SAFE(inst, getPointerOperand()) << "\n";
12961293
//assert(0 && "illegal");
12971294
}
1295+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(inst);
12981296
unwrap_cache[cidx] = toreturn;
12991297
assert(val->getType() == toreturn->getType());
13001298
return toreturn;
@@ -1319,6 +1317,7 @@ class GradientUtils {
13191317
}
13201318
assert(idx->getType() == load->getOperand(0)->getType());
13211319
auto toreturn = BuilderM.CreateLoad(idx, load->getName()+"_unwrap");
1320+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(load);
13221321
toreturn->setAlignment(load->getAlignment());
13231322
toreturn->setVolatile(load->isVolatile());
13241323
toreturn->setOrdering(load->getOrdering());
@@ -1336,13 +1335,17 @@ class GradientUtils {
13361335
Value *args[] = {getOp(SAFE(op,getOperand(0)))};
13371336
if (args[0] == nullptr) goto endCheck;
13381337
Type *tys[] = {op->getOperand(0)->getType()};
1339-
return BuilderM.CreateCall(Intrinsic::getDeclaration(op->getParent()->getParent()->getParent(), Intrinsic::sin, tys), args);
1338+
auto toreturn = BuilderM.CreateCall(Intrinsic::getDeclaration(op->getParent()->getParent()->getParent(), Intrinsic::sin, tys), args);
1339+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(op);
1340+
return toreturn;
13401341
}
13411342
case Intrinsic::cos: {
13421343
Value *args[] = {getOp(SAFE(op,getOperand(0)))};
13431344
if (args[0] == nullptr) goto endCheck;
13441345
Type *tys[] = {op->getOperand(0)->getType()};
1345-
return BuilderM.CreateCall(Intrinsic::getDeclaration(op->getParent()->getParent()->getParent(), Intrinsic::cos, tys), args);
1346+
auto toreturn = BuilderM.CreateCall(Intrinsic::getDeclaration(op->getParent()->getParent()->getParent(), Intrinsic::cos, tys), args);
1347+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(op);
1348+
return toreturn;
13461349
}
13471350
default:;
13481351

@@ -1353,6 +1356,7 @@ class GradientUtils {
13531356
auto toreturn = getOp(SAFE(phi,getIncomingValue(0)));
13541357
if (toreturn == nullptr) goto endCheck;
13551358
assert(val->getType() == toreturn->getType());
1359+
if (auto newi = dyn_cast<Instruction>(toreturn)) newi->copyIRFlags(op);
13561360
return toreturn;
13571361
}
13581362
}
@@ -1784,6 +1788,9 @@ class GradientUtils {
17841788
}
17851789

17861790
void storeInstructionInCache(BasicBlock* ctx, IRBuilder <>& BuilderM, Value* val, AllocaInst* cache) {
1791+
assert(BuilderM.GetInsertBlock()->getParent() == newFunc);
1792+
if (auto inst = dyn_cast<Instruction>(val))
1793+
assert(inst->getParent()->getParent() == newFunc);
17871794
IRBuilder <> v(BuilderM);
17881795
v.setFastMathFlags(getFast());
17891796

enzyme/Enzyme/TypeAnalysis.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,10 +853,25 @@ void TypeAnalyzer::visitPHINode(PHINode& phi) {
853853
vals.push_back(op);
854854
}
855855

856+
std::vector<BinaryOperator*> bos;
857+
856858
while(vals.size()) {
857859
Value* todo = vals.front();
858860
vals.pop_front();
859861

862+
if (auto bo = dyn_cast<BinaryOperator>(todo)) {
863+
if (bo->getOpcode() == BinaryOperator::Add) {
864+
if (isa<ConstantInt>(bo->getOperand(0))) {
865+
bos.push_back(bo);
866+
todo = bo->getOperand(1);
867+
}
868+
if (isa<ConstantInt>(bo->getOperand(1))) {
869+
bos.push_back(bo);
870+
todo = bo->getOperand(0);
871+
}
872+
}
873+
}
874+
860875
if (seen.count(todo)) continue;
861876
seen.insert(todo);
862877

@@ -875,6 +890,14 @@ void TypeAnalyzer::visitPHINode(PHINode& phi) {
875890
//llvm::errs() << " + sub" << *todo << " ga: " << getAnalysis(todo).str() << "\n";
876891
consider(getAnalysis(todo));
877892
}
893+
894+
assert(set);
895+
for(BinaryOperator* bo : bos) {
896+
ValueData vd1 = isa<ConstantInt>(bo->getOperand(0)) ? getAnalysis(bo->getOperand(0)) : vd;
897+
ValueData vd2 = isa<ConstantInt>(bo->getOperand(1)) ? getAnalysis(bo->getOperand(1)) : vd2;
898+
vd1.pointerIntMerge(vd2, bo->getOpcode());
899+
vd &= vd1;
900+
}
878901
//llvm::errs() << " -- res" << vd.str() << "\n";
879902

880903
updateAnalysis(&phi, vd, &phi);

0 commit comments

Comments
 (0)