@@ -62,6 +62,103 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
6262 " enzyme_nonmarkedglobals_inactiveloads" , cl::init(true ), cl::Hidden,
6363 cl::desc(" Consider loads of nonmarked globals to be inactive" ));
6464
65+
66+ bool is_load_uncacheable (LoadInst& li, AAResults& AA, GradientUtils* gutils, TargetLibraryInfo& TLI,
67+ const std::map<Argument*, bool >& uncacheable_args) {
68+
69+ bool can_modref = false ;
70+ // Find the underlying object for the pointer operand of the load instruction.
71+ auto obj = GetUnderlyingObject (li.getPointerOperand (), gutils->oldFunc ->getParent ()->getDataLayout (), 100 );
72+
73+ // llvm::errs() << "underlying object for load " << li << " is " << *obj << "\n";
74+ // If the pointer operand is from an argument to the function, we need to check if the argument
75+ // received from the caller is uncacheable.
76+ if (auto arg = dyn_cast<Argument>(obj)) {
77+ auto found = uncacheable_args.find (arg);
78+ if (found == uncacheable_args.end ()) {
79+ llvm::errs () << " uncacheable_args:\n " ;
80+ for (auto & pair : uncacheable_args) {
81+ llvm::errs () << " + " << *pair.first << " : " << pair.second << " of func " << pair.first ->getParent ()->getName () << " \n " ;
82+ }
83+ llvm::errs () << " could not find " << *arg << " of func " << arg->getParent ()->getName () << " in args_map\n " ;
84+ }
85+ assert (found != uncacheable_args.end ());
86+ if (found->second ) {
87+ // llvm::errs() << "OP is uncacheable arg: " << li << "\n";
88+ can_modref = true ;
89+ }
90+ // llvm::errs() << " + argument (can_modref=" << can_modref << ") " << li << " object: " << *obj << " arg: " << *arg << "e\n";
91+ // TODO this case (alloca goes out of scope/allocation is freed and we dont force it to continue needs to be forcibly cached)
92+ } else {
93+ // NOTE(TFK): In the case where the underlying object for the pointer operand is from a Load or Call we need
94+ // to check if we need to cache. Likely, we need to play it safe in this case and cache.
95+ // NOTE(TFK): The logic below is an attempt at a conservative handling of the case mentioned above, but it
96+ // needs to be verified.
97+
98+ // Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
99+ if (auto obj_op = dyn_cast<CallInst>(obj)) {
100+ Function* called = obj_op->getCalledFunction ();
101+ if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue ())) {
102+ if (castinst->isCast ()) {
103+ if (auto fn = dyn_cast<Function>(castinst->getOperand (0 ))) {
104+ if (isAllocationFunction (*fn, TLI) || isDeallocationFunction (*fn, TLI)) {
105+ called = fn;
106+ }
107+ }
108+ }
109+ }
110+ if (called && isCertainMallocOrFree (called)) {
111+ // llvm::errs() << "OP is certain malloc or free: " << *op << "\n";
112+ } else {
113+ // llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n";
114+ can_modref = true ;
115+ }
116+ } else if (auto sli = dyn_cast<LoadInst>(obj)) {
117+ // If obj is from a load instruction conservatively consider it uncacheable if that load itself cannot be cached
118+ // llvm::errs() << "OP is from a load, needing to cache " << *op << "\n";
119+ can_modref = is_load_uncacheable (*sli, AA, gutils, TLI, uncacheable_args);
120+ } else {
121+ // In absence of more information, assume that the underlying object for pointer operand is uncacheable in caller.
122+ // llvm::errs() << "OP is an unknown instruction, needing to cache " << *op << "\n";
123+ can_modref = true ;
124+ }
125+ }
126+
127+ for (inst_iterator I2 = inst_begin (*gutils->oldFunc ), E2 = inst_end (*gutils->oldFunc ); I2 != E2 ; ++I2) {
128+ Instruction* inst2 = &*I2;
129+ assert (li.getParent ()->getParent () == inst2->getParent ()->getParent ());
130+ if (&li == inst2) continue ;
131+ if (!gutils->OrigDT .dominates (inst2, &li)) {
132+
133+ // Don't consider modref from malloc/free as a need to cache
134+ if (auto obj_op = dyn_cast<CallInst>(inst2)) {
135+ Function* called = obj_op->getCalledFunction ();
136+ if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue ())) {
137+ if (castinst->isCast ()) {
138+ if (auto fn = dyn_cast<Function>(castinst->getOperand (0 ))) {
139+ if (isAllocationFunction (*fn, TLI) || isDeallocationFunction (*fn, TLI)) {
140+ called = fn;
141+ }
142+ }
143+ }
144+ }
145+ if (called && isCertainMallocOrFree (called)) {
146+ continue ;
147+ }
148+ }
149+
150+ if (llvm::isModSet (AA.getModRefInfo (inst2, MemoryLocation::get (&li)))) {
151+ can_modref = true ;
152+ llvm::errs () << li << " needs to be cached due to: " << *inst2 << " \n " ;
153+ break ;
154+ }
155+ }
156+ }
157+ // llvm::errs() << "F - " << li << " can_modref" << can_modref << "\n";
158+ return can_modref;
159+
160+ }
161+
65162// Computes a map of LoadInst -> boolean for a function indicating whether that load is "uncacheable".
66163// A load is considered "uncacheable" if the data at the loaded memory location can be modified after
67164// the load instruction.
@@ -70,98 +167,11 @@ std::map<Instruction*, bool> compute_uncacheable_load_map(GradientUtils* gutils,
70167 std::map<Instruction*, bool > can_modref_map;
71168 for (inst_iterator I = inst_begin (*gutils->oldFunc ), E = inst_end (*gutils->oldFunc ); I != E; ++I) {
72169 Instruction* inst = &*I;
73- // For each load instruction, determine if it is xuncacheable .
170+ // For each load instruction, determine if it is uncacheable .
74171 if (auto op = dyn_cast<LoadInst>(inst)) {
75172
76- bool can_modref = false ;
77- // Find the underlying object for the pointer operand of the load instruction.
78- auto obj = GetUnderlyingObject (op->getPointerOperand (), gutils->oldFunc ->getParent ()->getDataLayout (), 100 );
79-
80- // llvm::errs() << "underlying object for load " << *op << " is " << *obj << "\n";
81- // If the pointer operand is from an argument to the function, we need to check if the argument
82- // received from the caller is uncacheable.
83- if (auto arg = dyn_cast<Argument>(obj)) {
84- auto found = uncacheable_args.find (arg);
85- if (found == uncacheable_args.end ()) {
86- llvm::errs () << " uncacheable_args:\n " ;
87- for (auto & pair : uncacheable_args) {
88- llvm::errs () << " + " << *pair.first << " : " << pair.second << " of func " << pair.first ->getParent ()->getName () << " \n " ;
89- }
90- llvm::errs () << " could not find " << *arg << " of func " << arg->getParent ()->getName () << " in args_map\n " ;
91- }
92- assert (found != uncacheable_args.end ());
93- if (found->second ) {
94- // llvm::errs() << "OP is uncacheable arg: " << *op << "\n";
95- can_modref = true ;
96- }
97- // llvm::errs() << " + argument (can_modref=" << can_modref << ") " << *op << " object: " << *obj << " arg: " << *arg << "e\n";
98- // TODO this case (alloca goes out of scope/allocation is freed and we dont force it to continue needs to be forcibly cached)
99- } else {
100- // NOTE(TFK): In the case where the underlying object for the pointer operand is from a Load or Call we need
101- // to check if we need to cache. Likely, we need to play it safe in this case and cache.
102- // NOTE(TFK): The logic below is an attempt at a conservative handling of the case mentioned above, but it
103- // needs to be verified.
104-
105- // Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
106- if (auto obj_op = dyn_cast<CallInst>(obj)) {
107- Function* called = obj_op->getCalledFunction ();
108- if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue ())) {
109- if (castinst->isCast ()) {
110- if (auto fn = dyn_cast<Function>(castinst->getOperand (0 ))) {
111- if (isAllocationFunction (*fn, TLI) || isDeallocationFunction (*fn, TLI)) {
112- called = fn;
113- }
114- }
115- }
116- }
117- if (called && isCertainMallocOrFree (called)) {
118- // llvm::errs() << "OP is certain malloc or free: " << *op << "\n";
119- } else {
120- // llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n";
121- can_modref = true ;
122- }
123- } else if (isa<LoadInst>(obj)) {
124- // If obj is from a load instruction conservatively consider it uncacheable.
125- // llvm::errs() << "OP is from a load, needing to cache " << *op << "\n";
126- can_modref = true ;
127- } else {
128- // In absence of more information, assume that the underlying object for pointer operand is uncacheable in caller.
129- // llvm::errs() << "OP is an unknown instruction, needing to cache " << *op << "\n";
130- can_modref = true ;
131- }
132- }
133-
134- for (inst_iterator I2 = inst_begin (*gutils->oldFunc ), E2 = inst_end (*gutils->oldFunc ); I2 != E2 ; ++I2) {
135- Instruction* inst2 = &*I2;
136- assert (inst->getParent ()->getParent () == inst2->getParent ()->getParent ());
137- if (inst == inst2) continue ;
138- if (!gutils->OrigDT .dominates (inst2, inst)) {
139-
140- // Don't consider modref from malloc/free as a need to cache
141- if (auto obj_op = dyn_cast<CallInst>(inst2)) {
142- Function* called = obj_op->getCalledFunction ();
143- if (auto castinst = dyn_cast<ConstantExpr>(obj_op->getCalledValue ())) {
144- if (castinst->isCast ()) {
145- if (auto fn = dyn_cast<Function>(castinst->getOperand (0 ))) {
146- if (isAllocationFunction (*fn, TLI) || isDeallocationFunction (*fn, TLI)) {
147- called = fn;
148- }
149- }
150- }
151- }
152- if (called && isCertainMallocOrFree (called)) {
153- continue ;
154- }
155- }
156173
157- if (llvm::isModSet (AA.getModRefInfo (inst2, MemoryLocation::get (op)))) {
158- can_modref = true ;
159- // llvm::errs() << *inst << " needs to be cached due to: " << *inst2 << "\n";
160- break ;
161- }
162- }
163- }
164- can_modref_map[inst] = can_modref;
174+ can_modref_map[inst] = is_load_uncacheable (*op, AA, gutils, TLI, uncacheable_args);
165175 }
166176 }
167177 return can_modref_map;
@@ -185,7 +195,10 @@ std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* ca
185195 100 );
186196 // llvm::errs() << "ocs underlying object for callsite " << *callsite_op << " idx: " << i << " is " << *obj << "\n";
187197 // If underlying object is an Argument, check parent volatility status.
188- if (auto arg = dyn_cast<Argument>(obj)) {
198+ if (isa<UndefValue>(obj)) {
199+ init_safe = true ;
200+ // llvm::errs() << " + ocs undef (safe=" << init_safe << ") " << *callsite_op << " object: " << *obj << "\n";
201+ } else if (auto arg = dyn_cast<Argument>(obj)) {
189202 auto found = parent_uncacheable_args.find (arg);
190203 if (found == parent_uncacheable_args.end ()) {
191204 llvm::errs () << " parent_uncacheable_args:\n " ;
@@ -198,7 +211,7 @@ std::map<Argument*, bool> compute_uncacheable_args_for_one_callsite(CallInst* ca
198211 if (found->second ) {
199212 init_safe = false ;
200213 }
201- // llvm::errs() << " + ocs argument (safe=" << init_safe << ") " << *callsite_op << " object: " << *obj << " arg: " << *arg << "e \n";
214+ // llvm::errs() << " + ocs argument (safe=" << init_safe << ") " << *callsite_op << " object: " << *obj << " arg: " << *arg << "å \n";
202215 } else {
203216 // Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable.
204217 if (auto obj_op = dyn_cast<CallInst>(obj)) {
0 commit comments