99
99
#include " llvm/IR/Type.h"
100
100
#include " llvm/Pass.h"
101
101
102
+ #define DEBUG_TYPE " nvptx-lower-args"
103
+
102
104
using namespace llvm ;
103
105
104
106
namespace llvm {
@@ -166,40 +168,60 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
166
168
Value *NewParam;
167
169
};
168
170
SmallVector<IP> ItemsToConvert = {{I, Param}};
169
- SmallVector<GetElementPtrInst *> GEPsToDelete ;
170
- while (!ItemsToConvert. empty ()) {
171
- IP I = ItemsToConvert. pop_back_val ();
172
- if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction ))
171
+ SmallVector<Instruction *> InstructionsToDelete ;
172
+
173
+ auto CloneInstInParamAS = []( const IP &I) -> Value * {
174
+ if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction )) {
173
175
LI->setOperand (0 , I.NewParam );
174
- else if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction )) {
176
+ return LI;
177
+ }
178
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction )) {
175
179
SmallVector<Value *, 4 > Indices (GEP->indices ());
176
180
auto *NewGEP = GetElementPtrInst::Create (nullptr , I.NewParam , Indices,
177
181
GEP->getName (), GEP);
178
182
NewGEP->setIsInBounds (GEP->isInBounds ());
179
- llvm::for_each (GEP->users (), [NewGEP, &ItemsToConvert](Value *V) {
180
- ItemsToConvert.push_back ({cast<Instruction>(V), NewGEP});
181
- });
182
- GEPsToDelete.push_back (GEP);
183
- } else
184
- llvm_unreachable (" Only Load and GEP can be converted to param AS." );
185
- }
186
- llvm::for_each (GEPsToDelete,
187
- [](GetElementPtrInst *GEP) { GEP->eraseFromParent (); });
188
- }
183
+ return NewGEP;
184
+ }
185
+ if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction )) {
186
+ auto *NewBCType = BC->getType ()->getPointerElementType ()->getPointerTo (
187
+ ADDRESS_SPACE_PARAM);
188
+ return BitCastInst::Create (BC->getOpcode (), I.NewParam , NewBCType,
189
+ BC->getName (), BC);
190
+ }
191
+ if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction )) {
192
+ assert (ASC->getDestAddressSpace () == ADDRESS_SPACE_PARAM);
193
+ // Just pass through the argument, the old ASC is no longer needed.
194
+ return I.NewParam ;
195
+ }
196
+ llvm_unreachable (" Unsupported instruction" );
197
+ };
189
198
190
- static bool isALoadChain (Value *Start) {
191
- SmallVector<Value *, 16 > ValuesToCheck = {Start};
192
- while (!ValuesToCheck.empty ()) {
193
- Value *V = ValuesToCheck.pop_back_val ();
194
- Instruction *I = dyn_cast<Instruction>(V);
195
- if (!I)
196
- return false ;
197
- if (isa<GetElementPtrInst>(I))
198
- ValuesToCheck.append (I->user_begin (), I->user_end ());
199
- else if (!isa<LoadInst>(I))
200
- return false ;
199
+ while (!ItemsToConvert.empty ()) {
200
+ IP I = ItemsToConvert.pop_back_val ();
201
+ Value *NewInst = CloneInstInParamAS (I);
202
+
203
+ if (NewInst && NewInst != I.OldInstruction ) {
204
+ // We've created a new instruction. Queue users of the old instruction to
205
+ // be converted and the instruction itself to be deleted. We can't delete
206
+ // the old instruction yet, because it's still in use by a load somewhere.
207
+ llvm::for_each (
208
+ I.OldInstruction ->users (), [NewInst, &ItemsToConvert](Value *V) {
209
+ ItemsToConvert.push_back ({cast<Instruction>(V), NewInst});
210
+ });
211
+
212
+ InstructionsToDelete.push_back (I.OldInstruction );
213
+ }
201
214
}
202
- return true ;
215
+
216
+ // Now we know that all argument loads are using addresses in parameter space
217
+ // and we can finally remove the old instructions in generic AS. Instructions
218
+ // scheduled for removal should be processed in reverse order so the ones
219
+ // closest to the load are deleted first. Otherwise they may still be in use.
220
+ // E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
221
+ // have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
222
+ // the BitCast.
223
+ llvm::for_each (reverse (InstructionsToDelete),
224
+ [](Instruction *I) { I->eraseFromParent (); });
203
225
}
204
226
205
227
void NVPTXLowerArgs::handleByValParam (Argument *Arg) {
@@ -211,16 +233,43 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
211
233
212
234
Type *StructType = PType->getElementType ();
213
235
214
- if (llvm::all_of (Arg->users (), isALoadChain)) {
215
- // Replace all loads with the loads in param AS. This allows loading the Arg
216
- // directly from parameter AS, without making a temporary copy.
236
+ auto IsALoadChain = [Arg](Value *Start) {
237
+ SmallVector<Value *, 16 > ValuesToCheck = {Start};
238
+ auto IsALoadChainInstr = [](Value *V) -> bool {
239
+ if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
240
+ return true ;
241
+ // ASC to param space are OK, too -- we'll just strip them.
242
+ if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
243
+ if (ASC->getDestAddressSpace () == ADDRESS_SPACE_PARAM)
244
+ return true ;
245
+ }
246
+ return false ;
247
+ };
248
+
249
+ while (!ValuesToCheck.empty ()) {
250
+ Value *V = ValuesToCheck.pop_back_val ();
251
+ if (!IsALoadChainInstr (V)) {
252
+ LLVM_DEBUG (dbgs () << " Need a copy of " << *Arg << " because of " << *V
253
+ << " \n " );
254
+ return false ;
255
+ }
256
+ if (!isa<LoadInst>(V))
257
+ llvm::append_range (ValuesToCheck, V->users ());
258
+ }
259
+ return true ;
260
+ };
261
+
262
+ if (llvm::all_of (Arg->users (), IsALoadChain)) {
263
+ // Convert all loads and intermediate operations to use parameter AS and
264
+ // skip creation of a local copy of the argument.
217
265
SmallVector<User *, 16 > UsersToUpdate (Arg->users ());
218
266
Value *ArgInParamAS = new AddrSpaceCastInst (
219
267
Arg, PointerType::get (StructType, ADDRESS_SPACE_PARAM), Arg->getName (),
220
268
FirstInst);
221
269
llvm::for_each (UsersToUpdate, [ArgInParamAS](Value *V) {
222
270
convertToParamAS (V, ArgInParamAS);
223
271
});
272
+ LLVM_DEBUG (dbgs () << " No need to copy " << *Arg << " \n " );
224
273
return ;
225
274
}
226
275
@@ -297,6 +346,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
297
346
}
298
347
}
299
348
349
+ LLVM_DEBUG (dbgs () << " Lowering kernel args of " << F.getName () << " \n " );
300
350
for (Argument &Arg : F.args ()) {
301
351
if (Arg.getType ()->isPointerTy ()) {
302
352
if (Arg.hasByValAttr ())
@@ -310,6 +360,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
310
360
311
361
// Device functions only need to copy byval args into local memory.
312
362
bool NVPTXLowerArgs::runOnDeviceFunction (Function &F) {
363
+ LLVM_DEBUG (dbgs () << " Lowering function args of " << F.getName () << " \n " );
313
364
for (Argument &Arg : F.args ())
314
365
if (Arg.getType ()->isPointerTy () && Arg.hasByValAttr ())
315
366
handleByValParam (&Arg);
0 commit comments