Skip to content

Commit b9049ce

Browse files
committed
Better alias info
1 parent 3973cfb commit b9049ce

File tree

13 files changed

+898
-135
lines changed

13 files changed

+898
-135
lines changed

enzyme/Enzyme/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ file(GLOB ENZYME_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
77
)
88

99
list(APPEND ENZYME_SRC SCEV/ScalarEvolutionExpander.cpp) # Attributor/Attributor.cpp)
10-
10+
1111
message("found enzyme sources " ${ENZYME_SRC})
1212

1313
if (${LLVM_VERSION_MAJOR} LESS 8)

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 95 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ bool is_value_mustcache_from_origin(Value* obj, AAResults& AA, GradientUtils* gu
116116
mustcache = true;
117117
}
118118
//llvm::errs() << " + argument (mustcache=" << mustcache << ") " << " object: " << *obj << " arg: " << *arg << "e\n";
119-
//TODO this case (alloca goes out of scope/allocation is freed and we dont force it to continue needs to be forcibly cached)
120119
} else {
121120

122121
// Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
@@ -158,34 +157,36 @@ bool is_load_uncacheable(LoadInst& li, AAResults& AA, GradientUtils* gutils, Tar
158157
// Find the underlying object for the pointer operand of the load instruction.
159158
auto obj = GetUnderlyingObject(li.getPointerOperand(), gutils->oldFunc->getParent()->getDataLayout(), 100);
160159

161-
//llvm::errs() << "underlying object for load " << li << " is " << *obj << "\n";
162160

163161
bool can_modref = is_value_mustcache_from_origin(obj, AA, gutils, TLI, uncacheable_args);
164162

165-
allFollowersOf(&li, [&](Instruction* inst2) {
166-
// Don't consider modref from malloc/free as a need to cache
167-
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
168-
Function* called = obj_op->getCalledFunction();
169-
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
170-
if (castinst->isCast()) {
171-
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
172-
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
173-
called = fn;
163+
//llvm::errs() << "underlying object for load " << li << " is " << *obj << " fromorigin: " << can_modref << "\n";
164+
165+
if (!can_modref) {
166+
allFollowersOf(&li, [&](Instruction* inst2) {
167+
// Don't consider modref from malloc/free as a need to cache
168+
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
169+
Function* called = obj_op->getCalledFunction();
170+
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
171+
if (castinst->isCast()) {
172+
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
173+
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
174+
called = fn;
175+
}
174176
}
175177
}
176178
}
179+
if (called && isCertainMallocOrFree(called)) {
180+
return;
181+
}
177182
}
178-
if (called && isCertainMallocOrFree(called)) {
183+
184+
if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(&li)))) {
185+
can_modref = true;
179186
return;
180187
}
181-
}
182-
183-
if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(&li)))) {
184-
can_modref = true;
185-
//llvm::errs() << li << " needs to be cached due to: " << *inst2 << "\n";
186-
return;
187-
}
188-
});
188+
});
189+
}
189190

190191
//llvm::errs() << "F - " << li << " can_modref" << can_modref << "\n";
191192
return can_modref;
@@ -210,9 +211,13 @@ std::map<Instruction*, bool> compute_uncacheable_load_map(GradientUtils* gutils,
210211
std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* callsite_op, DominatorTree &DT,
211212
TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, const std::map<Argument*, bool> parent_uncacheable_args) {
212213

214+
if (!callsite_op->getCalledFunction()) return {};
215+
213216
std::vector<Value*> args;
214217
std::vector<bool> args_safe;
215218

219+
//llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n";
220+
216221
// First, we need to propagate the uncacheable status from the parent function to the callee.
217222
// because memory location x modified after parent returns => x modified after callee returns.
218223
for (unsigned i = 0; i < callsite_op->getNumArgOperands(); i++) {
@@ -231,78 +236,45 @@ std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* ca
231236

232237
// Second, we check for memory modifications that can occur in the continuation of the
233238
// callee inside the parent function.
234-
for (inst_iterator I = inst_begin(*gutils->oldFunc), E = inst_end(*gutils->oldFunc); I != E; ++I) {
235-
Instruction* inst = &*I;
236-
assert(inst->getParent()->getParent() == callsite_op->getParent()->getParent());
237-
238-
if (inst == callsite_op) continue;
239-
240-
// If the "inst" does not dominate "callsite_op" then we cannot prove that
241-
// "inst" happens before "callsite_op". If "inst" modifies an argument of the call,
242-
// then that call needs to consider the argument uncacheable.
243-
// To correctly handle case where inst == callsite_op, we need to look at next instruction after callsite_op.
244-
if (!gutils->OrigDT.dominates(inst, callsite_op)) {
245-
//llvm::errs() << "Instruction " << *inst << " DOES NOT dominates " << *callsite_op << "\n";
246-
// Consider Store Instructions.
247-
if (auto op = dyn_cast<StoreInst>(inst)) {
248-
for (unsigned i = 0; i < args.size(); i++) {
249-
// If the modification flag is set, then this instruction may modify the $i$th argument of the call.
250-
if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) {
251-
//llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n";
252-
} else {
253-
//llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n";
254-
args_safe[i] = false;
239+
allFollowersOf(callsite_op, [&](Instruction* inst2) {
240+
// Don't consider modref from malloc/free as a need to cache
241+
if (auto obj_op = dyn_cast<CallInst>(inst2)) {
242+
Function* called = obj_op->getCalledFunction();
243+
if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue())) {
244+
if (castinst->isCast()) {
245+
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
246+
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
247+
called = fn;
248+
}
255249
}
256250
}
257251
}
252+
if (called && isCertainMallocOrFree(called)) {
253+
return;
254+
}
255+
}
258256

259-
// Consider Call Instructions.
260-
if (auto op = dyn_cast<CallInst>(inst)) {
261-
//llvm::errs() << "OP is call inst: " << *op << "\n";
262-
// Ignore memory allocation functions.
263-
Function* called = op->getCalledFunction();
264-
if (auto castinst = dyn_cast<ConstantExpr>(op->getCalledValue())) {
265-
if (castinst->isCast()) {
266-
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
267-
if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) {
268-
called = fn;
269-
}
270-
}
271-
}
272-
}
273-
if (isCertainMallocOrFree(called)) {
274-
//llvm::errs() << "OP is certain malloc or free: " << *op << "\n";
275-
continue;
276-
}
277257

278-
// For all the arguments, perform same check as for Stores, but ignore non-pointer arguments.
279-
for (unsigned i = 0; i < args.size(); i++) {
280-
if (!args[i]->getType()->isPointerTy()) continue; // Ignore non-pointer arguments.
281-
if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) {
282-
//llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n";
283-
} else {
284-
//llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n";
285-
args_safe[i] = false;
286-
}
287-
}
258+
for (unsigned i = 0; i < args.size(); i++) {
259+
if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::getForArgument(callsite_op, i, TLI)))) {
260+
args_safe[i] = false;
261+
//llvm::errs() << "Instruction " << *inst2 << " is maybe ModRef with call argument " << *args[i] << "\n";
288262
}
289-
} else {
290-
//llvm::errs() << "Instruction " << *inst << " DOES dominates " << *callsite_op << "\n";
291263
}
292-
}
264+
});
293265

294266
std::map<Argument*, bool> uncacheable_args;
295-
//llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n";
296-
if (callsite_op->getCalledFunction()) {
297267

298268
auto arg = callsite_op->getCalledFunction()->arg_begin();
299269
for (unsigned i = 0; i < args.size(); i++) {
300270
uncacheable_args[arg] = !args_safe[i];
301-
//llvm::errs() << "callArg: " << *args[i] << " arg:" << *arg << " STATUS: " << args_safe[i] << "\n";
271+
//llvm::errs() << "callArg: " << *args[i] << " arg:" << *arg << " uncacheable: " << uncacheable_args[arg] << "\n";
302272
arg++;
273+
if (arg ==callsite_op->getCalledFunction()->arg_end()) {
274+
break;
275+
}
303276
}
304277

305-
}
306278
return uncacheable_args;
307279
}
308280

@@ -661,38 +633,57 @@ bool legalCombinedForwardReverse(CallInst &ci, const std::map<ReturnInst*,StoreI
661633
return false;
662634
}
663635

664-
auto getMRI = [&](Instruction* inst, Instruction* inst2) {
665-
if (auto call = dyn_cast<CallInst>(inst)) {
636+
auto writesToMemoryReadBy = [&](Instruction* maybeReader, Instruction* maybeWriter) -> bool {
637+
if (auto call = dyn_cast<CallInst>(maybeWriter)) {
666638
if (call->getCalledFunction() && isCertainMallocOrFree(call->getCalledFunction())) {
667-
return ModRefInfo::NoModRef;
639+
return false;
668640
}
669641
}
642+
if (auto call = dyn_cast<CallInst>(maybeReader)) {
643+
if (call->getCalledFunction() && isCertainMallocOrFree(call->getCalledFunction())) {
644+
return false;
645+
}
646+
}
647+
if (auto call = dyn_cast<InvokeInst>(maybeWriter)) {
648+
if (call->getCalledFunction() && isCertainMallocOrFree(call->getCalledFunction())) {
649+
return false;
650+
}
651+
}
652+
if (auto call = dyn_cast<InvokeInst>(maybeReader)) {
653+
if (call->getCalledFunction() && isCertainMallocOrFree(call->getCalledFunction())) {
654+
return false;
655+
}
656+
}
657+
assert(maybeWriter->mayWriteToMemory());
658+
assert(maybeReader->mayReadFromMemory());
670659

671-
if (auto li = dyn_cast<LoadInst>(inst2)) {
672-
return gutils->AA.getModRefInfo(inst, MemoryLocation::get(li));
660+
if (auto li = dyn_cast<LoadInst>(maybeReader)) {
661+
return isModSet(gutils->AA.getModRefInfo(maybeWriter, MemoryLocation::get(li)));
673662
}
674-
if (auto si = dyn_cast<StoreInst>(inst2)) {
675-
return gutils->AA.getModRefInfo(inst, MemoryLocation::get(si));
663+
if (auto rmw = dyn_cast<AtomicRMWInst>(maybeReader)) {
664+
return isModSet(gutils->AA.getModRefInfo(maybeWriter, MemoryLocation::get(rmw)));
676665
}
677-
if (auto rmw = dyn_cast<AtomicRMWInst>(inst2)) {
678-
return gutils->AA.getModRefInfo(inst, MemoryLocation::get(rmw));
666+
if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeReader)) {
667+
return isModSet(gutils->AA.getModRefInfo(maybeWriter, MemoryLocation::get(xch)));
679668
}
680-
if (auto xch = dyn_cast<AtomicCmpXchgInst>(inst2)) {
681-
return gutils->AA.getModRefInfo(inst, MemoryLocation::get(xch));
669+
670+
if (auto si = dyn_cast<StoreInst>(maybeWriter)) {
671+
return isRefSet(gutils->AA.getModRefInfo(maybeReader, MemoryLocation::get(si)));
682672
}
683-
if (auto cb = dyn_cast<CallInst>(inst2)) {
684-
if (cb->getCalledFunction() && isCertainMallocOrFree(cb->getCalledFunction())) {
685-
return ModRefInfo::NoModRef;
686-
}
687-
return gutils->AA.getModRefInfo(inst, cb);
673+
if (auto rmw = dyn_cast<AtomicRMWInst>(maybeWriter)) {
674+
return isRefSet(gutils->AA.getModRefInfo(maybeReader, MemoryLocation::get(rmw)));
688675
}
689-
if (auto cb = dyn_cast<InvokeInst>(inst2)) {
690-
if (cb->getCalledFunction() && isCertainMallocOrFree(cb->getCalledFunction())) {
691-
return ModRefInfo::NoModRef;
692-
}
693-
return gutils->AA.getModRefInfo(inst, cb);
676+
if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeWriter)) {
677+
return isRefSet(gutils->AA.getModRefInfo(maybeReader, MemoryLocation::get(xch)));
678+
}
679+
680+
if (auto cb = dyn_cast<CallInst>(maybeReader)) {
681+
return isModOrRefSet(gutils->AA.getModRefInfo(maybeWriter, cb));
694682
}
695-
llvm::errs() << " inst2: " << *inst2 << "\n";
683+
if (auto cb = dyn_cast<InvokeInst>(maybeReader)) {
684+
return isModOrRefSet(gutils->AA.getModRefInfo(maybeWriter, cb));
685+
}
686+
llvm::errs() << " maybeReader: " << *maybeReader << " maybeWriter: " << *maybeWriter << "\n";
696687
llvm_unreachable("unknown inst2");
697688
};
698689

@@ -806,12 +797,11 @@ bool legalCombinedForwardReverse(CallInst &ci, const std::map<ReturnInst*,StoreI
806797
if (inst->mayWriteToMemory()) {
807798
auto consider = [&](Instruction* user) {
808799
if (!user->mayReadFromMemory()) return;
809-
auto mri = getMRI(user, inst);
810-
//llvm::errs() << " checking if need follower of " << *inst << " - " << *user << " : mri " << mri << "\n";
811-
if (isRefSet(mri)) {
800+
if (writesToMemoryReadBy(/*maybeReader*/user, /*maybeWriter*/inst)) {
801+
//llvm::errs() << " memory deduced need follower of " << *inst << " - " << *user << "\n";
812802
propagate(user);
813803
if (!legal) return;
814-
}
804+
}
815805
};
816806
allFollowersOf(inst, consider);
817807
if (!legal) return false;
@@ -826,13 +816,14 @@ bool legalCombinedForwardReverse(CallInst &ci, const std::map<ReturnInst*,StoreI
826816
// llvm::errs() << " + " << *u << "\n";
827817

828818
// Check if any of the unmoved operations will make it illegal to move the instruction
819+
829820
for (auto inst : usetree) {
830821
if (!inst->mayReadFromMemory()) continue;
831822
allFollowersOf(inst, [&](Instruction* post) {
823+
if (unnecessaryInstructions.count(post)) return;
832824
if (!post->mayWriteToMemory()) return;
833825
//llvm::errs() << " checking if illegal move of " << *inst << " due to " << *post << "\n";
834-
auto mri = getMRI(inst, post);
835-
if (isModSet(mri)) {
826+
if (writesToMemoryReadBy(/*maybeReader*/inst, /*maybeWriter*/post)) {
836827
if (called)
837828
llvm::errs() << " failed to replace function " << (called->getName()) << " due to " << *post << " usetree: " << *inst << "\n";
838829
else

0 commit comments

Comments
 (0)