Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ccpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ jobs:
strategy:
fail-fast: false
matrix:
llvm: ["7"]
build: ["Release", "Debug"] # "RelWithDebInfo"
llvm: ["7"] #8 and 9 not done because of gvn issues
build: ["Release"] # "RelWithDebInfo"
os: [self-hosted]

timeout-minutes: 45
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/enzyme.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
fail-fast: false
matrix:
llvm: ["7", "9"] # 8
llvm: ["7", "8", "9"] # 8
build: ["Release", "Debug"] # "RelWithDebInfo"
os: [self-hosted]

Expand Down
207 changes: 117 additions & 90 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,103 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
"enzyme_nonmarkedglobals_inactiveloads", cl::init(true), cl::Hidden,
cl::desc("Consider loads of nonmarked globals to be inactive"));


bool is_load_uncacheable(LoadInst& li, AAResults& AA, GradientUtils* gutils, TargetLibraryInfo& TLI,
const std::map<Argument*, bool>& uncacheable_args) {

bool can_modref = false;
// Find the underlying object for the pointer operand of the load instruction.
auto obj = GetUnderlyingObject(li.getPointerOperand(), gutils->oldFunc->getParent()->getDataLayout(), 100);

//llvm::errs() << "underlying object for load " << li << " is " << *obj << "\n";
// If the pointer operand is from an argument to the function, we need to check if the argument
// received from the caller is uncacheable.
if (auto arg = dyn_cast<Argument>(obj)) {
auto found = uncacheable_args.find(arg);
if (found == uncacheable_args.end()) {
llvm::errs() << "uncacheable_args:\n";
for(auto& pair : uncacheable_args) {
llvm::errs() << " + " << *pair.first << ": " << pair.second << " of func " << pair.first->getParent()->getName() << "\n";
}
llvm::errs() << "could not find " << *arg << " of func " << arg->getParent()->getName() << " in args_map\n";
}
assert(found != uncacheable_args.end());
if (found->second) {
//llvm::errs() << "OP is uncacheable arg: " << li << "\n";
can_modref = true;
}
//llvm::errs() << " + argument (can_modref=" << can_modref << ") " << li << " object: " << *obj << " arg: " << *arg << "e\n";
//TODO this case (alloca goes out of scope/allocation is freed and we dont force it to continue needs to be forcibly cached)
} else {
// NOTE(TFK): In the case where the underlying object for the pointer operand is from a Load or Call we need
// to check if we need to cache. Likely, we need to play it safe in this case and cache.
// NOTE(TFK): The logic below is an attempt at a conservative handling of the case mentioned above, but it
// needs to be verified.

// Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
if (auto obj_op = dyn_cast<CallInst>(obj)) {
Function* called = obj_op->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
called = fn;
}
}
}
}
if (called && isCertainMallocOrFree(called)) {
//llvm::errs() << "OP is certain malloc or free: " << *op << "\n";
} else {
//llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n";
can_modref = true;
}
} else if (auto sli = dyn_cast<LoadInst>(obj)) {
// If obj is from a load instruction conservatively consider it uncacheable if that load itself cannot be cached
//llvm::errs() << "OP is from a load, needing to cache " << *op << "\n";
can_modref = is_load_uncacheable(*sli, AA, gutils, TLI, uncacheable_args);
} else {
// In absence of more information, assume that the underlying object for pointer operand is uncacheable in caller.
//llvm::errs() << "OP is an unknown instruction, needing to cache " << *op << "\n";
can_modref = true;
}
}

for (inst_iterator I2 = inst_begin(*gutils->oldFunc), E2 = inst_end(*gutils->oldFunc); I2 != E2; ++I2) {
Instruction* inst2 = &*I2;
assert(li.getParent()->getParent() == inst2->getParent()->getParent());
if (&li == inst2) continue;
if (!gutils->OrigDT.dominates(inst2, &li)) {

// Don't consider modref from malloc/free as a need to cache
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
Function* called = obj_op->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
called = fn;
}
}
}
}
if (called && isCertainMallocOrFree(called)) {
continue;
}
}

if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(&li)))) {
can_modref = true;
llvm::errs() << li << " needs to be cached due to: " << *inst2 << "\n";
break;
}
}
}
//llvm::errs() << "F - " << li << " can_modref" << can_modref << "\n";
return can_modref;

}

// Computes a map of LoadInst -> boolean for a function indicating whether that load is "uncacheable".
// A load is considered "uncacheable" if the data at the loaded memory location can be modified after
// the load instruction.
Expand All @@ -73,95 +170,8 @@ std::map<Instruction*, bool> compute_uncacheable_load_map(GradientUtils* gutils,
// For each load instruction, determine if it is uncacheable.
if (auto op = dyn_cast<LoadInst>(inst)) {

bool can_modref = false;
// Find the underlying object for the pointer operand of the load instruction.
auto obj = GetUnderlyingObject(op->getPointerOperand(), gutils->oldFunc->getParent()->getDataLayout(), 100);

//llvm::errs() << "underlying object for load " << *op << " is " << *obj << "\n";
// If the pointer operand is from an argument to the function, we need to check if the argument
// received from the caller is uncacheable.
if (auto arg = dyn_cast<Argument>(obj)) {
auto found = uncacheable_args.find(arg);
if (found == uncacheable_args.end()) {
llvm::errs() << "uncacheable_args:\n";
for(auto& pair : uncacheable_args) {
llvm::errs() << " + " << *pair.first << ": " << pair.second << " of func " << pair.first->getParent()->getName() << "\n";
}
llvm::errs() << "could not find " << *arg << " of func " << arg->getParent()->getName() << " in args_map\n";
}
assert(found != uncacheable_args.end());
if (found->second) {
//llvm::errs() << "OP is uncacheable arg: " << *op << "\n";
can_modref = true;
}
//llvm::errs() << " + argument (can_modref=" << can_modref << ") " << *op << " object: " << *obj << " arg: " << *arg << "e\n";
//TODO this case (alloca goes out of scope/allocation is freed and we dont force it to continue needs to be forcibly cached)
} else {
// NOTE(TFK): In the case where the underlying object for the pointer operand is from a Load or Call we need
// to check if we need to cache. Likely, we need to play it safe in this case and cache.
// NOTE(TFK): The logic below is an attempt at a conservative handling of the case mentioned above, but it
// needs to be verified.

// Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
if (auto obj_op = dyn_cast<CallInst>(obj)) {
Function* called = obj_op->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
called = fn;
}
}
}
}
if (called && isCertainMallocOrFree(called)) {
//llvm::errs() << "OP is certain malloc or free: " << *op << "\n";
} else {
//llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n";
can_modref = true;
}
} else if (isa<LoadInst>(obj)) {
// If obj is from a load instruction conservatively consider it uncacheable.
//llvm::errs() << "OP is from a load, needing to cache " << *op << "\n";
can_modref = true;
} else {
// In absence of more information, assume that the underlying object for pointer operand is uncacheable in caller.
//llvm::errs() << "OP is an unknown instruction, needing to cache " << *op << "\n";
can_modref = true;
}
}

for (inst_iterator I2 = inst_begin(*gutils->oldFunc), E2 = inst_end(*gutils->oldFunc); I2 != E2; ++I2) {
Instruction* inst2 = &*I2;
assert(inst->getParent()->getParent() == inst2->getParent()->getParent());
if (inst == inst2) continue;
if (!gutils->OrigDT.dominates(inst2, inst)) {

// Don't consider modref from malloc/free as a need to cache
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
Function* called = obj_op->getCalledFunction();
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
if (castinst->isCast()) {
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
called = fn;
}
}
}
}
if (called && isCertainMallocOrFree(called)) {
continue;
}
}

if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(op)))) {
can_modref = true;
//llvm::errs() << *inst << " needs to be cached due to: " << *inst2 << "\n";
break;
}
}
}
can_modref_map[inst] = can_modref;
can_modref_map[inst] = is_load_uncacheable(*op, AA, gutils, TLI, uncacheable_args);
}
}
return can_modref_map;
Expand All @@ -185,7 +195,10 @@ std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* ca
100);
//llvm::errs() << "ocs underlying object for callsite " << *callsite_op << " idx: " << i << " is " << *obj << "\n";
// If underlying object is an Argument, check parent volatility status.
if (auto arg = dyn_cast<Argument>(obj)) {
if (isa<UndefValue>(obj)) {
init_safe = true;
//llvm::errs() << " + ocs undef (safe=" << init_safe << ") " << *callsite_op << " object: " << *obj << "\n";
} else if (auto arg = dyn_cast<Argument>(obj)) {
auto found = parent_uncacheable_args.find(arg);
if (found == parent_uncacheable_args.end()) {
llvm::errs() << "parent_uncacheable_args:\n";
Expand All @@ -198,7 +211,7 @@ std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* ca
if (found->second) {
init_safe = false;
}
//llvm::errs() << " + ocs argument (safe=" << init_safe << ") " << *callsite_op << " object: " << *obj << " arg: " << *arg << "e\n";
//llvm::errs() << " + ocs argument (safe=" << init_safe << ") " << *callsite_op << " object: " << *obj << " arg: " << *arg << "å\n";
} else {
// Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
if (auto obj_op = dyn_cast<CallInst>(obj)) {
Expand Down Expand Up @@ -451,6 +464,20 @@ bool is_value_needed_in_reverse(TypeResults &TR, const GradientUtils* gutils, Va
if (op->getOpcode() == Instruction::FAdd || op->getOpcode() == Instruction::FSub) {
continue;
}
if (op->getOpcode() == Instruction::FMul) {
bool needed = false;
if (op->getOperand(0) == inst && !gutils->isConstantValue(gutils->getNewFromOriginal(op->getOperand(1)))) needed = true;
if (op->getOperand(1) == inst && !gutils->isConstantValue(gutils->getNewFromOriginal(op->getOperand(0)))) needed = true;
//llvm::errs() << "needed " << *inst << " in mul " << *op << " - needed:" << needed << "\n";
if (!needed) continue;
}

if (op->getOpcode() == Instruction::FDiv) {
bool needed = false;
if (op->getOperand(1) == inst && !gutils->isConstantValue(gutils->getNewFromOriginal(op->getOperand(1)))) needed = true;
if (op->getOperand(1) == inst && !gutils->isConstantValue(gutils->getNewFromOriginal(op->getOperand(0)))) needed = true;
if (!needed) continue;
}
}

//We don't need only the indices of a GEP to compute the adjoint of a GEP
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
#include "llvm/Analysis/MemoryDependenceAnalysis.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"

#include "llvm/Analysis/TypeBasedAliasAnalysis.h"

#if LLVM_VERSION_MAJOR > 6
#include "llvm/Analysis/PhiValues.h"
#endif
Expand Down Expand Up @@ -216,6 +219,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
#endif
);
AA.addAAResult(*baa);//(cache_AA[F]));
AA.addAAResult(*(new TypeBasedAAResult()));
return cache[F];
}
Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(), "preprocess_" + F->getName(), F->getParent());
Expand Down Expand Up @@ -349,6 +353,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
//cache_AA[F] = baa;
//llvm::errs() << " basicAA(f=" << F->getName() << ")=" << baa << "\n";
AA.addAAResult(*baa);
AA.addAAResult(*(new TypeBasedAAResult()));
//for(auto &a : AA.AAs) {
// llvm::errs() << "&AA: " << &AA << " added baa &a: " << a.get() << "\n";
//}
Expand Down
19 changes: 11 additions & 8 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
Value* val1 = nullptr;


if (false && isConstantValue(arg->getOperand(0)) && isConstantValue(arg->getOperand(1))) {
if (isConstantValue(arg->getOperand(0)) && isConstantValue(arg->getOperand(1))) {
llvm::errs() << *oldFunc << "\n";
llvm::errs() << *newFunc << "\n";
dumpSet(this->originalInstructions);
Expand All @@ -447,29 +447,32 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
invertedPointers[arg] = li;
return lookupM(invertedPointers[arg], BuilderM);
} else if (auto arg = dyn_cast<GetElementPtrInst>(val)) {
if (arg->getParent() == &arg->getParent()->getParent()->getEntryBlock()) {
//if (arg->getParent() == &arg->getParent()->getParent()->getEntryBlock()) {
IRBuilder<> bb(arg);
SmallVector<Value*,4> invertargs;
for(auto &a: arg->indices()) {
auto b = lookupM(a, bb);
invertargs.push_back(b);
}
auto result = bb.CreateGEP(invertPointerM(arg->getPointerOperand(), bb), invertargs, arg->getName()+"'ipge");
auto result = bb.CreateGEP(invertPointerM(arg->getPointerOperand(), bb), invertargs, arg->getName()+"'ipg");
if (auto gep = dyn_cast<GetElementPtrInst>(result))
gep->setIsInBounds(arg->isInBounds());
invertedPointers[arg] = result;
return lookupM(invertedPointers[arg], BuilderM);
}

//}
/*
SmallVector<Value*,4> invertargs;
for(auto &a: arg->indices()) {
auto b = lookupM(a, BuilderM);
invertargs.push_back(b);
}
auto result = BuilderM.CreateGEP(invertPointerM(arg->getPointerOperand(), BuilderM), invertargs, arg->getName()+"'ipg");

auto result = bb.CreateGEP(invertPointerM(arg->getPointerOperand(), BuilderM), invertargs, arg->getName()+"'ipg");
if (auto gep = dyn_cast<GetElementPtrInst>(result))
gep->setIsInBounds(arg->isInBounds());
return result;
invertedPointers[arg] = result;
return lookupM(invertedPointers[arg], BuilderM);
*/
} else if (auto inst = dyn_cast<AllocaInst>(val)) {
IRBuilder<> bb(inst);
AllocaInst* antialloca = bb.CreateAlloca(inst->getAllocatedType(), inst->getType()->getPointerAddressSpace(), inst->getArraySize(), inst->getName()+"'ipa");
Expand Down Expand Up @@ -579,7 +582,7 @@ std::pair<PHINode*,Instruction*> insertNewCanonicalIV(Loop* L, Type* Ty) {
PHINode *CanonicalIV = B.CreatePHI(Ty, 1, "iv");

B.SetInsertPoint(Header->getFirstNonPHIOrDbg());
Instruction* inc = cast<Instruction>(B.CreateNUWAdd(CanonicalIV, ConstantInt::get(CanonicalIV->getType(), 1), "iv.next"));
Instruction* inc = cast<Instruction>(B.CreateAdd(CanonicalIV, ConstantInt::get(CanonicalIV->getType(), 1), "iv.next", /*NUW*/true, /*NSW*/true));

for (BasicBlock *Pred : predecessors(Header)) {
assert(Pred);
Expand Down
Loading