Skip to content

Commit fc9159b

Browse files
stepasitewsmoses
authored andcommitted
fix build with llvm-16 (#5)
1 parent 0503d38 commit fc9159b

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9823,9 +9823,15 @@ class AdjointGenerator
98239823
CallInst *augmentcall = nullptr;
98249824
Value *cachereplace = nullptr;
98259825

9826+
#if LLVM_VERSION_MAJOR >= 16
9827+
std::optional<int> tapeIdx;
9828+
std::optional<int> returnIdx;
9829+
std::optional<int> differetIdx;
9830+
#else
98269831
Optional<int> tapeIdx;
98279832
Optional<int> returnIdx;
98289833
Optional<int> differetIdx;
9834+
#endif
98299835

98309836
if (modifyPrimal) {
98319837
Value *newcalled = nullptr;
@@ -9942,12 +9948,21 @@ class AdjointGenerator
99429948
if (!augmentcall->getType()->isVoidTy())
99439949
augmentcall->setName(call.getName() + "_augmented");
99449950

9951+
#if LLVM_VERSION_MAJOR >= 16
9952+
if (tapeIdx.has_value()) {
9953+
tape = (tapeIdx.value() == -1)
9954+
? augmentcall
9955+
: BuilderZ.CreateExtractValue(
9956+
augmentcall, {(unsigned)tapeIdx.value()},
9957+
"subcache");
9958+
#else
99459959
if (tapeIdx.hasValue()) {
99469960
tape = (tapeIdx.getValue() == -1)
99479961
? augmentcall
99489962
: BuilderZ.CreateExtractValue(
99499963
augmentcall, {(unsigned)tapeIdx.getValue()},
99509964
"subcache");
9965+
#endif
99519966
if (tape->getType()->isEmptyTy()) {
99529967
auto tt = tape->getType();
99539968
gutils->erase(cast<Instruction>(tape));
@@ -9963,10 +9978,17 @@ class AdjointGenerator
99639978
Value *dcall = nullptr;
99649979
assert(returnIdx);
99659980
assert(augmentcall);
9981+
#if LLVM_VERSION_MAJOR >= 16
9982+
dcall = (returnIdx.value() < 0)
9983+
? augmentcall
9984+
: BuilderZ.CreateExtractValue(
9985+
augmentcall, {(unsigned)returnIdx.value()});
9986+
#else
99669987
dcall = (returnIdx.getValue() < 0)
99679988
? augmentcall
99689989
: BuilderZ.CreateExtractValue(
99699990
augmentcall, {(unsigned)returnIdx.getValue()});
9991+
#endif
99709992
gutils->originalToNewFn[&call] = dcall;
99719993
gutils->newToOriginalFn.erase(newCall);
99729994
gutils->newToOriginalFn[dcall] = &call;
@@ -10029,12 +10051,21 @@ class AdjointGenerator
1002910051
subdata->returns.end()) {
1003010052
} else {
1003110053
if (!tape) {
10054+
#if LLVM_VERSION_MAJOR >= 16
10055+
assert(tapeIdx.has_value());
10056+
tape = BuilderZ.CreatePHI(
10057+
(tapeIdx == -1) ? FT->getReturnType()
10058+
: cast<StructType>(FT->getReturnType())
10059+
->getElementType(tapeIdx.value()),
10060+
1, "tapeArg");
10061+
#else
1003210062
assert(tapeIdx.hasValue());
1003310063
tape = BuilderZ.CreatePHI(
1003410064
(tapeIdx == -1) ? FT->getReturnType()
1003510065
: cast<StructType>(FT->getReturnType())
1003610066
->getElementType(tapeIdx.getValue()),
1003710067
1, "tapeArg");
10068+
#endif
1003810069
}
1003910070
tape = gutils->cacheForReverse(BuilderZ, tape,
1004010071
getIndex(&call, CacheType::Tape));
@@ -10081,11 +10112,19 @@ class AdjointGenerator
1008110112
Value *newip = nullptr;
1008210113
if (Mode == DerivativeMode::ReverseModeCombined ||
1008310114
Mode == DerivativeMode::ReverseModePrimal) {
10115+
#if LLVM_VERSION_MAJOR >= 16
10116+
newip = (differetIdx.value() < 0)
10117+
? augmentcall
10118+
: BuilderZ.CreateExtractValue(
10119+
augmentcall, {(unsigned)differetIdx.value()},
10120+
call.getName() + "'ac");
10121+
#else
1008410122
newip = (differetIdx.getValue() < 0)
1008510123
? augmentcall
1008610124
: BuilderZ.CreateExtractValue(
1008710125
augmentcall, {(unsigned)differetIdx.getValue()},
1008810126
call.getName() + "'ac");
10127+
#endif
1008910128
assert(newip->getType() == call.getType());
1009010129
placeholder->replaceAllUsesWith(newip);
1009110130
if (placeholder == &*BuilderZ.GetInsertPoint()) {

enzyme/Enzyme/Enzyme.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,9 +1701,15 @@ class EnzymeBase {
17011701
#endif
17021702
}
17031703

1704+
#if LLVM_VERSION_MAJOR >= 16
1705+
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
1706+
byVal, constants, fn, mode, options.value(),
1707+
sizeOnly);
1708+
#else
17041709
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
17051710
byVal, constants, fn, mode, options.getValue(),
17061711
sizeOnly);
1712+
#endif
17071713
}
17081714

17091715
bool HandleProbProg(CallInst *CI, ProbProgMode mode) {
@@ -1725,8 +1731,13 @@ class EnzymeBase {
17251731

17261732
SmallVector<Value *, 6> dargs = SmallVector(args);
17271733

1734+
#if LLVM_VERSION_MAJOR >= 16
1735+
if (!opt.has_value())
1736+
return false;
1737+
#else
17281738
if (!opt.hasValue())
17291739
return false;
1740+
#endif
17301741

17311742
auto dynamic_interface = opt->dynamic_interface;
17321743
auto trace = opt->trace.first;
@@ -1817,9 +1828,14 @@ class EnzymeBase {
18171828
#endif
18181829
}
18191830

1831+
#if LLVM_VERSION_MAJOR >= 16
18201832
bool status = HandleAutoDiff(
1833+
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
1834+
newFunc, DerivativeMode::ReverseModeCombined, opt.value(), false);
1835+
#else
18211836
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
18221837
newFunc, DerivativeMode::ReverseModeCombined, opt.getValue(), false);
1838+
#endif
18231839

18241840
delete interface;
18251841

0 commit comments

Comments
 (0)