Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions enzyme/Enzyme/ActiveVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ cl::opt<bool> printconst(
"enzyme_printconst", cl::init(false), cl::Hidden,
cl::desc("Print constant detection algorithm"));

cl::opt<bool> nonmarkedglobals_inactive(
"enzyme_nonmarkedglobals_inactive", cl::init(false), cl::Hidden,
cl::desc("Consider all nonmarked globals to be inactive"));

bool isIntASecretFloat(Value* val) {
assert(val->getType()->isIntegerTy());

Expand Down Expand Up @@ -199,8 +203,7 @@ Type* isIntPointerASecretFloat(Value* val) {
continue;
}
if (auto gep = dyn_cast<GetElementPtrInst>(v)) {
v = gep->getOperand(0);
continue;
trackPointer(gep->getOperand(0));
}
if (auto phi = dyn_cast<PHINode>(v)) {
for(auto &a : phi->incoming_values()) {
Expand All @@ -218,6 +221,10 @@ Type* isIntPointerASecretFloat(Value* val) {
et = st->getTypeAtIndex((unsigned int)0);
continue;
}
if (auto st = dyn_cast<ArrayType>(et)) {
et = st->getElementType();
continue;
}
break;
} while(1);
llvm::errs() << " for val " << *v << *et << "\n";
Expand Down Expand Up @@ -401,6 +408,7 @@ bool isconstantM(Instruction* inst, SmallPtrSetImpl<Value*> &constants, SmallPtr
}
if (auto call = dyn_cast<CallInst>(a)) {
auto fnp = call->getCalledFunction();
// For known library functions, special case how derivatives flow to allow for more aggressive active variable detection
if (fnp) {
auto fn = fnp->getName();
// todo realloc consider?
Expand All @@ -410,6 +418,8 @@ bool isconstantM(Instruction* inst, SmallPtrSetImpl<Value*> &constants, SmallPtr
continue;
if (fnp->getIntrinsicID() == Intrinsic::memcpy && call->getArgOperand(0) != inst && call->getArgOperand(1) != inst)
continue;
if (fnp->getIntrinsicID() == Intrinsic::memmove && call->getArgOperand(0) != inst && call->getArgOperand(1) != inst)
continue;
}
}

Expand Down Expand Up @@ -556,6 +566,36 @@ bool isconstantValueM(Value* val, SmallPtrSetImpl<Value*> &constants, SmallPtrSe
llvm::errs() << *val << "\n";
assert(0 && "must've put arguments in constant/nonconstant");
}

if (auto gi = dyn_cast<GlobalVariable>(val)) {
if (!hasMetadata(gi, "enzyme_shadow") && nonmarkedglobals_inactive) {
constants.insert(val);
return true;
}
//TODO consider this more
if (gi->isConstant() && isconstantValueM(gi->getInitializer(), constants, nonconstant, retvals, originalInstructions, directions)) {
constants.insert(val);
return true;
}
}

if (auto ce = dyn_cast<ConstantExpr>(val)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments needed, especially before conditionals that involve recursive calls.

if (ce->isCast()) {
if (isconstantValueM(ce->getOperand(0), constants, nonconstant, retvals, originalInstructions, directions)) {
constants.insert(val);
return true;
}
}
if (ce->isGEPWithNoNotionalOverIndexing()) {
if (isconstantValueM(ce->getOperand(0), constants, nonconstant, retvals, originalInstructions, directions)) {
constants.insert(val);
return true;
}
if (auto gi = dyn_cast<GlobalVariable>(val)) {

}
}
}

if (auto inst = dyn_cast<Instruction>(val)) {
if (isconstantM(inst, constants, nonconstant, retvals, originalInstructions, directions)) return true;
Expand Down Expand Up @@ -589,6 +629,8 @@ bool isconstantValueM(Value* val, SmallPtrSetImpl<Value*> &constants, SmallPtrSe
continue;
if (fnp->getIntrinsicID() == Intrinsic::memcpy && call->getArgOperand(0) != val && call->getArgOperand(1) != val)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For skipped intrinsics, a comment should be present describing the property the intrinsic has that lets us skip it. This lets people understand when they can add a "missing" intrinsic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not skipping intrinsic, but rather special casing memcpy/memmove to say that the size variable is not made active even if other arguments are active

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point was: if you are handling a bunch of things similarly (e.g. intrinsics) and its too cumbersome to give an explicit comment for each one, then an explanation for the group is good enough --- and should exist anyways so that folks understand what determines membership in the group of similar cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed re adding a comment, just wanted to clarify what code was doing

continue;
if (fnp->getIntrinsicID() == Intrinsic::memmove && call->getArgOperand(0) != val && call->getArgOperand(1) != val)
continue;
}
}

Expand Down
39 changes: 35 additions & 4 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
"enzyme_nonmarkedglobals_inactiveloads", cl::init(true), cl::Hidden,
cl::desc("Consider loads of nonmarked globals to be inactive"));


// Computes a map of LoadInst -> boolean for a function indicating whether that load is "uncacheable".
// A load is considered "uncacheable" if the data at the loaded memory location can be modified after
// the load instruction.
Expand Down Expand Up @@ -509,7 +508,8 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul

if(auto op = dyn_cast_or_null<IntrinsicInst>(inst)) {
switch(op->getIntrinsicID()) {
case Intrinsic::memcpy: {
case Intrinsic::memcpy:
case Intrinsic::memmove: {
if (gutils->isConstantInstruction(inst)) continue;

if (!isIntPointerASecretFloat(op->getOperand(0)) ) {
Expand All @@ -521,7 +521,7 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul
args.push_back(op->getOperand(3));

Type *tys[] = {args[0]->getType(), args[1]->getType(), args[2]->getType()};
auto cal = BuilderZ.CreateCall(Intrinsic::getDeclaration(gutils->newFunc->getParent(), Intrinsic::memcpy, tys), args);
auto cal = BuilderZ.CreateCall(Intrinsic::getDeclaration(gutils->newFunc->getParent(), op->getIntrinsicID(), tys), args);
cal->setAttributes(op->getAttributes());
cal->setCallingConv(op->getCallingConv());
cal->setTailCallKind(op->getTailCallKind());
Expand Down Expand Up @@ -659,7 +659,6 @@ std::pair<Function*,StructType*> CreateAugmentedPrimal(Function* todiff, AAResul

for(unsigned i=0;i<op->getNumArgOperands(); i++) {
args.push_back(op->getArgOperand(i));

if (gutils->isConstantValue(op->getArgOperand(i)) && !called->empty()) {
subconstant_args.insert(i);
argsInverted.push_back(DIFFE_TYPE::CONSTANT);
Expand Down Expand Up @@ -2177,6 +2176,38 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set<unsigned>& co
}
break;
}
case Intrinsic::memmove: {
if (gutils->isConstantInstruction(inst)) continue;
if (Type* secretty = isIntPointerASecretFloat(op->getOperand(0)) ) {
SmallVector<Value*, 4> args;
auto secretpt = PointerType::getUnqual(secretty);

args.push_back(Builder2.CreatePointerCast(invertPointer(op->getOperand(0)), secretpt));
args.push_back(Builder2.CreatePointerCast(invertPointer(op->getOperand(1)), secretpt));
args.push_back(Builder2.CreateUDiv(lookup(op->getOperand(2)),

ConstantInt::get(op->getOperand(2)->getType(), Builder2.GetInsertBlock()->getParent()->getParent()->getDataLayout().getTypeAllocSizeInBits(secretty)/8)
));
auto dmemmove = getOrInsertDifferentialFloatMemmove(*M, secretpt);
Builder2.CreateCall(dmemmove, args);
} else {
if (topLevel) {
SmallVector<Value*, 4> args;
IRBuilder <>BuilderZ(op);
args.push_back(gutils->invertPointerM(op->getOperand(0), BuilderZ));
args.push_back(gutils->invertPointerM(op->getOperand(1), BuilderZ));
args.push_back(op->getOperand(2));
args.push_back(op->getOperand(3));

Type *tys[] = {args[0]->getType(), args[1]->getType(), args[2]->getType()};
auto cal = BuilderZ.CreateCall(Intrinsic::getDeclaration(gutils->newFunc->getParent(), Intrinsic::memmove, tys), args);
cal->setAttributes(op->getAttributes());
cal->setCallingConv(op->getCallingConv());
cal->setTailCallKind(op->getTailCallKind());
}
}
break;
}
case Intrinsic::memset: {
if (gutils->isConstantInstruction(inst)) continue;
if (!gutils->isConstantValue(op->getOperand(1))) {
Expand Down
4 changes: 3 additions & 1 deletion enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,14 @@ static bool promoteMemoryToRegister(Function &F, DominatorTree &DT,
}

void forceRecursiveInlining(Function *NewF, const Function* F) {
int count = 0;
static int count = 0;
remover:
SmallPtrSet<Instruction*, 10> originalInstructions;
for (inst_iterator I = inst_begin(NewF), E = inst_end(NewF); I != E; ++I) {
originalInstructions.insert(&*I);
}
if (count >= autodiff_inline_count)
return;
for (inst_iterator I = inst_begin(NewF), E = inst_end(NewF); I != E; ++I)
if (auto call = dyn_cast<CallInst>(&*I)) {
//if (isconstantM(call, constants, nonconstant, returnvals, originalInstructions)) continue;
Expand Down
26 changes: 25 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "FunctionUtils.h"

#include "llvm/IR/Constants.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/Transforms/Utils/SimplifyIndVar.h"

Expand Down Expand Up @@ -88,6 +89,7 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
if (setupMerge) {
for(auto pair : loopContexts) {
auto &lc = pair.second;
assert(lc.exitBlocks.size() > 0);

lc.latchMerge = BasicBlock::Create(newFunc->getContext(), "loopMerge", newFunc);
loopContexts[pair.first].latchMerge = lc.latchMerge;
Expand Down Expand Up @@ -136,6 +138,7 @@ static bool isParentOrSameContext(LoopContext & possibleChild, LoopContext & pos
}
}
}
assert(targetToPreds.size() > 0);

BasicBlock* backlatch = nullptr;
for(auto blk : predecessors(lc.header)) {
Expand Down Expand Up @@ -359,6 +362,15 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
} else if (auto arg = dyn_cast<CastInst>(val)) {
auto result = BuilderM.CreateCast(arg->getOpcode(), invertPointerM(arg->getOperand(0), BuilderM), arg->getDestTy(), arg->getName()+"'ipc");
return result;
} else if (auto arg = dyn_cast<ConstantExpr>(val)) {
if (arg->isCast()) {
auto result = ConstantExpr::getCast(arg->getOpcode(), cast<Constant>(invertPointerM(arg->getOperand(0), BuilderM)), arg->getType());
return result;
} else if (arg->isGEPWithNoNotionalOverIndexing()) {
auto result = arg->getWithOperandReplaced(0, cast<Constant>(invertPointerM(arg->getOperand(0), BuilderM)));
return result;
}
goto end;
} else if (auto arg = dyn_cast<ExtractValueInst>(val)) {
IRBuilder<> bb(arg);
auto result = bb.CreateExtractValue(invertPointerM(arg->getOperand(0), bb), arg->getIndices(), arg->getName()+"'ipev");
Expand Down Expand Up @@ -478,6 +490,8 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
return lookupM(which, BuilderM);
}
}

end:;
assert(BuilderM.GetInsertBlock());
assert(BuilderM.GetInsertBlock()->getParent());
assert(val);
Expand Down Expand Up @@ -699,7 +713,12 @@ bool getContextM(BasicBlock *BB, LoopContext &loopContext, std::map<Loop*,LoopCo

loopContexts[L].latchMerge = nullptr;

fake::SCEVExpander::getExitBlocks(L, loopContexts[L].exitBlocks);
fake::SCEVExpander::getExitBlocks(L, loopContexts[L].exitBlocks);
if (loopContexts[L].exitBlocks.size() == 0) {
llvm::errs() << "newFunc: " << *BB->getParent() << "\n";
llvm::errs() << "L: " << *L << "\n";
}
assert(loopContexts[L].exitBlocks.size() > 0);

auto pair = insertNewCanonicalIV(L, Type::getInt64Ty(BB->getContext()));
PHINode* CanonicalIV = pair.first;
Expand Down Expand Up @@ -858,6 +877,7 @@ bool GradientUtils::getContext(BasicBlock* BB, LoopContext& loopContext) {
// * If replacePHIs is null (usual case), this function does the branch
// * If replacePHIs isn't null, do not perform the branch and instead replace the PHI's with the derived condition as to whether we should branch to a particular target
void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& BuilderM, const std::map<BasicBlock*, std::vector<std::pair</*pred*/BasicBlock*,/*successor*/BasicBlock*>>> &targetToPreds, const std::map<BasicBlock*,PHINode*>* replacePHIs) {
assert(targetToPreds.size() > 0);
if (replacePHIs) {
if (replacePHIs->size() == 0) return;

Expand Down Expand Up @@ -1060,6 +1080,7 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
targets.push_back(pair.first);
idx++;
}
assert(targets.size() > 0);

for(const auto &pair: storing) {
assert(pair.first->getTerminator());
Expand Down Expand Up @@ -1091,6 +1112,9 @@ void GradientUtils::branchToCorrespondingTarget(BasicBlock* ctx, IRBuilder <>& B
assert(BuilderM.GetInsertBlock()->size() == 0 || !isa<BranchInst>(BuilderM.GetInsertBlock()->back()));
BuilderM.CreateCondBr(which, /*true*/targets[1], /*false*/targets[0]);
} else {
assert(targets.size() > 0);
llvm::errs() << "which: " << *which << "\n";
llvm::errs() << "targets.back(): " << *targets.back() << "\n";
auto swit = BuilderM.CreateSwitch(which, targets.back(), targets.size()-1);
for(unsigned i=0; i<targets.size()-1; i++) {
swit->addCase(ConstantInt::get(T, i), targets[i]);
Expand Down
6 changes: 4 additions & 2 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,9 @@ class GradientUtils {
}
for( auto u : users) {
if (auto li = dyn_cast<LoadInst>(u)) {
li->replaceAllUsesWith(ret);
IRBuilder<> lb(li);
ValueToValueMapTy empty;
li->replaceAllUsesWith(unwrapM(ret, lb, empty, /*lookupifable*/false));
erase(li);
} else if (auto si = dyn_cast<StoreInst>(u)) {
erase(si);
Expand Down Expand Up @@ -656,7 +658,7 @@ class GradientUtils {
}
} else {
if (auto inti = dyn_cast<IntrinsicInst>(inst)) {
if (inti->getIntrinsicID() == Intrinsic::memset || inti->getIntrinsicID() == Intrinsic::memcpy) {
if (inti->getIntrinsicID() == Intrinsic::memset || inti->getIntrinsicID() == Intrinsic::memcpy || inti->getIntrinsicID() == Intrinsic::memmove) {
erase(inst);
continue;
}
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,9 @@ Function* getOrInsertDifferentialFloatMemcpy(Module& M, PointerType* T) {
}
return F;
}

//TODO implement differential memmove
Function* getOrInsertDifferentialFloatMemmove(Module& M, PointerType* T) {
llvm::errs() << "warning: didn't implement memmove, using memcpy as fallback which can result in errors\n";
return getOrInsertDifferentialFloatMemcpy(M, T);
}
3 changes: 3 additions & 0 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,7 @@ static inline bool isCertainPrintMallocOrFree(llvm::Function* called) {
//! Create function for type that performs the derivative memcpy on floating point memory
llvm::Function* getOrInsertDifferentialFloatMemcpy(llvm::Module& M, llvm::PointerType* T);

//! Create function for type that performs the derivative memmove on floating point memory
llvm::Function* getOrInsertDifferentialFloatMemmove(llvm::Module& M, llvm::PointerType* T);

#endif
82 changes: 82 additions & 0 deletions enzyme/test/Enzyme/cachelocations.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -sroa -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local void @subf(i1 zeroext %z, double* nocapture %x) local_unnamed_addr #0 {
entry:
br i1 %z, label %if.then, label %if.end

if.then: ; preds = %entry
%0 = load double, double* %x, align 8
%mul = fmul fast double %0, %0
store double %mul, double* %x, align 8
br label %if.end

if.end: ; preds = %if.then, %entry
ret void
}

; Function Attrs: noinline norecurse nounwind uwtable
define dso_local void @f(i1 zeroext %z, double* nocapture %x) #0 {
entry:
tail call void @subf(i1 zeroext %z, double* %x)
%arrayidx = getelementptr inbounds double, double* %x, i64 1
store double 2.000000e+00, double* %arrayidx, align 8
ret void
}

; Function Attrs: noinline nounwind uwtable
define dso_local double @dsumsquare(i1 zeroext %z, double* %x, double* %xp) local_unnamed_addr #1 {
entry:
%call = tail call fast double @__enzyme_autodiff(i8* bitcast (void (i1, double*)* @f to i8*), i1 zeroext %z, double* %x, double* %xp)
ret double %call
}

declare dso_local double @__enzyme_autodiff(i8*, i1 zeroext, double*, double*)

; CHECK: define internal {} @diffef(i1 zeroext %z, double* nocapture %x, double* %"x'") {
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = call { { double } } @augmented_subf(i1 %z, double* %x, double* %"x'")
; CHECK-NEXT: %1 = extractvalue { { double } } %0, 0
; CHECK-NEXT: %"arrayidx'ipge" = getelementptr inbounds double, double* %"x'", i64 1
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
; CHECK-NEXT: store double 2.000000e+00, double* %arrayidx, align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipge", align 8
; CHECK-NEXT: %2 = call {} @diffesubf(i1 %z, double* nonnull %x, double* %"x'", { double } %1)
; CHECK-NEXT: ret {} undef
; CHECK-NEXT: }

; CHECK: define internal { { double } } @augmented_subf(i1 zeroext %z, double* nocapture %x, double* %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 %z, label %if.then, label %if.end

; CHECK: if.then: ; preds = %entry
; CHECK-NEXT: %0 = load double, double* %x, align 8
; CHECK-NEXT: %mul = fmul fast double %0, %0
; CHECK-NEXT: store double %mul, double* %x, align 8
; CHECK-NEXT: br label %if.end

; CHECK: if.end: ; preds = %if.then, %entry
; CHECK-NEXT: %[[val:.+]] = phi double [ %0, %if.then ], [ undef, %entry ]
; CHECK-NEXT: %[[toret:.+]] = insertvalue { { double } } undef, double %[[val]], 0, 0
; CHECK-NEXT: ret { { double } } %[[toret]]
; CHECK-NEXT: }

; CHECK: define internal {} @diffesubf(i1 zeroext %z, double* nocapture %x, double* %"x'", { double } %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 %z, label %invertif.then, label %invertentry

; CHECK: invertentry: ; preds = %entry, %invertif.then
; CHECK-NEXT: ret {} undef

; CHECK: invertif.then: ; preds = %entry
; CHECK-NEXT: %0 = load double, double* %"x'"
; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8
; CHECK-NEXT: %_unwrap = extractvalue { double } %tapeArg, 0
; CHECK-NEXT: %m0diffe = fmul fast double %0, %_unwrap
; CHECK-NEXT: %m1diffe = fmul fast double %0, %_unwrap
; CHECK-NEXT: %1 = fadd fast double %m0diffe, %m1diffe
; CHECK-NEXT: %2 = load double, double* %"x'"
; CHECK-NEXT: %3 = fadd fast double %2, %1
; CHECK-NEXT: store double %3, double* %"x'"
; CHECK-NEXT: br label %invertentry
; CHECK-NEXT: }
Loading