Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2125,7 +2125,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
insert_or_assign(AugmentedCachedFunctions, tup,
AugmentedReturn(gutils->newFunc, nullptr, {}, returnMapping,
uncacheable_args_map, can_modref_map));
AugmentedCachedFinished[tup] = false;

auto getIndex = [&](Instruction *I, CacheType u) -> unsigned {
return gutils->getIndex(
Expand Down Expand Up @@ -2708,7 +2707,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
AugmentedCachedFunctions.find(tup)->second.fn = NewF;
if (recursive || (omp && !noTape))
AugmentedCachedFunctions.find(tup)->second.tapeType = tapeType;
insert_or_assign(AugmentedCachedFinished, tup, true);
AugmentedCachedFunctions.find(tup)->second.isComplete = true;

for (auto pair : gfnusers) {
auto GV = pair.first;
Expand Down Expand Up @@ -3226,8 +3225,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(

if (key.retType != DIFFE_TYPE::CONSTANT)
assert(!key.todiff->getReturnType()->isVoidTy());

Function *prevFunction = nullptr;
if (ReverseCachedFunctions.find(key) != ReverseCachedFunctions.end()) {
return ReverseCachedFunctions.find(key)->second;
prevFunction = ReverseCachedFunctions.find(key)->second;
if (!hasMetadata(prevFunction, "enzyme_placeholder"))
return prevFunction;
if (augmenteddata && !augmenteddata->isComplete)
return prevFunction;
}

if (key.returnUsed)
Expand Down Expand Up @@ -3641,6 +3646,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
insert_or_assign2<ReverseCacheKey, Function *>(ReverseCachedFunctions, key,
gutils->newFunc);

if (augmenteddata && !augmenteddata->isComplete) {
auto nf = gutils->newFunc;
delete gutils;
assert(!prevFunction);
nf->setMetadata("enzyme_placeholder", MDTuple::get(nf->getContext(), {}));
return nf;
}

const SmallPtrSet<BasicBlock *, 4> guaranteedUnreachable =
getGuaranteedUnreachable(gutils->oldFunc);

Expand Down Expand Up @@ -4020,6 +4033,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
if (Arch == Triple::nvptx || Arch == Triple::nvptx64)
PPC.ReplaceReallocs(nf, /*mem2reg*/ true);

if (prevFunction) {
prevFunction->replaceAllUsesWith(nf);
prevFunction->eraseFromParent();
}

// Do not run post processing optimizations if the body of an openmp
// parallel so the adjointgenerator can successfully extract the allocation
// and frees and hoist them into the parent. Optimizing before then may
Expand Down Expand Up @@ -4714,6 +4732,5 @@ llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width,
void EnzymeLogic::clear() {
PPC.clear();
AugmentedCachedFunctions.clear();
AugmentedCachedFinished.clear();
ReverseCachedFunctions.clear();
}
5 changes: 3 additions & 2 deletions enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class AugmentedReturn {

std::set<ssize_t> tapeIndiciesToFree;

bool isComplete;

AugmentedReturn(
llvm::Function *fn, llvm::Type *tapeType,
std::map<std::pair<llvm::Instruction *, CacheType>, int> tapeIndices,
Expand All @@ -126,7 +128,7 @@ class AugmentedReturn {
std::map<llvm::Instruction *, bool> can_modref_map)
: fn(fn), tapeType(tapeType), tapeIndices(tapeIndices), returns(returns),
uncacheable_args_map(uncacheable_args_map),
can_modref_map(can_modref_map) {}
can_modref_map(can_modref_map), isComplete(false) {}
};

struct ReverseCacheKey {
Expand Down Expand Up @@ -329,7 +331,6 @@ class EnzymeLogic {
};

std::map<AugmentedCacheKey, AugmentedReturn> AugmentedCachedFunctions;
std::map<AugmentedCacheKey, bool> AugmentedCachedFinished;

/// Create an augmented forward pass.
/// \p todiff is the function to differentiate
Expand Down
8 changes: 8 additions & 0 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2186,6 +2186,14 @@ class DiffeGradientUtils final : public GradientUtils {
assert(dif->getType() == old->getType());
Value *res = nullptr;
if (old->getType()->isIntOrIntVectorTy()) {
if (!addingType) {
if (looseTypeAnalysis) {
if (old->getType()->isIntegerTy(64))
addingType = Type::getDoubleTy(old->getContext());
else if (old->getType()->isIntegerTy(32))
addingType = Type::getFloatTy(old->getContext());
}
}
if (!addingType) {
llvm::errs() << "module: " << *oldFunc->getParent() << "\n";
llvm::errs() << "oldFunc: " << *oldFunc << "\n";
Expand Down
28 changes: 28 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/cacheErr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -simplifycfg -S | FileCheck %s

%DynamicsStruct = type { i8**, void (%DynamicsStruct*)** }

@someGlobal = internal constant i8* bitcast (void (%DynamicsStruct*)* @asdf to i8*)

define internal void @asdf(%DynamicsStruct* %arg) {
bb:
%i = getelementptr inbounds %DynamicsStruct, %DynamicsStruct* %arg, i64 0, i32 0
store i8** @someGlobal, i8*** %i, align 8
%i5 = getelementptr inbounds %DynamicsStruct, %DynamicsStruct* %arg, i64 0, i32 1
%i6 = load void (%DynamicsStruct*)**, void (%DynamicsStruct*)*** %i5, align 8
%i8 = load void (%DynamicsStruct*)*, void (%DynamicsStruct*)** %i6, align 8
tail call void %i8(%DynamicsStruct* %arg)
ret void
}

declare i8* @_Z17__enzyme_virtualreversePv(...)

define internal void @_Z19testSensitivitiesADv() {
bb40:
call i8* (...) @_Z17__enzyme_virtualreversePv(void (%DynamicsStruct*)* @asdf)
ret void
}

; CHECK: define internal i8* @augmented_asdf(%DynamicsStruct* %arg, %DynamicsStruct* %"arg'")

; CHECK: define internal void @diffeasdf.1(%DynamicsStruct* %arg, %DynamicsStruct* %"arg'", i8* %tapeArg)