139
139
#include " NVPTX.h"
140
140
#include " NVPTXTargetMachine.h"
141
141
#include " NVPTXUtilities.h"
142
+ #include " llvm/ADT/STLExtras.h"
143
+ #include " llvm/Analysis/PtrUseVisitor.h"
142
144
#include " llvm/Analysis/ValueTracking.h"
143
145
#include " llvm/CodeGen/TargetPassConfig.h"
144
146
#include " llvm/IR/Function.h"
145
147
#include " llvm/IR/IRBuilder.h"
146
148
#include " llvm/IR/Instructions.h"
149
+ #include " llvm/IR/IntrinsicInst.h"
147
150
#include " llvm/IR/IntrinsicsNVPTX.h"
148
151
#include " llvm/IR/Module.h"
149
152
#include " llvm/IR/Type.h"
150
153
#include " llvm/InitializePasses.h"
151
154
#include " llvm/Pass.h"
155
+ #include " llvm/Support/Debug.h"
156
+ #include " llvm/Support/ErrorHandling.h"
152
157
#include < numeric>
153
158
#include < queue>
154
159
@@ -217,7 +222,8 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
217
222
// pointer in parameter AS.
218
223
// For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
219
224
// generic using cvta.param.
220
- static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
225
+ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
226
+ bool IsGridConstant) {
221
227
Instruction *I = dyn_cast<Instruction>(OldUse->getUser ());
222
228
assert (I && " OldUse must be in an instruction" );
223
229
struct IP {
@@ -228,7 +234,8 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
228
234
SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
229
235
SmallVector<Instruction *> InstructionsToDelete;
230
236
231
- auto CloneInstInParamAS = [GridConstant](const IP &I) -> Value * {
237
+ auto CloneInstInParamAS = [HasCvtaParam,
238
+ IsGridConstant](const IP &I) -> Value * {
232
239
if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction )) {
233
240
LI->setOperand (0 , I.NewParam );
234
241
return LI;
@@ -252,8 +259,25 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
252
259
// Just pass through the argument, the old ASC is no longer needed.
253
260
return I.NewParam ;
254
261
}
262
+ if (auto *MI = dyn_cast<MemTransferInst>(I.OldInstruction )) {
263
+ if (MI->getRawSource () == I.OldUse ->get ()) {
264
+ // convert to memcpy/memmove from param space.
265
+ IRBuilder<> Builder (I.OldInstruction );
266
+ Intrinsic::ID ID = MI->getIntrinsicID ();
267
+
268
+ CallInst *B = Builder.CreateMemTransferInst (
269
+ ID, MI->getRawDest (), MI->getDestAlign (), I.NewParam ,
270
+ MI->getSourceAlign (), MI->getLength (), MI->isVolatile ());
271
+ for (unsigned I : {0 , 1 })
272
+ if (uint64_t Bytes = MI->getParamDereferenceableBytes (I))
273
+ B->addDereferenceableParamAttr (I, Bytes);
274
+ return B;
275
+ }
276
+ // We may be able to handle other cases if the argument is
277
+ // __grid_constant__
278
+ }
255
279
256
- if (GridConstant ) {
280
+ if (HasCvtaParam ) {
257
281
auto GetParamAddrCastToGeneric =
258
282
[](Value *Addr, Instruction *OriginalUser) -> Value * {
259
283
PointerType *ReturnTy =
@@ -269,24 +293,44 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
269
293
OriginalUser->getIterator ());
270
294
return CvtToGenCall;
271
295
};
272
-
273
- if (auto *CI = dyn_cast<CallInst>(I.OldInstruction )) {
274
- I.OldUse ->set (GetParamAddrCastToGeneric (I.NewParam , CI));
275
- return CI;
296
+ auto *ParamInGenericAS =
297
+ GetParamAddrCastToGeneric (I.NewParam , I.OldInstruction );
298
+
299
+ // phi/select could use generic arg pointers w/o __grid_constant__
300
+ if (auto *PHI = dyn_cast<PHINode>(I.OldInstruction )) {
301
+ for (auto [Idx, V] : enumerate(PHI->incoming_values ())) {
302
+ if (V.get () == I.OldUse ->get ())
303
+ PHI->setIncomingValue (Idx, ParamInGenericAS);
304
+ }
276
305
}
277
- if (auto *SI = dyn_cast<StoreInst >(I.OldInstruction )) {
278
- // byval address is being stored, cast it to generic
279
- if ( SI->getValueOperand () == I. OldUse -> get ())
280
- SI->setOperand ( 0 , GetParamAddrCastToGeneric (I. NewParam , SI));
281
- return SI ;
306
+ if (auto *SI = dyn_cast<SelectInst >(I.OldInstruction )) {
307
+ if (SI-> getTrueValue () == I. OldUse -> get ())
308
+ SI->setTrueValue (ParamInGenericAS);
309
+ if ( SI->getFalseValue () == I. OldUse -> get ())
310
+ SI-> setFalseValue (ParamInGenericAS) ;
282
311
}
283
- if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction )) {
284
- if (PI->getPointerOperand () == I.OldUse ->get ())
285
- PI->setOperand (0 , GetParamAddrCastToGeneric (I.NewParam , PI));
286
- return PI;
312
+
313
+ // Escapes or writes can only use generic param pointers if
314
+ // __grid_constant__ is in effect.
315
+ if (IsGridConstant) {
316
+ if (auto *CI = dyn_cast<CallInst>(I.OldInstruction )) {
317
+ I.OldUse ->set (ParamInGenericAS);
318
+ return CI;
319
+ }
320
+ if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction )) {
321
+ // byval address is being stored, cast it to generic
322
+ if (SI->getValueOperand () == I.OldUse ->get ())
323
+ SI->setOperand (0 , ParamInGenericAS);
324
+ return SI;
325
+ }
326
+ if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction )) {
327
+ if (PI->getPointerOperand () == I.OldUse ->get ())
328
+ PI->setOperand (0 , ParamInGenericAS);
329
+ return PI;
330
+ }
331
+ // TODO: iIf we allow stores, we should allow memcpy/memset to
332
+ // parameter, too.
287
333
}
288
- llvm_unreachable (
289
- " Instruction unsupported even for grid_constant argument" );
290
334
}
291
335
292
336
llvm_unreachable (" Unsupported instruction" );
@@ -409,49 +453,110 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
409
453
}
410
454
}
411
455
456
+ namespace {
457
+ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
458
+ using Base = PtrUseVisitor<ArgUseChecker>;
459
+
460
+ bool IsGridConstant;
461
+ // Set of phi/select instructions using the Arg
462
+ SmallPtrSet<Instruction *, 4 > Conditionals;
463
+
464
+ ArgUseChecker (const DataLayout &DL, bool IsGridConstant)
465
+ : PtrUseVisitor(DL), IsGridConstant(IsGridConstant) {}
466
+
467
+ PtrInfo visitArgPtr (Argument &A) {
468
+ assert (A.getType ()->isPointerTy ());
469
+ IntegerType *IntIdxTy = cast<IntegerType>(DL.getIndexType (A.getType ()));
470
+ IsOffsetKnown = false ;
471
+ Offset = APInt (IntIdxTy->getBitWidth (), 0 );
472
+ PI.reset ();
473
+ Conditionals.clear ();
474
+
475
+ LLVM_DEBUG (dbgs () << " Checking Argument " << A << " \n " );
476
+ // Enqueue the uses of this pointer.
477
+ enqueueUsers (A);
478
+
479
+ // Visit all the uses off the worklist until it is empty.
480
+ // Note that unlike PtrUseVisitor we intentionally do not track offsets.
481
+ // We're only interested in how we use the pointer.
482
+ while (!(Worklist.empty () || PI.isAborted ())) {
483
+ UseToVisit ToVisit = Worklist.pop_back_val ();
484
+ U = ToVisit.UseAndIsOffsetKnown .getPointer ();
485
+ Instruction *I = cast<Instruction>(U->getUser ());
486
+ if (isa<PHINode>(I) || isa<SelectInst>(I))
487
+ Conditionals.insert (I);
488
+ LLVM_DEBUG (dbgs () << " Processing " << *I << " \n " );
489
+ Base::visit (I);
490
+ }
491
+ if (PI.isEscaped ())
492
+ LLVM_DEBUG (dbgs () << " Argument pointer escaped: " << *PI.getEscapingInst ()
493
+ << " \n " );
494
+ else if (PI.isAborted ())
495
+ LLVM_DEBUG (dbgs () << " Pointer use needs a copy: " << *PI.getAbortingInst ()
496
+ << " \n " );
497
+ LLVM_DEBUG (dbgs () << " Traversed " << Conditionals.size ()
498
+ << " conditionals\n " );
499
+ return PI;
500
+ }
501
+
502
+ void visitStoreInst (StoreInst &SI) {
503
+ // Storing the pointer escapes it.
504
+ if (U->get () == SI.getValueOperand ())
505
+ return PI.setEscapedAndAborted (&SI);
506
+ // Writes to the pointer are UB w/ __grid_constant__, but do not force a
507
+ // copy.
508
+ if (!IsGridConstant)
509
+ return PI.setAborted (&SI);
510
+ }
511
+
512
+ void visitAddrSpaceCastInst (AddrSpaceCastInst &ASC) {
513
+ // ASC to param space are no-ops and do not need a copy
514
+ if (ASC.getDestAddressSpace () != ADDRESS_SPACE_PARAM)
515
+ return PI.setEscapedAndAborted (&ASC);
516
+ Base::visitAddrSpaceCastInst (ASC);
517
+ }
518
+
519
+ void visitPtrToIntInst (PtrToIntInst &I) {
520
+ if (IsGridConstant)
521
+ return ;
522
+ Base::visitPtrToIntInst (I);
523
+ }
524
+ void visitPHINodeOrSelectInst (Instruction &I) {
525
+ assert (isa<PHINode>(I) || isa<SelectInst>(I));
526
+ }
527
+ // PHI and select just pass through the pointers.
528
+ void visitPHINode (PHINode &PN) { enqueueUsers (PN); }
529
+ void visitSelectInst (SelectInst &SI) { enqueueUsers (SI); }
530
+
531
+ void visitMemTransferInst (MemTransferInst &II) {
532
+ if (*U == II.getRawDest () && !IsGridConstant)
533
+ PI.setAborted (&II);
534
+ // memcpy/memmove are OK when the pointer is source. We can convert them to
535
+ // AS-specific memcpy.
536
+ }
537
+
538
+ void visitMemSetInst (MemSetInst &II) {
539
+ if (!IsGridConstant)
540
+ PI.setAborted (&II);
541
+ }
542
+ }; // struct ArgUseChecker
543
+ } // namespace
544
+
412
545
void NVPTXLowerArgs::handleByValParam (const NVPTXTargetMachine &TM,
413
546
Argument *Arg) {
414
- bool IsGridConstant = isParamGridConstant (*Arg);
415
547
Function *Func = Arg->getParent ();
548
+ bool HasCvtaParam = TM.getSubtargetImpl (*Func)->hasCvtaParam ();
549
+ bool IsGridConstant = HasCvtaParam && isParamGridConstant (*Arg);
550
+ const DataLayout &DL = Func->getDataLayout ();
416
551
BasicBlock::iterator FirstInst = Func->getEntryBlock ().begin ();
417
552
Type *StructType = Arg->getParamByValType ();
418
553
assert (StructType && " Missing byval type" );
419
554
420
- auto AreSupportedUsers = [&](Value *Start) {
421
- SmallVector<Value *, 16 > ValuesToCheck = {Start};
422
- auto IsSupportedUse = [IsGridConstant](Value *V) -> bool {
423
- if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
424
- return true ;
425
- // ASC to param space are OK, too -- we'll just strip them.
426
- if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
427
- if (ASC->getDestAddressSpace () == ADDRESS_SPACE_PARAM)
428
- return true ;
429
- }
430
- // Simple calls and stores are supported for grid_constants
431
- // writes to these pointers are undefined behaviour
432
- if (IsGridConstant &&
433
- (isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
434
- return true ;
435
- return false ;
436
- };
437
-
438
- while (!ValuesToCheck.empty ()) {
439
- Value *V = ValuesToCheck.pop_back_val ();
440
- if (!IsSupportedUse (V)) {
441
- LLVM_DEBUG (dbgs () << " Need a "
442
- << (isParamGridConstant (*Arg) ? " cast " : " copy " )
443
- << " of " << *Arg << " because of " << *V << " \n " );
444
- (void )Arg;
445
- return false ;
446
- }
447
- if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
448
- !isa<PtrToIntInst>(V))
449
- llvm::append_range (ValuesToCheck, V->users ());
450
- }
451
- return true ;
452
- };
453
-
454
- if (llvm::all_of (Arg->users (), AreSupportedUsers)) {
555
+ ArgUseChecker AUC (DL, IsGridConstant);
556
+ ArgUseChecker::PtrInfo PI = AUC.visitArgPtr (*Arg);
557
+ bool ArgUseIsReadOnly = !(PI.isEscaped () || PI.isAborted ());
558
+ // Easy case, accessing parameter directly is fine.
559
+ if (ArgUseIsReadOnly && AUC.Conditionals .empty ()) {
455
560
// Convert all loads and intermediate operations to use parameter AS and
456
561
// skip creation of a local copy of the argument.
457
562
SmallVector<Use *, 16 > UsesToUpdate;
@@ -462,7 +567,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
462
567
Arg, PointerType::get (StructType, ADDRESS_SPACE_PARAM), Arg->getName (),
463
568
FirstInst);
464
569
for (Use *U : UsesToUpdate)
465
- convertToParamAS (U, ArgInParamAS, IsGridConstant);
570
+ convertToParamAS (U, ArgInParamAS, HasCvtaParam, IsGridConstant);
466
571
LLVM_DEBUG (dbgs () << " No need to copy or cast " << *Arg << " \n " );
467
572
468
573
const auto *TLI =
@@ -473,13 +578,17 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
473
578
return ;
474
579
}
475
580
476
- const DataLayout &DL = Func->getDataLayout ();
581
+ // We can't access byval arg directly and need a pointer. on sm_70+ we have
582
+ // ability to take a pointer to the argument without making a local copy.
583
+ // However, we're still not allowed to write to it. If the user specified
584
+ // `__grid_constant__` for the argument, we'll consider escaped pointer as
585
+ // read-only.
477
586
unsigned AS = DL.getAllocaAddrSpace ();
478
- if (isParamGridConstant (*Arg )) {
479
- // Writes to a grid constant are undefined behaviour. We do not need a
480
- // temporary copy. When a pointer might have escaped, conservatively replace
481
- // all of its uses (which might include a device function call) with a cast
482
- // to the generic address space .
587
+ if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant )) {
588
+ LLVM_DEBUG ( dbgs () << " Using non-copy pointer to " << *Arg << " \n " );
589
+ // Replace all argument pointer uses (which might include a device function
590
+ // call) with a cast to the generic address space using cvta.param
591
+ // instruction, which avoids a local copy .
483
592
IRBuilder<> IRB (&Func->getEntryBlock ().front ());
484
593
485
594
// Cast argument to param address space
@@ -500,6 +609,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
500
609
// Do not replace Arg in the cast to param space
501
610
CastToParam->setOperand (0 , Arg);
502
611
} else {
612
+ LLVM_DEBUG (dbgs () << " Creating a local copy of " << *Arg << " \n " );
503
613
// Otherwise we have to create a temporary copy.
504
614
AllocaInst *AllocA =
505
615
new AllocaInst (StructType, AS, Arg->getName (), FirstInst);
0 commit comments