Skip to content

Commit b7b28e7

Browse files
authored
[NVPTX] Improve copy avoidance during lowering. (#106423)
On newer GPUs, where `cvta.param` instruction is available we can avoid making byval arguments when their pointers are used in a few more cases, even when `__grid_constant__` is not specified. - phi - select - memcpy from the parameter. Switched pointer traversal from a DIY implementation to PtrUseVisitor.
1 parent cb03126 commit b7b28e7

File tree

4 files changed

+984
-378
lines changed

4 files changed

+984
-378
lines changed

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

+171-61
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,21 @@
139139
#include "NVPTX.h"
140140
#include "NVPTXTargetMachine.h"
141141
#include "NVPTXUtilities.h"
142+
#include "llvm/ADT/STLExtras.h"
143+
#include "llvm/Analysis/PtrUseVisitor.h"
142144
#include "llvm/Analysis/ValueTracking.h"
143145
#include "llvm/CodeGen/TargetPassConfig.h"
144146
#include "llvm/IR/Function.h"
145147
#include "llvm/IR/IRBuilder.h"
146148
#include "llvm/IR/Instructions.h"
149+
#include "llvm/IR/IntrinsicInst.h"
147150
#include "llvm/IR/IntrinsicsNVPTX.h"
148151
#include "llvm/IR/Module.h"
149152
#include "llvm/IR/Type.h"
150153
#include "llvm/InitializePasses.h"
151154
#include "llvm/Pass.h"
155+
#include "llvm/Support/Debug.h"
156+
#include "llvm/Support/ErrorHandling.h"
152157
#include <numeric>
153158
#include <queue>
154159

@@ -217,7 +222,8 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
217222
// pointer in parameter AS.
218223
// For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
219224
// generic using cvta.param.
220-
static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
225+
static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
226+
bool IsGridConstant) {
221227
Instruction *I = dyn_cast<Instruction>(OldUse->getUser());
222228
assert(I && "OldUse must be in an instruction");
223229
struct IP {
@@ -228,7 +234,8 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
228234
SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
229235
SmallVector<Instruction *> InstructionsToDelete;
230236

231-
auto CloneInstInParamAS = [GridConstant](const IP &I) -> Value * {
237+
auto CloneInstInParamAS = [HasCvtaParam,
238+
IsGridConstant](const IP &I) -> Value * {
232239
if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
233240
LI->setOperand(0, I.NewParam);
234241
return LI;
@@ -252,8 +259,25 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
252259
// Just pass through the argument, the old ASC is no longer needed.
253260
return I.NewParam;
254261
}
262+
if (auto *MI = dyn_cast<MemTransferInst>(I.OldInstruction)) {
263+
if (MI->getRawSource() == I.OldUse->get()) {
264+
// convert to memcpy/memmove from param space.
265+
IRBuilder<> Builder(I.OldInstruction);
266+
Intrinsic::ID ID = MI->getIntrinsicID();
267+
268+
CallInst *B = Builder.CreateMemTransferInst(
269+
ID, MI->getRawDest(), MI->getDestAlign(), I.NewParam,
270+
MI->getSourceAlign(), MI->getLength(), MI->isVolatile());
271+
for (unsigned I : {0, 1})
272+
if (uint64_t Bytes = MI->getParamDereferenceableBytes(I))
273+
B->addDereferenceableParamAttr(I, Bytes);
274+
return B;
275+
}
276+
// We may be able to handle other cases if the argument is
277+
// __grid_constant__
278+
}
255279

256-
if (GridConstant) {
280+
if (HasCvtaParam) {
257281
auto GetParamAddrCastToGeneric =
258282
[](Value *Addr, Instruction *OriginalUser) -> Value * {
259283
PointerType *ReturnTy =
@@ -269,24 +293,44 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
269293
OriginalUser->getIterator());
270294
return CvtToGenCall;
271295
};
272-
273-
if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
274-
I.OldUse->set(GetParamAddrCastToGeneric(I.NewParam, CI));
275-
return CI;
296+
auto *ParamInGenericAS =
297+
GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);
298+
299+
// phi/select could use generic arg pointers w/o __grid_constant__
300+
if (auto *PHI = dyn_cast<PHINode>(I.OldInstruction)) {
301+
for (auto [Idx, V] : enumerate(PHI->incoming_values())) {
302+
if (V.get() == I.OldUse->get())
303+
PHI->setIncomingValue(Idx, ParamInGenericAS);
304+
}
276305
}
277-
if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
278-
// byval address is being stored, cast it to generic
279-
if (SI->getValueOperand() == I.OldUse->get())
280-
SI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, SI));
281-
return SI;
306+
if (auto *SI = dyn_cast<SelectInst>(I.OldInstruction)) {
307+
if (SI->getTrueValue() == I.OldUse->get())
308+
SI->setTrueValue(ParamInGenericAS);
309+
if (SI->getFalseValue() == I.OldUse->get())
310+
SI->setFalseValue(ParamInGenericAS);
282311
}
283-
if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
284-
if (PI->getPointerOperand() == I.OldUse->get())
285-
PI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, PI));
286-
return PI;
312+
313+
// Escapes or writes can only use generic param pointers if
314+
// __grid_constant__ is in effect.
315+
if (IsGridConstant) {
316+
if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
317+
I.OldUse->set(ParamInGenericAS);
318+
return CI;
319+
}
320+
if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
321+
// byval address is being stored, cast it to generic
322+
if (SI->getValueOperand() == I.OldUse->get())
323+
SI->setOperand(0, ParamInGenericAS);
324+
return SI;
325+
}
326+
if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
327+
if (PI->getPointerOperand() == I.OldUse->get())
328+
PI->setOperand(0, ParamInGenericAS);
329+
return PI;
330+
}
331+
// TODO: iIf we allow stores, we should allow memcpy/memset to
332+
// parameter, too.
287333
}
288-
llvm_unreachable(
289-
"Instruction unsupported even for grid_constant argument");
290334
}
291335

292336
llvm_unreachable("Unsupported instruction");
@@ -409,49 +453,110 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
409453
}
410454
}
411455

456+
namespace {
457+
struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
458+
using Base = PtrUseVisitor<ArgUseChecker>;
459+
460+
bool IsGridConstant;
461+
// Set of phi/select instructions using the Arg
462+
SmallPtrSet<Instruction *, 4> Conditionals;
463+
464+
ArgUseChecker(const DataLayout &DL, bool IsGridConstant)
465+
: PtrUseVisitor(DL), IsGridConstant(IsGridConstant) {}
466+
467+
PtrInfo visitArgPtr(Argument &A) {
468+
assert(A.getType()->isPointerTy());
469+
IntegerType *IntIdxTy = cast<IntegerType>(DL.getIndexType(A.getType()));
470+
IsOffsetKnown = false;
471+
Offset = APInt(IntIdxTy->getBitWidth(), 0);
472+
PI.reset();
473+
Conditionals.clear();
474+
475+
LLVM_DEBUG(dbgs() << "Checking Argument " << A << "\n");
476+
// Enqueue the uses of this pointer.
477+
enqueueUsers(A);
478+
479+
// Visit all the uses off the worklist until it is empty.
480+
// Note that unlike PtrUseVisitor we intentionally do not track offsets.
481+
// We're only interested in how we use the pointer.
482+
while (!(Worklist.empty() || PI.isAborted())) {
483+
UseToVisit ToVisit = Worklist.pop_back_val();
484+
U = ToVisit.UseAndIsOffsetKnown.getPointer();
485+
Instruction *I = cast<Instruction>(U->getUser());
486+
if (isa<PHINode>(I) || isa<SelectInst>(I))
487+
Conditionals.insert(I);
488+
LLVM_DEBUG(dbgs() << "Processing " << *I << "\n");
489+
Base::visit(I);
490+
}
491+
if (PI.isEscaped())
492+
LLVM_DEBUG(dbgs() << "Argument pointer escaped: " << *PI.getEscapingInst()
493+
<< "\n");
494+
else if (PI.isAborted())
495+
LLVM_DEBUG(dbgs() << "Pointer use needs a copy: " << *PI.getAbortingInst()
496+
<< "\n");
497+
LLVM_DEBUG(dbgs() << "Traversed " << Conditionals.size()
498+
<< " conditionals\n");
499+
return PI;
500+
}
501+
502+
void visitStoreInst(StoreInst &SI) {
503+
// Storing the pointer escapes it.
504+
if (U->get() == SI.getValueOperand())
505+
return PI.setEscapedAndAborted(&SI);
506+
// Writes to the pointer are UB w/ __grid_constant__, but do not force a
507+
// copy.
508+
if (!IsGridConstant)
509+
return PI.setAborted(&SI);
510+
}
511+
512+
void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
513+
// ASC to param space are no-ops and do not need a copy
514+
if (ASC.getDestAddressSpace() != ADDRESS_SPACE_PARAM)
515+
return PI.setEscapedAndAborted(&ASC);
516+
Base::visitAddrSpaceCastInst(ASC);
517+
}
518+
519+
void visitPtrToIntInst(PtrToIntInst &I) {
520+
if (IsGridConstant)
521+
return;
522+
Base::visitPtrToIntInst(I);
523+
}
524+
void visitPHINodeOrSelectInst(Instruction &I) {
525+
assert(isa<PHINode>(I) || isa<SelectInst>(I));
526+
}
527+
// PHI and select just pass through the pointers.
528+
void visitPHINode(PHINode &PN) { enqueueUsers(PN); }
529+
void visitSelectInst(SelectInst &SI) { enqueueUsers(SI); }
530+
531+
void visitMemTransferInst(MemTransferInst &II) {
532+
if (*U == II.getRawDest() && !IsGridConstant)
533+
PI.setAborted(&II);
534+
// memcpy/memmove are OK when the pointer is source. We can convert them to
535+
// AS-specific memcpy.
536+
}
537+
538+
void visitMemSetInst(MemSetInst &II) {
539+
if (!IsGridConstant)
540+
PI.setAborted(&II);
541+
}
542+
}; // struct ArgUseChecker
543+
} // namespace
544+
412545
void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
413546
Argument *Arg) {
414-
bool IsGridConstant = isParamGridConstant(*Arg);
415547
Function *Func = Arg->getParent();
548+
bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
549+
bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
550+
const DataLayout &DL = Func->getDataLayout();
416551
BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
417552
Type *StructType = Arg->getParamByValType();
418553
assert(StructType && "Missing byval type");
419554

420-
auto AreSupportedUsers = [&](Value *Start) {
421-
SmallVector<Value *, 16> ValuesToCheck = {Start};
422-
auto IsSupportedUse = [IsGridConstant](Value *V) -> bool {
423-
if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
424-
return true;
425-
// ASC to param space are OK, too -- we'll just strip them.
426-
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
427-
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
428-
return true;
429-
}
430-
// Simple calls and stores are supported for grid_constants
431-
// writes to these pointers are undefined behaviour
432-
if (IsGridConstant &&
433-
(isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
434-
return true;
435-
return false;
436-
};
437-
438-
while (!ValuesToCheck.empty()) {
439-
Value *V = ValuesToCheck.pop_back_val();
440-
if (!IsSupportedUse(V)) {
441-
LLVM_DEBUG(dbgs() << "Need a "
442-
<< (isParamGridConstant(*Arg) ? "cast " : "copy ")
443-
<< "of " << *Arg << " because of " << *V << "\n");
444-
(void)Arg;
445-
return false;
446-
}
447-
if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
448-
!isa<PtrToIntInst>(V))
449-
llvm::append_range(ValuesToCheck, V->users());
450-
}
451-
return true;
452-
};
453-
454-
if (llvm::all_of(Arg->users(), AreSupportedUsers)) {
555+
ArgUseChecker AUC(DL, IsGridConstant);
556+
ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
557+
bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
558+
// Easy case, accessing parameter directly is fine.
559+
if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
455560
// Convert all loads and intermediate operations to use parameter AS and
456561
// skip creation of a local copy of the argument.
457562
SmallVector<Use *, 16> UsesToUpdate;
@@ -462,7 +567,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
462567
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
463568
FirstInst);
464569
for (Use *U : UsesToUpdate)
465-
convertToParamAS(U, ArgInParamAS, IsGridConstant);
570+
convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
466571
LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
467572

468573
const auto *TLI =
@@ -473,13 +578,17 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
473578
return;
474579
}
475580

476-
const DataLayout &DL = Func->getDataLayout();
581+
// We can't access byval arg directly and need a pointer. on sm_70+ we have
582+
// ability to take a pointer to the argument without making a local copy.
583+
// However, we're still not allowed to write to it. If the user specified
584+
// `__grid_constant__` for the argument, we'll consider escaped pointer as
585+
// read-only.
477586
unsigned AS = DL.getAllocaAddrSpace();
478-
if (isParamGridConstant(*Arg)) {
479-
// Writes to a grid constant are undefined behaviour. We do not need a
480-
// temporary copy. When a pointer might have escaped, conservatively replace
481-
// all of its uses (which might include a device function call) with a cast
482-
// to the generic address space.
587+
if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
588+
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
589+
// Replace all argument pointer uses (which might include a device function
590+
// call) with a cast to the generic address space using cvta.param
591+
// instruction, which avoids a local copy.
483592
IRBuilder<> IRB(&Func->getEntryBlock().front());
484593

485594
// Cast argument to param address space
@@ -500,6 +609,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
500609
// Do not replace Arg in the cast to param space
501610
CastToParam->setOperand(0, Arg);
502611
} else {
612+
LLVM_DEBUG(dbgs() << "Creating a local copy of " << *Arg << "\n");
503613
// Otherwise we have to create a temporary copy.
504614
AllocaInst *AllocA =
505615
new AllocaInst(StructType, AS, Arg->getName(), FirstInst);

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
9494
bool hasDotInstructions() const {
9595
return SmVersion >= 61 && PTXVersion >= 50;
9696
}
97+
bool hasCvtaParam() const { return SmVersion >= 70 && PTXVersion >= 77; }
9798
unsigned int getFullSmVersion() const { return FullSmVersion; }
9899
unsigned int getSmVersion() const { return getFullSmVersion() / 10; }
99100
// GPUs with "a" suffix have include architecture-accelerated features that

0 commit comments

Comments
 (0)