Skip to content

Commit eb14a66

Browse files
committed
[Caching] more aggressive load caching & use TBAA results
1 parent 205f9a2 commit eb14a66

File tree

8 files changed

+633
-175
lines changed

8 files changed

+633
-175
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 104 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,103 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
6262
"enzyme_nonmarkedglobals_inactiveloads", cl::init(true), cl::Hidden,
6363
cl::desc("Consider loads of nonmarked globals to be inactive"));
6464

65+
66+
bool is_load_uncacheable(LoadInst& li, AAResults& AA, GradientUtils* gutils, TargetLibraryInfo& TLI,
67+
const std::map<Argument*, bool>& uncacheable_args) {
68+
69+
bool can_modref = false;
70+
// Find the underlying object for the pointer operand of the load instruction.
71+
auto obj = GetUnderlyingObject(li.getPointerOperand(), gutils->oldFunc->getParent()->getDataLayout(), 100);
72+
73+
//llvm::errs() << "underlying object for load " << li << " is " << *obj << "\n";
74+
// If the pointer operand is from an argument to the function, we need to check if the argument
75+
// received from the caller is uncacheable.
76+
if (auto arg = dyn_cast<Argument>(obj)) {
77+
auto found = uncacheable_args.find(arg);
78+
if (found == uncacheable_args.end()) {
79+
llvm::errs() << "uncacheable_args:\n";
80+
for(auto& pair : uncacheable_args) {
81+
llvm::errs() << " + " << *pair.first << ": " << pair.second << " of func " << pair.first->getParent()->getName() << "\n";
82+
}
83+
llvm::errs() << "could not find " << *arg << " of func " << arg->getParent()->getName() << " in args_map\n";
84+
}
85+
assert(found != uncacheable_args.end());
86+
if (found->second) {
87+
//llvm::errs() << "OP is uncacheable arg: " << li << "\n";
88+
can_modref = true;
89+
}
90+
//llvm::errs() << " + argument (can_modref=" << can_modref << ") " << li << " object: " << *obj << " arg: " << *arg << "e\n";
91+
//TODO this case (alloca goes out of scope/allocation is freed and we dont force it to continue needs to be forcibly cached)
92+
} else {
93+
// NOTE(TFK): In the case where the underlying object for the pointer operand is from a Load or Call we need
94+
// to check if we need to cache. Likely, we need to play it safe in this case and cache.
95+
// NOTE(TFK): The logic below is an attempt at a conservative handling of the case mentioned above, but it
96+
// needs to be verified.
97+
98+
// Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
99+
if (auto obj_op = dyn_cast<CallInst>(obj)) {
100+
Function* called = obj_op->getCalledFunction();
101+
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
102+
if (castinst->isCast()) {
103+
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
104+
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
105+
called = fn;
106+
}
107+
}
108+
}
109+
}
110+
if (called && isCertainMallocOrFree(called)) {
111+
//llvm::errs() << "OP is certain malloc or free: " << *op << "\n";
112+
} else {
113+
//llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n";
114+
can_modref = true;
115+
}
116+
} else if (auto sli = dyn_cast<LoadInst>(obj)) {
117+
// If obj is from a load instruction conservatively consider it uncacheable if that load itself cannot be cached
118+
//llvm::errs() << "OP is from a load, needing to cache " << *op << "\n";
119+
can_modref = is_load_uncacheable(*sli, AA, gutils, TLI, uncacheable_args);
120+
} else {
121+
// In absence of more information, assume that the underlying object for pointer operand is uncacheable in caller.
122+
//llvm::errs() << "OP is an unknown instruction, needing to cache " << *op << "\n";
123+
can_modref = true;
124+
}
125+
}
126+
127+
for (inst_iterator I2 = inst_begin(*gutils->oldFunc), E2 = inst_end(*gutils->oldFunc); I2 != E2; ++I2) {
128+
Instruction* inst2 = &*I2;
129+
assert(li.getParent()->getParent() == inst2->getParent()->getParent());
130+
if (&li == inst2) continue;
131+
if (!gutils->OrigDT.dominates(inst2, &li)) {
132+
133+
// Don't consider modref from malloc/free as a need to cache
134+
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
135+
Function* called = obj_op->getCalledFunction();
136+
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
137+
if (castinst->isCast()) {
138+
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
139+
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
140+
called = fn;
141+
}
142+
}
143+
}
144+
}
145+
if (called && isCertainMallocOrFree(called)) {
146+
continue;
147+
}
148+
}
149+
150+
if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(&li)))) {
151+
can_modref = true;
152+
llvm::errs() << li << " needs to be cached due to: " << *inst2 << "\n";
153+
break;
154+
}
155+
}
156+
}
157+
//llvm::errs() << "F - " << li << " can_modref" << can_modref << "\n";
158+
return can_modref;
159+
160+
}
161+
65162
// Computes a map of LoadInst -> boolean for a function indicating whether that load is "uncacheable".
66163
// A load is considered "uncacheable" if the data at the loaded memory location can be modified after
67164
// the load instruction.
@@ -70,98 +167,11 @@ std::map<Instruction*, bool> compute_uncacheable_load_map(GradientUtils* gutils,
70167
std::map<Instruction*, bool> can_modref_map;
71168
for (inst_iterator I = inst_begin(*gutils->oldFunc), E = inst_end(*gutils->oldFunc); I != E; ++I) {
72169
Instruction* inst = &*I;
73-
// For each load instruction, determine if it is xuncacheable.
170+
// For each load instruction, determine if it is uncacheable.
74171
if (auto op = dyn_cast<LoadInst>(inst)) {
75172

76-
bool can_modref = false;
77-
// Find the underlying object for the pointer operand of the load instruction.
78-
auto obj = GetUnderlyingObject(op->getPointerOperand(), gutils->oldFunc->getParent()->getDataLayout(), 100);
79-
80-
//llvm::errs() << "underlying object for load " << *op << " is " << *obj << "\n";
81-
// If the pointer operand is from an argument to the function, we need to check if the argument
82-
// received from the caller is uncacheable.
83-
if (auto arg = dyn_cast<Argument>(obj)) {
84-
auto found = uncacheable_args.find(arg);
85-
if (found == uncacheable_args.end()) {
86-
llvm::errs() << "uncacheable_args:\n";
87-
for(auto& pair : uncacheable_args) {
88-
llvm::errs() << " + " << *pair.first << ": " << pair.second << " of func " << pair.first->getParent()->getName() << "\n";
89-
}
90-
llvm::errs() << "could not find " << *arg << " of func " << arg->getParent()->getName() << " in args_map\n";
91-
}
92-
assert(found != uncacheable_args.end());
93-
if (found->second) {
94-
//llvm::errs() << "OP is uncacheable arg: " << *op << "\n";
95-
can_modref = true;
96-
}
97-
//llvm::errs() << " + argument (can_modref=" << can_modref << ") " << *op << " object: " << *obj << " arg: " << *arg << "e\n";
98-
//TODO this case (alloca goes out of scope/allocation is freed and we dont force it to continue needs to be forcibly cached)
99-
} else {
100-
// NOTE(TFK): In the case where the underlying object for the pointer operand is from a Load or Call we need
101-
// to check if we need to cache. Likely, we need to play it safe in this case and cache.
102-
// NOTE(TFK): The logic below is an attempt at a conservative handling of the case mentioned above, but it
103-
// needs to be verified.
104-
105-
// Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
106-
if (auto obj_op = dyn_cast<CallInst>(obj)) {
107-
Function* called = obj_op->getCalledFunction();
108-
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
109-
if (castinst->isCast()) {
110-
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
111-
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
112-
called = fn;
113-
}
114-
}
115-
}
116-
}
117-
if (called && isCertainMallocOrFree(called)) {
118-
//llvm::errs() << "OP is certain malloc or free: " << *op << "\n";
119-
} else {
120-
//llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n";
121-
can_modref = true;
122-
}
123-
} else if (isa<LoadInst>(obj)) {
124-
// If obj is from a load instruction conservatively consider it uncacheable.
125-
//llvm::errs() << "OP is from a load, needing to cache " << *op << "\n";
126-
can_modref = true;
127-
} else {
128-
// In absence of more information, assume that the underlying object for pointer operand is uncacheable in caller.
129-
//llvm::errs() << "OP is an unknown instruction, needing to cache " << *op << "\n";
130-
can_modref = true;
131-
}
132-
}
133-
134-
for (inst_iterator I2 = inst_begin(*gutils->oldFunc), E2 = inst_end(*gutils->oldFunc); I2 != E2; ++I2) {
135-
Instruction* inst2 = &*I2;
136-
assert(inst->getParent()->getParent() == inst2->getParent()->getParent());
137-
if (inst == inst2) continue;
138-
if (!gutils->OrigDT.dominates(inst2, inst)) {
139-
140-
// Don't consider modref from malloc/free as a need to cache
141-
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
142-
Function* called = obj_op->getCalledFunction();
143-
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
144-
if (castinst->isCast()) {
145-
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
146-
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
147-
called = fn;
148-
}
149-
}
150-
}
151-
}
152-
if (called && isCertainMallocOrFree(called)) {
153-
continue;
154-
}
155-
}
156173

157-
if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(op)))) {
158-
can_modref = true;
159-
//llvm::errs() << *inst << " needs to be cached due to: " << *inst2 << "\n";
160-
break;
161-
}
162-
}
163-
}
164-
can_modref_map[inst] = can_modref;
174+
can_modref_map[inst] = is_load_uncacheable(*op, AA, gutils, TLI, uncacheable_args);
165175
}
166176
}
167177
return can_modref_map;
@@ -185,7 +195,10 @@ std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* ca
185195
100);
186196
//llvm::errs() << "ocs underlying object for callsite " << *callsite_op << " idx: " << i << " is " << *obj << "\n";
187197
// If underlying object is an Argument, check parent volatility status.
188-
if (auto arg = dyn_cast<Argument>(obj)) {
198+
if (isa<UndefValue>(obj)) {
199+
init_safe = true;
200+
//llvm::errs() << " + ocs undef (safe=" << init_safe << ") " << *callsite_op << " object: " << *obj << "\n";
201+
} else if (auto arg = dyn_cast<Argument>(obj)) {
189202
auto found = parent_uncacheable_args.find(arg);
190203
if (found == parent_uncacheable_args.end()) {
191204
llvm::errs() << "parent_uncacheable_args:\n";
@@ -198,7 +211,7 @@ std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* ca
198211
if (found->second) {
199212
init_safe = false;
200213
}
201-
//llvm::errs() << " + ocs argument (safe=" << init_safe << ") " << *callsite_op << " object: " << *obj << " arg: " << *arg << "e\n";
214+
//llvm::errs() << " + ocs argument (safe=" << init_safe << ") " << *callsite_op << " object: " << *obj << " arg: " << *arg << "å\n";
202215
} else {
203216
// Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
204217
if (auto obj_op = dyn_cast<CallInst>(obj)) {

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
#include "llvm/Analysis/MemoryDependenceAnalysis.h"
3838
#include "llvm/Analysis/MemorySSA.h"
3939
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
40+
41+
#include "llvm/Analysis/TypeBasedAliasAnalysis.h"
42+
4043
#if LLVM_VERSION_MAJOR > 6
4144
#include "llvm/Analysis/PhiValues.h"
4245
#endif
@@ -216,6 +219,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
216219
#endif
217220
);
218221
AA.addAAResult(*baa);//(cache_AA[F]));
222+
AA.addAAResult(*(new TypeBasedAAResult()));
219223
return cache[F];
220224
}
221225
Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(), "preprocess_" + F->getName(), F->getParent());
@@ -349,6 +353,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
349353
//cache_AA[F] = baa;
350354
//llvm::errs() << " basicAA(f=" << F->getName() << ")=" << baa << "\n";
351355
AA.addAAResult(*baa);
356+
AA.addAAResult(*(new TypeBasedAAResult()));
352357
//for(auto &a : AA.AAs) {
353358
// llvm::errs() << "&AA: " << &AA << " added baa &a: " << a.get() << "\n";
354359
//}

0 commit comments

Comments
 (0)