Skip to content

Commit aa4f76d

Browse files
committed
further calling convention fixes
1 parent 3741228 commit aa4f76d

File tree

3 files changed

+185
-38
lines changed

3 files changed

+185
-38
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ llvm::cl::opt<bool> enzyme_print("enzyme_print", cl::init(false), cl::Hidden,
4848

4949
//! return structtype if recursive function
5050
std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResults &AA, const std::set<unsigned>& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed) {
51-
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*differentialReturn*/>, std::pair<Function*,StructType*>> cachedfunctions;
52-
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*differentialReturn*/>, bool> cachedfinished;
53-
auto tup = std::make_tuple(todiff, std::set<unsigned>(constant_args.begin(), constant_args.end()), differentialReturn);
51+
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair<Function*,StructType*>> cachedfunctions;
52+
static std::map<std::tuple<Function*,std::set<unsigned>, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished;
53+
auto tup = std::make_tuple(todiff, std::set<unsigned>(constant_args.begin(), constant_args.end()), differentialReturn, returnUsed);
5454
if (cachedfunctions.find(tup) != cachedfunctions.end()) {
5555
return cachedfunctions[tup];
5656
}
@@ -110,13 +110,19 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
110110
gutils->forceAugmentedReturns();
111111

112112
//! Explicitly handle all returns first to ensure that all instructions know whether or not they are used
113+
SmallPtrSet<Instruction*, 4> returnuses;
114+
113115
for(BasicBlock* BB: gutils->originalBlocks) {
114116
if(auto ri = dyn_cast<ReturnInst>(BB->getTerminator())) {
115117
auto oldval = ri->getReturnValue();
116118
Value* rt = UndefValue::get(gutils->newFunc->getReturnType());
117119
IRBuilder <>ib(ri);
118-
if (oldval && returnUsed)
120+
if (oldval && returnUsed) {
119121
rt = ib.CreateInsertValue(rt, oldval, {1});
122+
if (Instruction* inst = dyn_cast<Instruction>(rt)) {
123+
returnuses.insert(inst);
124+
}
125+
}
120126
ib.CreateRet(rt);
121127
gutils->erase(ri);
122128
/*
@@ -333,6 +339,16 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
333339

334340
bool subretused = op->getNumUses() != 0;
335341
bool subdifferentialreturn = !gutils->isConstantValue(op) && subretused;
342+
343+
//! We only need to cache something if it is used in a non return setting (since the backard pass doesnt need to use it if just returned)
344+
bool shouldCache = false;//outermostAugmentation;
345+
for(auto use : op->users()) {
346+
if (!isa<Instruction>(use) || returnuses.find(cast<Instruction>(use)) == returnuses.end()) {
347+
llvm::errs() << "shouldCache for " << *op << " use " << *use << "\n";
348+
shouldCache = true;
349+
}
350+
}
351+
336352
auto newcalled = CreateAugmentedPrimal(dyn_cast<Function>(called), AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused).first;
337353
auto augmentcall = BuilderZ.CreateCall(newcalled, args);
338354
assert(augmentcall->getType()->isStructTy());
@@ -348,6 +364,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
348364
gutils->erase(cast<Instruction>(tp));
349365
tp = UndefValue::get(tpt);
350366
}
367+
351368
gutils->addMalloc(BuilderZ, tp);
352369

353370
if (subretused) {
@@ -359,19 +376,25 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
359376
}
360377
assert(op->getType() == rv->getType());
361378

362-
gutils->addMalloc(BuilderZ, rv);
379+
if (shouldCache) {
380+
gutils->addMalloc(BuilderZ, rv);
381+
}
363382

364383
if ((op->getType()->isPointerTy() || op->getType()->isIntegerTy()) && subdifferentialreturn) {
365-
assert(cast<StructType>(augmentcall->getType())->getNumElements() == 3);
366-
367-
auto antiptr = cast<Instruction>(BuilderZ.CreateExtractValue(augmentcall, {2}, "antiptr_" + op->getName() ));
368384
auto placeholder = cast<PHINode>(gutils->invertedPointers[op]);
369385
if (I != E && placeholder == &*I) I++;
370386
gutils->invertedPointers.erase(op);
387+
388+
assert(cast<StructType>(augmentcall->getType())->getNumElements() == 3);
389+
auto antiptr = cast<Instruction>(BuilderZ.CreateExtractValue(augmentcall, {2}, "antiptr_" + op->getName() ));
390+
gutils->invertedPointers[rv] = antiptr;
371391
placeholder->replaceAllUsesWith(antiptr);
392+
393+
if (shouldCache) {
394+
gutils->addMalloc(BuilderZ, antiptr);
395+
}
396+
372397
gutils->erase(placeholder);
373-
gutils->invertedPointers[rv] = antiptr;
374-
gutils->addMalloc(BuilderZ, antiptr);
375398
} else {
376399
if (cast<StructType>(augmentcall->getType())->getNumElements() != 2) {
377400
llvm::errs() << "old called: " << *called << "\n";
@@ -380,6 +403,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
380403
llvm::errs() << "op subdifferentialreturn: " << subdifferentialreturn << "\n";
381404
}
382405
assert(cast<StructType>(augmentcall->getType())->getNumElements() == 2);
406+
383407
}
384408

385409
gutils->replaceAWithB(op,rv);
@@ -1197,8 +1221,9 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r
11971221

11981222
//TODO consider what to do if called == nullptr for augmentation
11991223
if (modifyPrimal && called) {
1200-
bool subdifferentialreturn = !gutils->isConstantValue(op);
1201-
auto fnandtapetype = CreateAugmentedPrimal(cast<Function>(called), AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/ op->getNumUses() != 0 && !op->doesNotAccessMemory());
1224+
bool subretused = op->getNumUses() != 0;
1225+
bool subdifferentialreturn = !gutils->isConstantValue(op) && subretused;
1226+
auto fnandtapetype = CreateAugmentedPrimal(cast<Function>(called), AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused);
12021227
if (topLevel) {
12031228
Function* newcalled = fnandtapetype.first;
12041229
augmentcall = BuilderZ.CreateCall(newcalled, pre_args);
@@ -1233,12 +1258,12 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r
12331258
}
12341259
} else {
12351260
tape = gutils->addMalloc(BuilderZ, tape);
1261+
12361262
if (!tape->getType()->isStructTy()) {
12371263
llvm::errs() << "newFunc: " << *gutils->newFunc << "\n";
12381264
llvm::errs() << "augment: " << *fnandtapetype.first << "\n";
12391265
llvm::errs() << "op: " << *op << "\n";
1240-
llvm::errs() << "tape: " << *tape << "\n";
1241-
1266+
llvm::errs() << "tape: " << *tape << "\n";
12421267
}
12431268
assert(tape->getType()->isStructTy());
12441269

enzyme/test/Enzyme/badcallused.ll

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ attributes #1 = { noinline nounwind uwtable }
4343

4444
; CHECK: define internal {} @diffef(double* nocapture %x, double* %"x'")
4545
; CHECK-NEXT: entry:
46-
; CHECK-NEXT: %0 = call { { {}, i1, i1 }, i1, i1 } @augmented_subf(double* %x, double* %"x'")
47-
; CHECK-NEXT: %1 = extractvalue { { {}, i1, i1 }, i1, i1 } %0, 0
48-
; CHECK-NEXT: %2 = extractvalue { { {}, i1, i1 }, i1, i1 } %0, 1
49-
; CHECK-NEXT: %sel = select i1 %2, double 2.000000e+00, double 3.000000e+00
46+
; CHECK-NEXT: %0 = call { { {} }, i1, i1 } @augmented_subf(double* %x, double* %"x'")
47+
; CHECK-NEXT: %1 = extractvalue { { {} }, i1, i1 } %0, 1
48+
; CHECK-NEXT: %sel = select i1 %1, double 2.000000e+00, double 3.000000e+00
5049
; CHECK-NEXT: store double %sel, double* %x, align 8
5150
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
52-
; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, i1, i1 } %1)
51+
; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {} } undef)
5352
; CHECK-NEXT: ret {} undef
5453
; CHECK-NEXT: }
5554

@@ -66,29 +65,24 @@ attributes #1 = { noinline nounwind uwtable }
6665
; CHECK-NEXT: ret { {}, i1, i1 } %3
6766
; CHECK-NEXT: }
6867

69-
; CHECK: define internal { { {}, i1, i1 }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'")
68+
; CHECK: define internal { { {} }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'")
7069
; CHECK-NEXT: entry:
71-
; CHECK-NEXT: %0 = alloca { { {}, i1, i1 }, i1, i1 }
72-
; CHECK-NEXT: %1 = getelementptr { { {}, i1, i1 }, i1, i1 }, { { {}, i1, i1 }, i1, i1 }* %0, i32 0, i32 0
73-
; CHECK-NEXT: %2 = load double, double* %x, align 8
74-
; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00
70+
; CHECK-NEXT: %0 = alloca { { {} }, i1, i1 }
71+
; CHECK-NEXT: %1 = load double, double* %x, align 8
72+
; CHECK-NEXT: %mul = fmul fast double %1, 2.000000e+00
7573
; CHECK-NEXT: store double %mul, double* %x, align 8
76-
; CHECK-NEXT: %3 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'")
77-
; CHECK-NEXT: %4 = extractvalue { {}, i1, i1 } %3, 1
78-
; CHECK-NEXT: %5 = getelementptr { {}, i1, i1 }, { {}, i1, i1 }* %1, i32 0, i32 1
79-
; CHECK-NEXT: store i1 %4, i1* %5
80-
; CHECK-NEXT: %antiptr_call = extractvalue { {}, i1, i1 } %3, 2
81-
; CHECK-NEXT: %6 = getelementptr { {}, i1, i1 }, { {}, i1, i1 }* %1, i32 0, i32 2
82-
; CHECK-NEXT: store i1 %antiptr_call, i1* %6
83-
; CHECK-NEXT: %7 = getelementptr { { {}, i1, i1 }, i1, i1 }, { { {}, i1, i1 }, i1, i1 }* %0, i32 0, i32 1
84-
; CHECK-NEXT: store i1 %4, i1* %7
85-
; CHECK-NEXT: %8 = getelementptr { { {}, i1, i1 }, i1, i1 }, { { {}, i1, i1 }, i1, i1 }* %0, i32 0, i32 2
86-
; CHECK-NEXT: store i1 %antiptr_call, i1* %8
87-
; CHECK-NEXT: %[[toret:.+]] = load { { {}, i1, i1 }, i1, i1 }, { { {}, i1, i1 }, i1, i1 }* %0
88-
; CHECK-NEXT: ret { { {}, i1, i1 }, i1, i1 } %[[toret]]
74+
; CHECK-NEXT: %2 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'")
75+
; CHECK-NEXT: %3 = extractvalue { {}, i1, i1 } %2, 1
76+
; CHECK-NEXT: %antiptr_call = extractvalue { {}, i1, i1 } %2, 2
77+
; CHECK-NEXT: %4 = getelementptr { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0, i32 0, i32 1
78+
; CHECK-NEXT: store i1 %3, i1* %4
79+
; CHECK-NEXT: %5 = getelementptr { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0, i32 0, i32 2
80+
; CHECK-NEXT: store i1 %antiptr_call, i1* %5
81+
; CHECK-NEXT: %[[toret:.+]] = load { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0
82+
; CHECK-NEXT: ret { { {} }, i1, i1 } %[[toret]]
8983
; CHECK-NEXT: }
9084

91-
; CHECK: define internal {} @diffesubf(double* nocapture %x, double* %"x'", { {}, i1, i1 } %tapeArg)
85+
; CHECK: define internal {} @diffesubf(double* nocapture %x, double* %"x'", { {} } %tapeArg)
9286
; CHECK-NEXT: entry:
9387
; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef)
9488
; CHECK-NEXT: %1 = load double, double* %"x'"

enzyme/test/Enzyme/badcallused2.ll

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline norecurse nounwind uwtable
4+
define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 {
5+
entry:
6+
%arrayidx = getelementptr inbounds double, double* %x, i64 1
7+
store double 3.000000e+00, double* %arrayidx, align 8
8+
%0 = load double, double* %x, align 8
9+
%cmp = fcmp fast oeq double %0, 2.000000e+00
10+
ret i1 %cmp
11+
}
12+
13+
define dso_local zeroext i1 @omegasubf(double* nocapture %x) local_unnamed_addr #0 {
14+
entry:
15+
%arrayidx = getelementptr inbounds double, double* %x, i64 1
16+
store double 3.000000e+00, double* %arrayidx, align 8
17+
%0 = load double, double* %x, align 8
18+
%cmp = fcmp fast oeq double %0, 2.000000e+00
19+
ret i1 %cmp
20+
}
21+
22+
; Function Attrs: noinline norecurse nounwind uwtable
23+
define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 {
24+
entry:
25+
%0 = load double, double* %x, align 8
26+
%mul = fmul fast double %0, 2.000000e+00
27+
store double %mul, double* %x, align 8
28+
%call = tail call zeroext i1 @omegasubf(double* %x)
29+
%call2 = tail call zeroext i1 @metasubf(double* %x)
30+
ret i1 %call2
31+
}
32+
33+
; Function Attrs: noinline norecurse nounwind uwtable
34+
define dso_local void @f(double* nocapture %x) #0 {
35+
entry:
36+
%call = tail call zeroext i1 @subf(double* %x)
37+
%sel = select i1 %call, double 2.000000e+00, double 3.000000e+00
38+
store double %sel, double* %x, align 8
39+
ret void
40+
}
41+
42+
; Function Attrs: noinline nounwind uwtable
43+
define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 {
44+
entry:
45+
%call = tail call fast double @__enzyme_autodiff(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp)
46+
ret double %call
47+
}
48+
49+
declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed_addr
50+
51+
attributes #0 = { noinline norecurse nounwind uwtable }
52+
attributes #1 = { noinline nounwind uwtable }
53+
54+
; CHECK: define internal {} @diffef(double* nocapture %x, double* %"x'")
55+
; CHECK-NEXT: entry:
56+
; CHECK-NEXT: %0 = call { { {}, {} }, i1, i1 } @augmented_subf(double* %x, double* %"x'")
57+
; CHECK-NEXT: %1 = extractvalue { { {}, {} }, i1, i1 } %0, 1
58+
; CHECK-NEXT: %sel = select i1 %1, double 2.000000e+00, double 3.000000e+00
59+
; CHECK-NEXT: store double %sel, double* %x, align 8
60+
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
61+
; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {} } undef)
62+
; CHECK-NEXT: ret {} undef
63+
; CHECK-NEXT: }
64+
65+
; CHECK: define internal { {}, i1, i1 } @augmented_metasubf(double* nocapture %x, double* %"x'")
66+
; CHECK-NEXT: entry:
67+
; CHECK-NEXT: %0 = alloca { {}, i1, i1 }
68+
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
69+
; CHECK-NEXT: store double 3.000000e+00, double* %arrayidx, align 8
70+
; CHECK-NEXT: %1 = load double, double* %x, align 8
71+
; CHECK-NEXT: %cmp = fcmp fast oeq double %1, 2.000000e+00
72+
; CHECK-NEXT: %2 = getelementptr { {}, i1, i1 }, { {}, i1, i1 }* %0, i32 0, i32 1
73+
; CHECK-NEXT: store i1 %cmp, i1* %2
74+
; CHECK-NEXT: %3 = load { {}, i1, i1 }, { {}, i1, i1 }* %0
75+
; CHECK-NEXT: ret { {}, i1, i1 } %3
76+
; CHECK-NEXT: }
77+
78+
; CHECK: define internal { {} } @augmented_omegasubf(double* nocapture %x, double* %"x'")
79+
; CHECK-NEXT: entry:
80+
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
81+
; CHECK-NEXT: store double 3.000000e+00, double* %arrayidx, align 8
82+
; CHECK-NEXT: ret { {} } undef
83+
; CHECK-NEXT: }
84+
85+
; CHECK: define internal { { {}, {} }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'")
86+
; CHECK-NEXT: entry:
87+
; CHECK-NEXT: %0 = alloca { { {}, {} }, i1, i1 }
88+
; CHECK-NEXT: %1 = load double, double* %x, align 8
89+
; CHECK-NEXT: %mul = fmul fast double %1, 2.000000e+00
90+
; CHECK-NEXT: store double %mul, double* %x, align 8
91+
; CHECK-NEXT: %2 = call { {} } @augmented_omegasubf(double* %x, double* %"x'")
92+
; CHECK-NEXT: %3 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'")
93+
; CHECK-NEXT: %4 = extractvalue { {}, i1, i1 } %3, 1
94+
; CHECK-NEXT: %antiptr_call2 = extractvalue { {}, i1, i1 } %3, 2
95+
; CHECK-NEXT: %5 = getelementptr { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0, i32 0, i32 1
96+
; CHECK-NEXT: store i1 %4, i1* %5
97+
; CHECK-NEXT: %6 = getelementptr { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0, i32 0, i32 2
98+
; CHECK-NEXT: store i1 %antiptr_call2, i1* %6
99+
; CHECK-NEXT: %[[toret:.+]] = load { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0
100+
; CHECK-NEXT: ret { { {}, {} }, i1, i1 } %[[toret]]
101+
; CHECK-NEXT: }
102+
103+
; CHECK: define internal {} @diffesubf(double* nocapture %x, double* %"x'", { {}, {} } %tapeArg)
104+
; CHECK-NEXT: entry:
105+
; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef)
106+
; CHECK-NEXT: %1 = call {} @diffeomegasubf(double* %x, double* %"x'", {} undef)
107+
; CHECK-NEXT: %2 = load double, double* %"x'"
108+
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
109+
; CHECK-NEXT: %m0diffe = fmul fast double %2, 2.000000e+00
110+
; CHECK-NEXT: %3 = load double, double* %"x'"
111+
; CHECK-NEXT: %4 = fadd fast double %3, %m0diffe
112+
; CHECK-NEXT: store double %4, double* %"x'"
113+
; CHECK-NEXT: ret {} undef
114+
; CHECK-NEXT: }
115+
116+
; CHECK: define internal {} @diffemetasubf(double* nocapture %x, double* %"x'", {} %tapeArg)
117+
; CHECK-NEXT: entry:
118+
; CHECK-NEXT: %[[tostore:.+]] = getelementptr inbounds double, double* %"x'", i64 1
119+
; CHECK-NEXT: store double 0.000000e+00, double* %[[tostore]], align 8
120+
; CHECK-NEXT: ret {} undef
121+
; CHECK-NEXT: }
122+
123+
; CHECK: define internal {} @diffeomegasubf(double* nocapture %x, double* %"x'", {} %tapeArg)
124+
; CHECK-NEXT: entry:
125+
; CHECK-NEXT: %[[tostore:.+]] = getelementptr inbounds double, double* %"x'", i64 1
126+
; CHECK-NEXT: store double 0.000000e+00, double* %[[tostore]], align 8
127+
; CHECK-NEXT: ret {} undef
128+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)