Skip to content

[NVPTX] Improve copy avoidance during lowering. #106423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 171 additions & 61 deletions llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,21 @@
#include "NVPTX.h"
#include "NVPTXTargetMachine.h"
#include "NVPTXUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/PtrUseVisitor.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include <numeric>
#include <queue>

Expand Down Expand Up @@ -217,7 +222,8 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
// pointer in parameter AS.
// For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
// generic using cvta.param.
static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
bool IsGridConstant) {
Instruction *I = dyn_cast<Instruction>(OldUse->getUser());
assert(I && "OldUse must be in an instruction");
struct IP {
Expand All @@ -228,7 +234,8 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
SmallVector<Instruction *> InstructionsToDelete;

auto CloneInstInParamAS = [GridConstant](const IP &I) -> Value * {
auto CloneInstInParamAS = [HasCvtaParam,
IsGridConstant](const IP &I) -> Value * {
if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
LI->setOperand(0, I.NewParam);
return LI;
Expand All @@ -252,8 +259,25 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
// Just pass through the argument, the old ASC is no longer needed.
return I.NewParam;
}
if (auto *MI = dyn_cast<MemTransferInst>(I.OldInstruction)) {
if (MI->getRawSource() == I.OldUse->get()) {
// convert to memcpy/memmove from param space.
IRBuilder<> Builder(I.OldInstruction);
Intrinsic::ID ID = MI->getIntrinsicID();

CallInst *B = Builder.CreateMemTransferInst(
ID, MI->getRawDest(), MI->getDestAlign(), I.NewParam,
MI->getSourceAlign(), MI->getLength(), MI->isVolatile());
for (unsigned I : {0, 1})
if (uint64_t Bytes = MI->getParamDereferenceableBytes(I))
B->addDereferenceableParamAttr(I, Bytes);
return B;
}
// We may be able to handle other cases if the argument is
// __grid_constant__
}

if (GridConstant) {
if (HasCvtaParam) {
auto GetParamAddrCastToGeneric =
[](Value *Addr, Instruction *OriginalUser) -> Value * {
PointerType *ReturnTy =
Expand All @@ -269,24 +293,44 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
OriginalUser->getIterator());
return CvtToGenCall;
};

if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
I.OldUse->set(GetParamAddrCastToGeneric(I.NewParam, CI));
return CI;
auto *ParamInGenericAS =
GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);

// phi/select could use generic arg pointers w/o __grid_constant__
if (auto *PHI = dyn_cast<PHINode>(I.OldInstruction)) {
for (auto [Idx, V] : enumerate(PHI->incoming_values())) {
if (V.get() == I.OldUse->get())
PHI->setIncomingValue(Idx, ParamInGenericAS);
}
}
if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
// byval address is being stored, cast it to generic
if (SI->getValueOperand() == I.OldUse->get())
SI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, SI));
return SI;
if (auto *SI = dyn_cast<SelectInst>(I.OldInstruction)) {
if (SI->getTrueValue() == I.OldUse->get())
SI->setTrueValue(ParamInGenericAS);
if (SI->getFalseValue() == I.OldUse->get())
SI->setFalseValue(ParamInGenericAS);
}
if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
if (PI->getPointerOperand() == I.OldUse->get())
PI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, PI));
return PI;

// Escapes or writes can only use generic param pointers if
// __grid_constant__ is in effect.
if (IsGridConstant) {
if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
I.OldUse->set(ParamInGenericAS);
return CI;
}
if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
// byval address is being stored, cast it to generic
if (SI->getValueOperand() == I.OldUse->get())
SI->setOperand(0, ParamInGenericAS);
return SI;
}
if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
if (PI->getPointerOperand() == I.OldUse->get())
PI->setOperand(0, ParamInGenericAS);
return PI;
}
// TODO: iIf we allow stores, we should allow memcpy/memset to
// parameter, too.
}
llvm_unreachable(
"Instruction unsupported even for grid_constant argument");
}

llvm_unreachable("Unsupported instruction");
Expand Down Expand Up @@ -409,49 +453,110 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
}
}

namespace {
struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
using Base = PtrUseVisitor<ArgUseChecker>;

bool IsGridConstant;
// Set of phi/select instructions using the Arg
SmallPtrSet<Instruction *, 4> Conditionals;

ArgUseChecker(const DataLayout &DL, bool IsGridConstant)
: PtrUseVisitor(DL), IsGridConstant(IsGridConstant) {}

PtrInfo visitArgPtr(Argument &A) {
assert(A.getType()->isPointerTy());
IntegerType *IntIdxTy = cast<IntegerType>(DL.getIndexType(A.getType()));
IsOffsetKnown = false;
Offset = APInt(IntIdxTy->getBitWidth(), 0);
PI.reset();
Conditionals.clear();

LLVM_DEBUG(dbgs() << "Checking Argument " << A << "\n");
// Enqueue the uses of this pointer.
enqueueUsers(A);

// Visit all the uses off the worklist until it is empty.
// Note that unlike PtrUseVisitor we intentionally do not track offsets.
// We're only interested in how we use the pointer.
while (!(Worklist.empty() || PI.isAborted())) {
UseToVisit ToVisit = Worklist.pop_back_val();
U = ToVisit.UseAndIsOffsetKnown.getPointer();
Instruction *I = cast<Instruction>(U->getUser());
if (isa<PHINode>(I) || isa<SelectInst>(I))
Conditionals.insert(I);
LLVM_DEBUG(dbgs() << "Processing " << *I << "\n");
Base::visit(I);
}
if (PI.isEscaped())
LLVM_DEBUG(dbgs() << "Argument pointer escaped: " << *PI.getEscapingInst()
<< "\n");
else if (PI.isAborted())
LLVM_DEBUG(dbgs() << "Pointer use needs a copy: " << *PI.getAbortingInst()
<< "\n");
LLVM_DEBUG(dbgs() << "Traversed " << Conditionals.size()
<< " conditionals\n");
return PI;
}

void visitStoreInst(StoreInst &SI) {
// Storing the pointer escapes it.
if (U->get() == SI.getValueOperand())
return PI.setEscapedAndAborted(&SI);
// Writes to the pointer are UB w/ __grid_constant__, but do not force a
// copy.
if (!IsGridConstant)
return PI.setAborted(&SI);
}

void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
// ASC to param space are no-ops and do not need a copy
if (ASC.getDestAddressSpace() != ADDRESS_SPACE_PARAM)
return PI.setEscapedAndAborted(&ASC);
Base::visitAddrSpaceCastInst(ASC);
}

void visitPtrToIntInst(PtrToIntInst &I) {
if (IsGridConstant)
return;
Base::visitPtrToIntInst(I);
}
void visitPHINodeOrSelectInst(Instruction &I) {
assert(isa<PHINode>(I) || isa<SelectInst>(I));
}
// PHI and select just pass through the pointers.
void visitPHINode(PHINode &PN) { enqueueUsers(PN); }
void visitSelectInst(SelectInst &SI) { enqueueUsers(SI); }

void visitMemTransferInst(MemTransferInst &II) {
if (*U == II.getRawDest() && !IsGridConstant)
PI.setAborted(&II);
// memcpy/memmove are OK when the pointer is source. We can convert them to
// AS-specific memcpy.
}

void visitMemSetInst(MemSetInst &II) {
if (!IsGridConstant)
PI.setAborted(&II);
}
}; // struct ArgUseChecker
} // namespace

void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
Argument *Arg) {
bool IsGridConstant = isParamGridConstant(*Arg);
Function *Func = Arg->getParent();
bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
const DataLayout &DL = Func->getDataLayout();
BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
Type *StructType = Arg->getParamByValType();
assert(StructType && "Missing byval type");

auto AreSupportedUsers = [&](Value *Start) {
SmallVector<Value *, 16> ValuesToCheck = {Start};
auto IsSupportedUse = [IsGridConstant](Value *V) -> bool {
if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
return true;
// ASC to param space are OK, too -- we'll just strip them.
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
return true;
}
// Simple calls and stores are supported for grid_constants
// writes to these pointers are undefined behaviour
if (IsGridConstant &&
(isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
return true;
return false;
};

while (!ValuesToCheck.empty()) {
Value *V = ValuesToCheck.pop_back_val();
if (!IsSupportedUse(V)) {
LLVM_DEBUG(dbgs() << "Need a "
<< (isParamGridConstant(*Arg) ? "cast " : "copy ")
<< "of " << *Arg << " because of " << *V << "\n");
(void)Arg;
return false;
}
if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
!isa<PtrToIntInst>(V))
llvm::append_range(ValuesToCheck, V->users());
}
return true;
};

if (llvm::all_of(Arg->users(), AreSupportedUsers)) {
ArgUseChecker AUC(DL, IsGridConstant);
ArgUseChecker::PtrInfo PI = AUC.visitArgPtr(*Arg);
bool ArgUseIsReadOnly = !(PI.isEscaped() || PI.isAborted());
// Easy case, accessing parameter directly is fine.
if (ArgUseIsReadOnly && AUC.Conditionals.empty()) {
// Convert all loads and intermediate operations to use parameter AS and
// skip creation of a local copy of the argument.
SmallVector<Use *, 16> UsesToUpdate;
Expand All @@ -462,7 +567,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
FirstInst);
for (Use *U : UsesToUpdate)
convertToParamAS(U, ArgInParamAS, IsGridConstant);
convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");

const auto *TLI =
Expand All @@ -473,13 +578,17 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
return;
}

const DataLayout &DL = Func->getDataLayout();
// We can't access byval arg directly and need a pointer. on sm_70+ we have
// ability to take a pointer to the argument without making a local copy.
// However, we're still not allowed to write to it. If the user specified
// `__grid_constant__` for the argument, we'll consider escaped pointer as
// read-only.
unsigned AS = DL.getAllocaAddrSpace();
if (isParamGridConstant(*Arg)) {
// Writes to a grid constant are undefined behaviour. We do not need a
// temporary copy. When a pointer might have escaped, conservatively replace
// all of its uses (which might include a device function call) with a cast
// to the generic address space.
if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
// Replace all argument pointer uses (which might include a device function
// call) with a cast to the generic address space using cvta.param
// instruction, which avoids a local copy.
IRBuilder<> IRB(&Func->getEntryBlock().front());

// Cast argument to param address space
Expand All @@ -500,6 +609,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
// Do not replace Arg in the cast to param space
CastToParam->setOperand(0, Arg);
} else {
LLVM_DEBUG(dbgs() << "Creating a local copy of " << *Arg << "\n");
// Otherwise we have to create a temporary copy.
AllocaInst *AllocA =
new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
bool hasDotInstructions() const {
return SmVersion >= 61 && PTXVersion >= 50;
}
bool hasCvtaParam() const { return SmVersion >= 70 && PTXVersion >= 77; }
unsigned int getFullSmVersion() const { return FullSmVersion; }
unsigned int getSmVersion() const { return getFullSmVersion() / 10; }
// GPUs with "a" suffix have include architecture-accelerated features that
Expand Down
Loading
Loading