@@ -116,7 +116,6 @@ bool is_value_mustcache_from_origin(Value* obj, AAResults& AA, GradientUtils* gu
116116 mustcache = true ;
117117 }
118118 // llvm::errs() << " + argument (mustcache=" << mustcache << ") " << " object: " << *obj << " arg: " << *arg << "e\n";
119- // TODO this case (alloca goes out of scope/allocation is freed and we dont force it to continue needs to be forcibly cached)
120119 } else {
121120
122121 // Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
@@ -158,34 +157,36 @@ bool is_load_uncacheable(LoadInst& li, AAResults& AA, GradientUtils* gutils, Tar
158157 // Find the underlying object for the pointer operand of the load instruction.
159158 auto obj = GetUnderlyingObject (li.getPointerOperand (), gutils->oldFunc ->getParent ()->getDataLayout (), 100 );
160159
161- // llvm::errs() << "underlying object for load " << li << " is " << *obj << "\n";
162160
163161 bool can_modref = is_value_mustcache_from_origin (obj, AA, gutils, TLI, uncacheable_args);
164162
165- allFollowersOf (&li, [&](Instruction* inst2) {
166- // Don't consider modref from malloc/free as a need to cache
167- if (auto obj_op = dyn_cast<CallInst>(inst2)) {
168- Function* called = obj_op->getCalledFunction ();
169- if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue ())) {
170- if (castinst->isCast ()) {
171- if (auto fn = dyn_cast<Function>(castinst->getOperand (0 ))) {
172- if (isAllocationFunction (*fn, TLI) || isDeallocationFunction (*fn, TLI)) {
173- called = fn;
163+ // llvm::errs() << "underlying object for load " << li << " is " << *obj << " fromorigin: " << can_modref << "\n";
164+
165+ if (!can_modref) {
166+ allFollowersOf (&li, [&](Instruction* inst2) {
167+ // Don't consider modref from malloc/free as a need to cache
168+ if (auto obj_op = dyn_cast<CallInst>(inst2)) {
169+ Function* called = obj_op->getCalledFunction ();
170+ if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue ())) {
171+ if (castinst->isCast ()) {
172+ if (auto fn = dyn_cast<Function>(castinst->getOperand (0 ))) {
173+ if (isAllocationFunction (*fn, TLI) || isDeallocationFunction (*fn, TLI)) {
174+ called = fn;
175+ }
174176 }
175177 }
176178 }
179+ if (called && isCertainMallocOrFree (called)) {
180+ return ;
181+ }
177182 }
178- if (called && isCertainMallocOrFree (called)) {
183+
184+ if (llvm::isModSet (AA.getModRefInfo (inst2, MemoryLocation::get (&li)))) {
185+ can_modref = true ;
179186 return ;
180187 }
181- }
182-
183- if (llvm::isModSet (AA.getModRefInfo (inst2, MemoryLocation::get (&li)))) {
184- can_modref = true ;
185- // llvm::errs() << li << " needs to be cached due to: " << *inst2 << "\n";
186- return ;
187- }
188- });
188+ });
189+ }
189190
190191 // llvm::errs() << "F - " << li << " can_modref" << can_modref << "\n";
191192 return can_modref;
@@ -210,9 +211,13 @@ std::map<Instruction*, bool> compute_uncacheable_load_map(GradientUtils* gutils,
210211std::map<Argument*, bool > compute_uncacheable_args_for_one_callsite (CallInst* callsite_op, DominatorTree &DT,
211212 TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, const std::map<Argument*, bool > parent_uncacheable_args) {
212213
214+ if (!callsite_op->getCalledFunction ()) return {};
215+
213216 std::vector<Value*> args;
214217 std::vector<bool > args_safe;
215218
219+ // llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n";
220+
216221 // First, we need to propagate the uncacheable status from the parent function to the callee.
217222 // because memory location x modified after parent returns => x modified after callee returns.
218223 for (unsigned i = 0 ; i < callsite_op->getNumArgOperands (); i++) {
@@ -231,78 +236,45 @@ std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* ca
231236
232237 // Second, we check for memory modifications that can occur in the continuation of the
233238 // callee inside the parent function.
234- for (inst_iterator I = inst_begin (*gutils->oldFunc ), E = inst_end (*gutils->oldFunc ); I != E; ++I) {
235- Instruction* inst = &*I;
236- assert (inst->getParent ()->getParent () == callsite_op->getParent ()->getParent ());
237-
238- if (inst == callsite_op) continue ;
239-
240- // If the "inst" does not dominate "callsite_op" then we cannot prove that
241- // "inst" happens before "callsite_op". If "inst" modifies an argument of the call,
242- // then that call needs to consider the argument uncacheable.
243- // To correctly handle case where inst == callsite_op, we need to look at next instruction after callsite_op.
244- if (!gutils->OrigDT .dominates (inst, callsite_op)) {
245- // llvm::errs() << "Instruction " << *inst << " DOES NOT dominates " << *callsite_op << "\n";
246- // Consider Store Instructions.
247- if (auto op = dyn_cast<StoreInst>(inst)) {
248- for (unsigned i = 0 ; i < args.size (); i++) {
249- // If the modification flag is set, then this instruction may modify the $i$th argument of the call.
250- if (!llvm::isModSet (AA.getModRefInfo (op, MemoryLocation::getForArgument (callsite_op, i, TLI)))) {
251- // llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n";
252- } else {
253- // llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n";
254- args_safe[i] = false ;
239+ allFollowersOf (callsite_op, [&](Instruction* inst2) {
240+ // Don't consider modref from malloc/free as a need to cache
241+ if (auto obj_op = dyn_cast<CallInst>(inst2)) {
242+ Function* called = obj_op->getCalledFunction ();
243+ if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue ())) {
244+ if (castinst->isCast ()) {
245+ if (auto fn = dyn_cast<Function>(castinst->getOperand (0 ))) {
246+ if (isAllocationFunction (*fn, TLI) || isDeallocationFunction (*fn, TLI)) {
247+ called = fn;
248+ }
255249 }
256250 }
257251 }
252+ if (called && isCertainMallocOrFree (called)) {
253+ return ;
254+ }
255+ }
258256
259- // Consider Call Instructions.
260- if (auto op = dyn_cast<CallInst>(inst)) {
261- // llvm::errs() << "OP is call inst: " << *op << "\n";
262- // Ignore memory allocation functions.
263- Function* called = op->getCalledFunction ();
264- if (auto castinst = dyn_cast<ConstantExpr>(op->getCalledValue ())) {
265- if (castinst->isCast ()) {
266- if (auto fn = dyn_cast<Function>(castinst->getOperand (0 ))) {
267- if (isAllocationFunction (*fn, TLI) || isDeallocationFunction (*fn, TLI)) {
268- called = fn;
269- }
270- }
271- }
272- }
273- if (isCertainMallocOrFree (called)) {
274- // llvm::errs() << "OP is certain malloc or free: " << *op << "\n";
275- continue ;
276- }
277257
278- // For all the arguments, perform same check as for Stores, but ignore non-pointer arguments.
279- for (unsigned i = 0 ; i < args.size (); i++) {
280- if (!args[i]->getType ()->isPointerTy ()) continue ; // Ignore non-pointer arguments.
281- if (!llvm::isModSet (AA.getModRefInfo (op, MemoryLocation::getForArgument (callsite_op, i, TLI)))) {
282- // llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n";
283- } else {
284- // llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n";
285- args_safe[i] = false ;
286- }
287- }
258+ for (unsigned i = 0 ; i < args.size (); i++) {
259+ if (llvm::isModSet (AA.getModRefInfo (inst2, MemoryLocation::getForArgument (callsite_op, i, TLI)))) {
260+ args_safe[i] = false ;
261+ // llvm::errs() << "Instruction " << *inst2 << " is maybe ModRef with call argument " << *args[i] << "\n";
288262 }
289- } else {
290- // llvm::errs() << "Instruction " << *inst << " DOES dominates " << *callsite_op << "\n";
291263 }
292- }
264+ });
293265
294266 std::map<Argument*, bool > uncacheable_args;
295- // llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n";
296- if (callsite_op->getCalledFunction ()) {
297267
298268 auto arg = callsite_op->getCalledFunction ()->arg_begin ();
299269 for (unsigned i = 0 ; i < args.size (); i++) {
300270 uncacheable_args[arg] = !args_safe[i];
301- // llvm::errs() << "callArg: " << *args[i] << " arg:" << *arg << " STATUS : " << args_safe[i ] << "\n";
271+ // llvm::errs() << "callArg: " << *args[i] << " arg:" << *arg << " uncacheable : " << uncacheable_args[arg ] << "\n";
302272 arg++;
273+ if (arg ==callsite_op->getCalledFunction ()->arg_end ()) {
274+ break ;
275+ }
303276 }
304277
305- }
306278 return uncacheable_args;
307279}
308280
@@ -661,38 +633,57 @@ bool legalCombinedForwardReverse(CallInst &ci, const std::map<ReturnInst*,StoreI
661633 return false ;
662634 }
663635
664- auto getMRI = [&](Instruction* inst , Instruction* inst2) {
665- if (auto call = dyn_cast<CallInst>(inst )) {
636+ auto writesToMemoryReadBy = [&](Instruction* maybeReader , Instruction* maybeWriter) -> bool {
637+ if (auto call = dyn_cast<CallInst>(maybeWriter )) {
666638 if (call->getCalledFunction () && isCertainMallocOrFree (call->getCalledFunction ())) {
667- return ModRefInfo::NoModRef ;
639+ return false ;
668640 }
669641 }
642+ if (auto call = dyn_cast<CallInst>(maybeReader)) {
643+ if (call->getCalledFunction () && isCertainMallocOrFree (call->getCalledFunction ())) {
644+ return false ;
645+ }
646+ }
647+ if (auto call = dyn_cast<InvokeInst>(maybeWriter)) {
648+ if (call->getCalledFunction () && isCertainMallocOrFree (call->getCalledFunction ())) {
649+ return false ;
650+ }
651+ }
652+ if (auto call = dyn_cast<InvokeInst>(maybeReader)) {
653+ if (call->getCalledFunction () && isCertainMallocOrFree (call->getCalledFunction ())) {
654+ return false ;
655+ }
656+ }
657+ assert (maybeWriter->mayWriteToMemory ());
658+ assert (maybeReader->mayReadFromMemory ());
670659
671- if (auto li = dyn_cast<LoadInst>(inst2 )) {
672- return gutils->AA .getModRefInfo (inst , MemoryLocation::get (li));
660+ if (auto li = dyn_cast<LoadInst>(maybeReader )) {
661+ return isModSet ( gutils->AA .getModRefInfo (maybeWriter , MemoryLocation::get (li) ));
673662 }
674- if (auto si = dyn_cast<StoreInst>(inst2 )) {
675- return gutils->AA .getModRefInfo (inst , MemoryLocation::get (si ));
663+ if (auto rmw = dyn_cast<AtomicRMWInst>(maybeReader )) {
664+ return isModSet ( gutils->AA .getModRefInfo (maybeWriter , MemoryLocation::get (rmw) ));
676665 }
677- if (auto rmw = dyn_cast<AtomicRMWInst>(inst2 )) {
678- return gutils->AA .getModRefInfo (inst , MemoryLocation::get (rmw ));
666+ if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeReader )) {
667+ return isModSet ( gutils->AA .getModRefInfo (maybeWriter , MemoryLocation::get (xch) ));
679668 }
680- if (auto xch = dyn_cast<AtomicCmpXchgInst>(inst2)) {
681- return gutils->AA .getModRefInfo (inst, MemoryLocation::get (xch));
669+
670+ if (auto si = dyn_cast<StoreInst>(maybeWriter)) {
671+ return isRefSet (gutils->AA .getModRefInfo (maybeReader, MemoryLocation::get (si)));
682672 }
683- if (auto cb = dyn_cast<CallInst>(inst2)) {
684- if (cb->getCalledFunction () && isCertainMallocOrFree (cb->getCalledFunction ())) {
685- return ModRefInfo::NoModRef;
686- }
687- return gutils->AA .getModRefInfo (inst, cb);
673+ if (auto rmw = dyn_cast<AtomicRMWInst>(maybeWriter)) {
674+ return isRefSet (gutils->AA .getModRefInfo (maybeReader, MemoryLocation::get (rmw)));
688675 }
689- if (auto cb = dyn_cast<InvokeInst>(inst2)) {
690- if (cb->getCalledFunction () && isCertainMallocOrFree (cb->getCalledFunction ())) {
691- return ModRefInfo::NoModRef;
692- }
693- return gutils->AA .getModRefInfo (inst, cb);
676+ if (auto xch = dyn_cast<AtomicCmpXchgInst>(maybeWriter)) {
677+ return isRefSet (gutils->AA .getModRefInfo (maybeReader, MemoryLocation::get (xch)));
678+ }
679+
680+ if (auto cb = dyn_cast<CallInst>(maybeReader)) {
681+ return isModOrRefSet (gutils->AA .getModRefInfo (maybeWriter, cb));
694682 }
695- llvm::errs () << " inst2: " << *inst2 << " \n " ;
683+ if (auto cb = dyn_cast<InvokeInst>(maybeReader)) {
684+ return isModOrRefSet (gutils->AA .getModRefInfo (maybeWriter, cb));
685+ }
686+ llvm::errs () << " maybeReader: " << *maybeReader << " maybeWriter: " << *maybeWriter << " \n " ;
696687 llvm_unreachable (" unknown inst2" );
697688 };
698689
@@ -806,12 +797,11 @@ bool legalCombinedForwardReverse(CallInst &ci, const std::map<ReturnInst*,StoreI
806797 if (inst->mayWriteToMemory ()) {
807798 auto consider = [&](Instruction* user) {
808799 if (!user->mayReadFromMemory ()) return ;
809- auto mri = getMRI (user, inst);
810- // llvm::errs() << " checking if need follower of " << *inst << " - " << *user << " : mri " << mri << "\n";
811- if (isRefSet (mri)) {
800+ if (writesToMemoryReadBy (/* maybeReader*/ user, /* maybeWriter*/ inst)) {
801+ // llvm::errs() << " memory deduced need follower of " << *inst << " - " << *user << "\n";
812802 propagate (user);
813803 if (!legal) return ;
814- }
804+ }
815805 };
816806 allFollowersOf (inst, consider);
817807 if (!legal) return false ;
@@ -826,13 +816,14 @@ bool legalCombinedForwardReverse(CallInst &ci, const std::map<ReturnInst*,StoreI
826816 // llvm::errs() << " + " << *u << "\n";
827817
828818 // Check if any of the unmoved operations will make it illegal to move the instruction
819+
829820 for (auto inst : usetree) {
830821 if (!inst->mayReadFromMemory ()) continue ;
831822 allFollowersOf (inst, [&](Instruction* post ) {
823+ if (unnecessaryInstructions.count (post )) return ;
832824 if (!post ->mayWriteToMemory ()) return ;
833825 // llvm::errs() << " checking if illegal move of " << *inst << " due to " << *post << "\n";
834- auto mri = getMRI (inst, post );
835- if (isModSet (mri)) {
826+ if (writesToMemoryReadBy (/* maybeReader*/ inst, /* maybeWriter*/ post )) {
836827 if (called)
837828 llvm::errs () << " failed to replace function " << (called->getName ()) << " due to " << *post << " usetree: " << *inst << " \n " ;
838829 else
0 commit comments