Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions enzyme/Enzyme/ActiveVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,24 +369,39 @@ bool isconstantM(TypeResults &TR, Instruction* inst, SmallPtrSetImpl<Value*> &co
return true;
}

if (isa<LoadInst>(inst) || isa<StoreInst>(inst)) {
if (isa<StoreInst>(inst)) {
if (parseTBAA(inst).typeEnum == IntType::Integer) {
if (printconst)
llvm::errs() << " constant instruction from TBAA " << *inst << "\n";
constants.insert(inst);
return true;
}
}
if (auto li = dyn_cast<LoadInst>(inst)) {
if (constants.find(li->getPointerOperand()) != constants.end()) {
constants.insert(li);
return true;

if (directions & UP) {
if (auto li = dyn_cast<LoadInst>(inst)) {
if (constants.find(li->getPointerOperand()) != constants.end()) {
constants.insert(li);
return true;
}
}
}
if (auto rmw = dyn_cast<AtomicRMWInst>(inst)) {
if (constants.find(rmw->getPointerOperand()) != constants.end()) {
constants.insert(rmw);
return true;
if (auto rmw = dyn_cast<AtomicRMWInst>(inst)) {
if (constants.find(rmw->getPointerOperand()) != constants.end()) {
constants.insert(rmw);
return true;
}
}
if (auto gep = dyn_cast<GetElementPtrInst>(inst)) {
if (constants.find(gep->getPointerOperand()) != constants.end()) {
constants.insert(gep);
return true;
}
}
if (auto cst = dyn_cast<CastInst>(inst)) {
if (constants.find(cst->getOperand(0)) != constants.end()) {
constants.insert(cst);
return true;
}
}
}

Expand All @@ -412,8 +427,7 @@ bool isconstantM(TypeResults &TR, Instruction* inst, SmallPtrSetImpl<Value*> &co
// * integers that we know are not pointers
bool containsPointer = true;
if (inst->getType()->isFPOrFPVectorTy()) containsPointer = false;
// TODO propagate typeInfo here so can do more aggressive constant analysis rather than using empty map {}
if (inst->getType()->isIntOrIntVectorTy() && !TR.intType(inst, /*errIfNotFound*/false).isPossiblePointer()) containsPointer = false;
if (!TR.intType(inst, /*errIfNotFound*/false).isPossiblePointer()) containsPointer = false;

if (containsPointer) {

Expand Down Expand Up @@ -661,7 +675,7 @@ bool isconstantM(TypeResults &TR, Instruction* inst, SmallPtrSetImpl<Value*> &co
}

//TODO use typeInfo for more aggressive activity analysis
if (!(inst->getType()->isPointerTy() || (inst->getType()->isIntOrIntVectorTy() && TR.intType(inst, /*errIfNotFound*/false).isPossiblePointer()) ) && ( !inst->mayWriteToMemory() ) && (directions & DOWN) && (retvals.find(inst) == retvals.end()) ) {
if (!containsPointer && ( !inst->mayWriteToMemory() ) && (directions & DOWN) && (retvals.find(inst) == retvals.end()) ) {
//Proceed assuming this is constant, can we prove this should be constant otherwise
SmallPtrSet<Value*, 20> constants2;
constants2.insert(constants.begin(), constants.end());
Expand Down
17 changes: 12 additions & 5 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
if (enzyme_print)
llvm::errs() << "prefn:\n" << *fn << "\n";

std::set<unsigned> constants;
std::vector<DIFFE_TYPE> constants;
SmallVector<Value*,2> args;

unsigned truei = 0;
Expand All @@ -85,6 +85,8 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
auto MS = cast<MDString>(av->getMetadata())->getString();
if (MS == "diffe_dup") {
ty = DIFFE_TYPE::DUP_ARG;
} else if(MS == "diffe_dupnoneed") {
ty = DIFFE_TYPE::DUP_NONEED;
} else if(MS == "diffe_out") {
llvm::errs() << "saw metadata for diffe_out\n";
ty = DIFFE_TYPE::OUT_DIFF;
Expand All @@ -102,6 +104,10 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
ty = DIFFE_TYPE::DUP_ARG;
i++;
res = CI->getArgOperand(i);
} else if(MS == "diffe_dupnoneed") {
ty = DIFFE_TYPE::DUP_NONEED;
i++;
res = CI->getArgOperand(i);
} else if(MS == "diffe_out") {
llvm::errs() << "saw metadata for diffe_out\n";
ty = DIFFE_TYPE::OUT_DIFF;
Expand All @@ -119,8 +125,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo

//llvm::errs() << "considering arg " << *res << " argnum " << truei << "\n";

if (ty == DIFFE_TYPE::CONSTANT)
constants.insert(truei);
constants.push_back(ty);

assert(truei < FT->getNumParams());
if (PTy != res->getType()) {
Expand All @@ -143,7 +148,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
}

args.push_back(res);
if (ty == DIFFE_TYPE::DUP_ARG) {
if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) {
i++;

Value* res = CI->getArgOperand(i);
Expand Down Expand Up @@ -181,6 +186,8 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo

bool differentialReturn = cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy();

DIFFE_TYPE retType = whatType(cast<Function>(fn)->getReturnType());

std::map<Argument*, bool> volatile_args;
NewFnTypeInfo type_args(cast<Function>(fn));
for(auto &a : type_args.function->args()) {
Expand All @@ -204,7 +211,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
TypeAnalysis TA;
type_args = TA.analyzeFunction(type_args).getAnalyzedTypeInfo();

auto newFunc = CreatePrimalAndGradient(cast<Function>(fn), constants, TLI, TA, AA, /*should return*/false, differentialReturn, /*dretPtr*/false, /*topLevel*/true, /*addedType*/nullptr, type_args, volatile_args, /*index mapping*/nullptr); //llvm::Optional<std::map<std::pair<Instruction*, std::string>, unsigned>>({}));
auto newFunc = CreatePrimalAndGradient(cast<Function>(fn), retType, constants, TLI, TA, AA, /*should return*/false, /*dretPtr*/false, /*topLevel*/true, /*addedType*/nullptr, type_args, volatile_args, /*index mapping*/nullptr); //llvm::Optional<std::map<std::pair<Instruction*, std::string>, unsigned>>({}));

if (differentialReturn)
args.push_back(ConstantFP::get(cast<Function>(fn)->getReturnType(), 1.0));
Expand Down
Loading