9999#include " llvm/IR/Type.h"
100100#include " llvm/Pass.h"
101101
102+ #define DEBUG_TYPE " nvptx-lower-args"
103+
102104using namespace llvm ;
103105
104106namespace llvm {
@@ -166,40 +168,60 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
166168 Value *NewParam;
167169 };
168170 SmallVector<IP> ItemsToConvert = {{I, Param}};
169- SmallVector<GetElementPtrInst *> GEPsToDelete ;
170- while (!ItemsToConvert. empty ()) {
171- IP I = ItemsToConvert. pop_back_val ();
172- if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction ))
171+ SmallVector<Instruction *> InstructionsToDelete ;
172+
173+ auto CloneInstInParamAS = []( const IP &I) -> Value * {
174+ if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction )) {
173175 LI->setOperand (0 , I.NewParam );
174- else if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction )) {
176+ return LI;
177+ }
178+ if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction )) {
175179 SmallVector<Value *, 4 > Indices (GEP->indices ());
176180 auto *NewGEP = GetElementPtrInst::Create (nullptr , I.NewParam , Indices,
177181 GEP->getName (), GEP);
178182 NewGEP->setIsInBounds (GEP->isInBounds ());
179- llvm::for_each (GEP->users (), [NewGEP, &ItemsToConvert](Value *V) {
180- ItemsToConvert.push_back ({cast<Instruction>(V), NewGEP});
181- });
182- GEPsToDelete.push_back (GEP);
183- } else
184- llvm_unreachable (" Only Load and GEP can be converted to param AS." );
185- }
186- llvm::for_each (GEPsToDelete,
187- [](GetElementPtrInst *GEP) { GEP->eraseFromParent (); });
188- }
183+ return NewGEP;
184+ }
185+ if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction )) {
186+ auto *NewBCType = BC->getType ()->getPointerElementType ()->getPointerTo (
187+ ADDRESS_SPACE_PARAM);
188+ return BitCastInst::Create (BC->getOpcode (), I.NewParam , NewBCType,
189+ BC->getName (), BC);
190+ }
191+ if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction )) {
192+ assert (ASC->getDestAddressSpace () == ADDRESS_SPACE_PARAM);
193+ // Just pass through the argument, the old ASC is no longer needed.
194+ return I.NewParam ;
195+ }
196+ llvm_unreachable (" Unsupported instruction" );
197+ };
189198
190- static bool isALoadChain (Value *Start) {
191- SmallVector<Value *, 16 > ValuesToCheck = {Start};
192- while (!ValuesToCheck.empty ()) {
193- Value *V = ValuesToCheck.pop_back_val ();
194- Instruction *I = dyn_cast<Instruction>(V);
195- if (!I)
196- return false ;
197- if (isa<GetElementPtrInst>(I))
198- ValuesToCheck.append (I->user_begin (), I->user_end ());
199- else if (!isa<LoadInst>(I))
200- return false ;
199+ while (!ItemsToConvert.empty ()) {
200+ IP I = ItemsToConvert.pop_back_val ();
201+ Value *NewInst = CloneInstInParamAS (I);
202+
203+ if (NewInst && NewInst != I.OldInstruction ) {
204+ // We've created a new instruction. Queue users of the old instruction to
205+ // be converted and the instruction itself to be deleted. We can't delete
206+ // the old instruction yet, because it's still in use by a load somewhere.
207+ llvm::for_each (
208+ I.OldInstruction ->users (), [NewInst, &ItemsToConvert](Value *V) {
209+ ItemsToConvert.push_back ({cast<Instruction>(V), NewInst});
210+ });
211+
212+ InstructionsToDelete.push_back (I.OldInstruction );
213+ }
201214 }
202- return true ;
215+
216+ // Now we know that all argument loads are using addresses in parameter space
217+ // and we can finally remove the old instructions in generic AS. Instructions
218+ // scheduled for removal should be processed in reverse order so the ones
219+ // closest to the load are deleted first. Otherwise they may still be in use.
220+ // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
221+ // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
222+ // the BitCast.
223+ llvm::for_each (reverse (InstructionsToDelete),
224+ [](Instruction *I) { I->eraseFromParent (); });
203225}
204226
205227void NVPTXLowerArgs::handleByValParam (Argument *Arg) {
@@ -211,16 +233,43 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
211233
212234 Type *StructType = PType->getElementType ();
213235
214- if (llvm::all_of (Arg->users (), isALoadChain)) {
215- // Replace all loads with the loads in param AS. This allows loading the Arg
216- // directly from parameter AS, without making a temporary copy.
236+ auto IsALoadChain = [Arg](Value *Start) {
237+ SmallVector<Value *, 16 > ValuesToCheck = {Start};
238+ auto IsALoadChainInstr = [](Value *V) -> bool {
239+ if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
240+ return true ;
241+ // ASC to param space are OK, too -- we'll just strip them.
242+ if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
243+ if (ASC->getDestAddressSpace () == ADDRESS_SPACE_PARAM)
244+ return true ;
245+ }
246+ return false ;
247+ };
248+
249+ while (!ValuesToCheck.empty ()) {
250+ Value *V = ValuesToCheck.pop_back_val ();
251+ if (!IsALoadChainInstr (V)) {
252+ LLVM_DEBUG (dbgs () << " Need a copy of " << *Arg << " because of " << *V
253+ << " \n " );
254+ return false ;
255+ }
256+ if (!isa<LoadInst>(V))
257+ llvm::append_range (ValuesToCheck, V->users ());
258+ }
259+ return true ;
260+ };
261+
262+ if (llvm::all_of (Arg->users (), IsALoadChain)) {
263+ // Convert all loads and intermediate operations to use parameter AS and
264+ // skip creation of a local copy of the argument.
217265 SmallVector<User *, 16 > UsersToUpdate (Arg->users ());
218266 Value *ArgInParamAS = new AddrSpaceCastInst (
219267 Arg, PointerType::get (StructType, ADDRESS_SPACE_PARAM), Arg->getName (),
220268 FirstInst);
221269 llvm::for_each (UsersToUpdate, [ArgInParamAS](Value *V) {
222270 convertToParamAS (V, ArgInParamAS);
223271 });
272+ LLVM_DEBUG (dbgs () << " No need to copy " << *Arg << " \n " );
224273 return ;
225274 }
226275
@@ -297,6 +346,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
297346 }
298347 }
299348
349+ LLVM_DEBUG (dbgs () << " Lowering kernel args of " << F.getName () << " \n " );
300350 for (Argument &Arg : F.args ()) {
301351 if (Arg.getType ()->isPointerTy ()) {
302352 if (Arg.hasByValAttr ())
@@ -310,6 +360,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
310360
311361// Device functions only need to copy byval args into local memory.
312362bool NVPTXLowerArgs::runOnDeviceFunction (Function &F) {
363+ LLVM_DEBUG (dbgs () << " Lowering function args of " << F.getName () << " \n " );
313364 for (Argument &Arg : F.args ())
314365 if (Arg.getType ()->isPointerTy () && Arg.hasByValAttr ())
315366 handleByValParam (&Arg);
0 commit comments