Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10593,9 +10593,15 @@ class AdjointGenerator
CallInst *augmentcall = nullptr;
Value *cachereplace = nullptr;

#if LLVM_VERSION_MAJOR >= 16
std::optional<int> tapeIdx;
std::optional<int> returnIdx;
std::optional<int> differetIdx;
#else
Optional<int> tapeIdx;
Optional<int> returnIdx;
Optional<int> differetIdx;
#endif

if (modifyPrimal) {
Value *newcalled = nullptr;
Expand Down Expand Up @@ -10712,12 +10718,21 @@ class AdjointGenerator
if (!augmentcall->getType()->isVoidTy())
augmentcall->setName(call.getName() + "_augmented");

#if LLVM_VERSION_MAJOR >= 16
if (tapeIdx.has_value()) {
tape = (tapeIdx.value() == -1)
? augmentcall
: BuilderZ.CreateExtractValue(
augmentcall, {(unsigned)tapeIdx.value()},
"subcache");
#else
if (tapeIdx.hasValue()) {
tape = (tapeIdx.getValue() == -1)
? augmentcall
: BuilderZ.CreateExtractValue(
augmentcall, {(unsigned)tapeIdx.getValue()},
"subcache");
#endif
if (tape->getType()->isEmptyTy()) {
auto tt = tape->getType();
gutils->erase(cast<Instruction>(tape));
Expand All @@ -10733,10 +10748,17 @@ class AdjointGenerator
Value *dcall = nullptr;
assert(returnIdx);
assert(augmentcall);
#if LLVM_VERSION_MAJOR >= 16
dcall = (returnIdx.value() < 0)
? augmentcall
: BuilderZ.CreateExtractValue(
augmentcall, {(unsigned)returnIdx.value()});
#else
dcall = (returnIdx.getValue() < 0)
? augmentcall
: BuilderZ.CreateExtractValue(
augmentcall, {(unsigned)returnIdx.getValue()});
#endif
gutils->originalToNewFn[&call] = dcall;
gutils->newToOriginalFn.erase(newCall);
gutils->newToOriginalFn[dcall] = &call;
Expand Down Expand Up @@ -10799,12 +10821,21 @@ class AdjointGenerator
subdata->returns.end()) {
} else {
if (!tape) {
#if LLVM_VERSION_MAJOR >= 16
assert(tapeIdx.has_value());
tape = BuilderZ.CreatePHI(
(tapeIdx == -1) ? FT->getReturnType()
: cast<StructType>(FT->getReturnType())
->getElementType(tapeIdx.value()),
1, "tapeArg");
#else
assert(tapeIdx.hasValue());
tape = BuilderZ.CreatePHI(
(tapeIdx == -1) ? FT->getReturnType()
: cast<StructType>(FT->getReturnType())
->getElementType(tapeIdx.getValue()),
1, "tapeArg");
#endif
}
tape = gutils->cacheForReverse(BuilderZ, tape,
getIndex(&call, CacheType::Tape));
Expand Down Expand Up @@ -10851,11 +10882,19 @@ class AdjointGenerator
Value *newip = nullptr;
if (Mode == DerivativeMode::ReverseModeCombined ||
Mode == DerivativeMode::ReverseModePrimal) {
#if LLVM_VERSION_MAJOR >= 16
newip = (differetIdx.value() < 0)
? augmentcall
: BuilderZ.CreateExtractValue(
augmentcall, {(unsigned)differetIdx.value()},
call.getName() + "'ac");
#else
newip = (differetIdx.getValue() < 0)
? augmentcall
: BuilderZ.CreateExtractValue(
augmentcall, {(unsigned)differetIdx.getValue()},
call.getName() + "'ac");
#endif
assert(newip->getType() == call.getType());
placeholder->replaceAllUsesWith(newip);
if (placeholder == &*BuilderZ.GetInsertPoint()) {
Expand Down
16 changes: 16 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1697,9 +1697,15 @@ class EnzymeBase {
#endif
}

#if LLVM_VERSION_MAJOR >= 16
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
byVal, constants, fn, mode, options.value(),
sizeOnly);
#else
return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args,
byVal, constants, fn, mode, options.getValue(),
sizeOnly);
#endif
}

bool HandleProbProg(CallInst *CI, ProbProgMode mode) {
Expand All @@ -1721,8 +1727,13 @@ class EnzymeBase {

SmallVector<Value *, 6> dargs = SmallVector(args);

#if LLVM_VERSION_MAJOR >= 16
if (!opt.has_value())
return false;
#else
if (!opt.hasValue())
return false;
#endif

auto dynamic_interface = opt->dynamic_interface;
auto trace = opt->trace.first;
Expand Down Expand Up @@ -1813,9 +1824,14 @@ class EnzymeBase {
#endif
}

#if LLVM_VERSION_MAJOR >= 16
bool status = HandleAutoDiff(
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
newFunc, DerivativeMode::ReverseModeCombined, opt.value(), false);
#else
CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants,
newFunc, DerivativeMode::ReverseModeCombined, opt.getValue(), false);
#endif

delete interface;

Expand Down