Skip to content

Commit 56910a8

Browse files
authored
[NVPTX] Improve kernel byval parameter lowering (#136008)
This change introduces a new pattern for lowering kernel byval parameters in `NVPTXLowerArgs`. Each byval argument is wrapped in a call to a new intrinsic, `@llvm.nvvm.internal.addrspace.wrap`. This intrinsic explicitly equates to no instructions and is removed during operation legalization in SDAG. However, it allows us to change the addrspace of the arguments to 101 to reflect the fact that they will occupy this space when lowered by `LowerFormalArgs` in `NVPTXISelLowering`. Optionally, if a generic pointer to a param is needed, a standard `addrspacecast` is used. This approach offers several advantages: - Exposes addrspace optimizations: By using a standard `addrspacecast` back to generic space we allow InferAS to optimize this instruction, potentially sinking it through control flow or in other ways unsupported by `NVPTXLowerArgs`. This is demonstrated in several existing tests. - Clearer, more consistent semantics: Previously an `addrspacecast` from generic to param space was implicitly a no-op. This is problematic because it's not reciprocal with the inverse cast, violating LLVM semantics. Further it is very confusing given the existence of `cvta.to.param`. After this change the cast equates to this instruction. - Allow for the removal of all nvvm.ptr.* intrinsics: In a follow-up change the nvvm.ptr.gen.to.param and nvvm.ptr.param.to.gen intrinsics may be removed.
1 parent c40d3a4 commit 56910a8

File tree

13 files changed

+400
-388
lines changed

13 files changed

+400
-388
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

+16
Original file line numberDiff line numberDiff line change
@@ -1909,6 +1909,22 @@ def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
19091909
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
19101910
"llvm.nvvm.ptr.param.to.gen">;
19111911

1912+
// Represents an explicit hole in the LLVM IR type system. It may be inserted by
1913+
// the compiler in cases where a pointer is of the wrong type. In the backend
1914+
// this intrinsic will be folded away and not equate to any instruction. It
1915+
// should not be used by any frontend and should only be considered well defined
1916+
// when added in the following cases:
1917+
//
1918+
// - NVPTXLowerArgs: When wrapping a byval pointer argument to a kernel
1919+
// function to convert the address space from generic (0) to param (101).
1920+
// This accounts for the fact that the parameter symbols will occupy this
1921+
// space when lowered during ISel.
1922+
//
1923+
def int_nvvm_internal_addrspace_wrap :
1924+
DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_anyptr_ty],
1925+
[IntrNoMem, IntrSpeculatable, NoUndef<ArgIndex<0>>,
1926+
NoUndef<RetIndex>]>;
1927+
19121928
// Move intrinsics, used in nvvm internally
19131929

19141930
def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,9 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
985985
case ADDRESS_SPACE_LOCAL:
986986
Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local;
987987
break;
988+
case ADDRESS_SPACE_PARAM:
989+
Opc = TM.is64Bit() ? NVPTX::cvta_param_64 : NVPTX::cvta_param;
990+
break;
988991
}
989992
ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src));
990993
return;
@@ -1008,7 +1011,7 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
10081011
Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local;
10091012
break;
10101013
case ADDRESS_SPACE_PARAM:
1011-
Opc = TM.is64Bit() ? NVPTX::IMOV64r : NVPTX::IMOV32r;
1014+
Opc = TM.is64Bit() ? NVPTX::cvta_to_param_64 : NVPTX::cvta_to_param;
10121015
break;
10131016
}
10141017

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
10141014
{MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
10151015
MVT::v32i32, MVT::v64i32, MVT::v128i32},
10161016
Custom);
1017+
1018+
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
10171019
}
10181020

10191021
const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
@@ -1426,6 +1428,17 @@ static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
14261428

14271429
return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
14281430
}
1431+
1432+
// Peel of an addrspacecast to generic and load directly from the specific
1433+
// address space.
1434+
if (Ptr->getOpcode() == ISD::ADDRSPACECAST) {
1435+
const auto *ASC = cast<AddrSpaceCastSDNode>(Ptr);
1436+
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_GENERIC) {
1437+
Ptr = ASC->getOperand(0);
1438+
return MachinePointerInfo(ASC->getSrcAddressSpace());
1439+
}
1440+
}
1441+
14291442
return MachinePointerInfo();
14301443
}
14311444

@@ -2746,6 +2759,15 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
27462759
return Op;
27472760
}
27482761

2762+
static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
2763+
switch (Op->getConstantOperandVal(0)) {
2764+
default:
2765+
return Op;
2766+
case Intrinsic::nvvm_internal_addrspace_wrap:
2767+
return Op.getOperand(1);
2768+
}
2769+
}
2770+
27492771
// In PTX 64-bit CTLZ and CTPOP are supported, but they return a 32-bit value.
27502772
// Lower these into a node returning the correct type which is zero-extended
27512773
// back to the correct size.
@@ -2889,6 +2911,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28892911
return LowerGlobalAddress(Op, DAG);
28902912
case ISD::INTRINSIC_W_CHAIN:
28912913
return Op;
2914+
case ISD::INTRINSIC_WO_CHAIN:
2915+
return lowerIntrinsicWOChain(Op, DAG);
28922916
case ISD::INTRINSIC_VOID:
28932917
return LowerIntrinsicVoid(Op, DAG);
28942918
case ISD::BUILD_VECTOR:

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

+4-12
Original file line numberDiff line numberDiff line change
@@ -2395,18 +2395,10 @@ multiclass G_TO_NG<string Str> {
23952395
"cvta.to." # Str # ".u64 \t$result, $src;", []>;
23962396
}
23972397

2398-
defm cvta_local : NG_TO_G<"local">;
2399-
defm cvta_shared : NG_TO_G<"shared">;
2400-
defm cvta_global : NG_TO_G<"global">;
2401-
defm cvta_const : NG_TO_G<"const">;
2402-
2403-
defm cvta_to_local : G_TO_NG<"local">;
2404-
defm cvta_to_shared : G_TO_NG<"shared">;
2405-
defm cvta_to_global : G_TO_NG<"global">;
2406-
defm cvta_to_const : G_TO_NG<"const">;
2407-
2408-
// nvvm.ptr.param.to.gen
2409-
defm cvta_param : NG_TO_G<"param">;
2398+
foreach space = ["local", "shared", "global", "const", "param"] in {
2399+
defm cvta_#space : NG_TO_G<space>;
2400+
defm cvta_to_#space : G_TO_NG<space>;
2401+
}
24102402

24112403
def : Pat<(int_nvvm_ptr_param_to_gen i32:$src),
24122404
(cvta_param $src)>;

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

+43-46
Original file line numberDiff line numberDiff line change
@@ -265,18 +265,9 @@ static void convertToParamAS(Use *OldUse, Value *Param, bool HasCvtaParam,
265265
if (HasCvtaParam) {
266266
auto GetParamAddrCastToGeneric =
267267
[](Value *Addr, Instruction *OriginalUser) -> Value * {
268-
PointerType *ReturnTy =
269-
PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
270-
Function *CvtToGen = Intrinsic::getOrInsertDeclaration(
271-
OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
272-
{ReturnTy, PointerType::get(OriginalUser->getContext(),
273-
ADDRESS_SPACE_PARAM)});
274-
275-
// Cast param address to generic address space
276-
Value *CvtToGenCall =
277-
CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
278-
OriginalUser->getIterator());
279-
return CvtToGenCall;
268+
IRBuilder<> IRB(OriginalUser);
269+
Type *GenTy = IRB.getPtrTy(ADDRESS_SPACE_GENERIC);
270+
return IRB.CreateAddrSpaceCast(Addr, GenTy, Addr->getName() + ".gen");
280271
};
281272
auto *ParamInGenericAS =
282273
GetParamAddrCastToGeneric(I.NewParam, I.OldInstruction);
@@ -515,33 +506,34 @@ void copyByValParam(Function &F, Argument &Arg) {
515506
BasicBlock::iterator FirstInst = F.getEntryBlock().begin();
516507
Type *StructType = Arg.getParamByValType();
517508
const DataLayout &DL = F.getDataLayout();
518-
AllocaInst *AllocA = new AllocaInst(StructType, DL.getAllocaAddrSpace(),
519-
Arg.getName(), FirstInst);
509+
IRBuilder<> IRB(&*FirstInst);
510+
AllocaInst *AllocA = IRB.CreateAlloca(StructType, nullptr, Arg.getName());
520511
// Set the alignment to alignment of the byval parameter. This is because,
521512
// later load/stores assume that alignment, and we are going to replace
522513
// the use of the byval parameter with this alloca instruction.
523-
AllocA->setAlignment(F.getParamAlign(Arg.getArgNo())
524-
.value_or(DL.getPrefTypeAlign(StructType)));
514+
AllocA->setAlignment(
515+
Arg.getParamAlign().value_or(DL.getPrefTypeAlign(StructType)));
525516
Arg.replaceAllUsesWith(AllocA);
526517

527-
Value *ArgInParam = new AddrSpaceCastInst(
528-
&Arg, PointerType::get(Arg.getContext(), ADDRESS_SPACE_PARAM),
529-
Arg.getName(), FirstInst);
518+
Value *ArgInParam =
519+
IRB.CreateIntrinsic(Intrinsic::nvvm_internal_addrspace_wrap,
520+
{IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg.getType()},
521+
&Arg, {}, Arg.getName());
522+
530523
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
531524
// addrspacecast preserves alignment. Since params are constant, this load
532525
// is definitely not volatile.
533526
const auto ArgSize = *AllocA->getAllocationSize(DL);
534-
IRBuilder<> IRB(&*FirstInst);
535527
IRB.CreateMemCpy(AllocA, AllocA->getAlign(), ArgInParam, AllocA->getAlign(),
536528
ArgSize);
537529
}
538530
} // namespace
539531

540532
static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
541533
Function *Func = Arg->getParent();
542-
bool HasCvtaParam =
543-
TM.getSubtargetImpl(*Func)->hasCvtaParam() && isKernelFunction(*Func);
544-
bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
534+
assert(isKernelFunction(*Func));
535+
const bool HasCvtaParam = TM.getSubtargetImpl(*Func)->hasCvtaParam();
536+
const bool IsGridConstant = HasCvtaParam && isParamGridConstant(*Arg);
545537
const DataLayout &DL = Func->getDataLayout();
546538
BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
547539
Type *StructType = Arg->getParamByValType();
@@ -556,9 +548,11 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
556548
// skip creation of a local copy of the argument.
557549
SmallVector<Use *, 16> UsesToUpdate(llvm::make_pointer_range(Arg->uses()));
558550

559-
Value *ArgInParamAS = new AddrSpaceCastInst(
560-
Arg, PointerType::get(StructType->getContext(), ADDRESS_SPACE_PARAM),
561-
Arg->getName(), FirstInst);
551+
IRBuilder<> IRB(&*FirstInst);
552+
Value *ArgInParamAS = IRB.CreateIntrinsic(
553+
Intrinsic::nvvm_internal_addrspace_wrap,
554+
{IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()}, {Arg});
555+
562556
for (Use *U : UsesToUpdate)
563557
convertToParamAS(U, ArgInParamAS, HasCvtaParam, IsGridConstant);
564558
LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
@@ -576,30 +570,31 @@ static void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg) {
576570
// However, we're still not allowed to write to it. If the user specified
577571
// `__grid_constant__` for the argument, we'll consider escaped pointer as
578572
// read-only.
579-
if (HasCvtaParam && (ArgUseIsReadOnly || IsGridConstant)) {
573+
if (IsGridConstant || (HasCvtaParam && ArgUseIsReadOnly)) {
580574
LLVM_DEBUG(dbgs() << "Using non-copy pointer to " << *Arg << "\n");
581575
// Replace all argument pointer uses (which might include a device function
582576
// call) with a cast to the generic address space using cvta.param
583577
// instruction, which avoids a local copy.
584578
IRBuilder<> IRB(&Func->getEntryBlock().front());
585579

586-
// Cast argument to param address space
587-
auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
588-
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
580+
// Cast argument to param address space. Because the backend will emit the
581+
// argument already in the param address space, we need to use the noop
582+
// intrinsic, this had the added benefit of preventing other optimizations
583+
// from folding away this pair of addrspacecasts.
584+
auto *ParamSpaceArg =
585+
IRB.CreateIntrinsic(Intrinsic::nvvm_internal_addrspace_wrap,
586+
{IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getType()},
587+
Arg, {}, Arg->getName() + ".param");
589588

590-
// Cast param address to generic address space. We do not use an
591-
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the
592-
// generic address space, and a `generic -> param` cast followed by a `param
593-
// -> generic` cast will be folded away. The `param -> generic` intrinsic
594-
// will be correctly lowered to `cvta.param`.
595-
Value *CvtToGenCall = IRB.CreateIntrinsic(
596-
IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
597-
CastToParam, nullptr, CastToParam->getName() + ".gen");
589+
// Cast param address to generic address space.
590+
Value *GenericArg = IRB.CreateAddrSpaceCast(
591+
ParamSpaceArg, IRB.getPtrTy(ADDRESS_SPACE_GENERIC),
592+
Arg->getName() + ".gen");
598593

599-
Arg->replaceAllUsesWith(CvtToGenCall);
594+
Arg->replaceAllUsesWith(GenericArg);
600595

601596
// Do not replace Arg in the cast to param space
602-
CastToParam->setOperand(0, Arg);
597+
ParamSpaceArg->setOperand(0, Arg);
603598
} else
604599
copyByValParam(*Func, *Arg);
605600
}
@@ -713,12 +708,14 @@ static bool copyFunctionByValArgs(Function &F) {
713708
LLVM_DEBUG(dbgs() << "Creating a copy of byval args of " << F.getName()
714709
<< "\n");
715710
bool Changed = false;
716-
for (Argument &Arg : F.args())
717-
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
718-
!(isParamGridConstant(Arg) && isKernelFunction(F))) {
719-
copyByValParam(F, Arg);
720-
Changed = true;
721-
}
711+
if (isKernelFunction(F)) {
712+
for (Argument &Arg : F.args())
713+
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr() &&
714+
!isParamGridConstant(Arg)) {
715+
copyByValParam(F, Arg);
716+
Changed = true;
717+
}
718+
}
722719
return Changed;
723720
}
724721

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

+24-9
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
#include "llvm/ADT/ArrayRef.h"
1717
#include "llvm/ADT/SmallVector.h"
1818
#include "llvm/ADT/StringRef.h"
19+
#include "llvm/IR/Argument.h"
1920
#include "llvm/IR/Constants.h"
2021
#include "llvm/IR/Function.h"
2122
#include "llvm/IR/GlobalVariable.h"
2223
#include "llvm/IR/Module.h"
2324
#include "llvm/Support/Alignment.h"
25+
#include "llvm/Support/ModRef.h"
2426
#include "llvm/Support/Mutex.h"
2527
#include <cstdint>
2628
#include <cstring>
@@ -228,17 +230,30 @@ static std::optional<uint64_t> getVectorProduct(ArrayRef<unsigned> V) {
228230
return std::accumulate(V.begin(), V.end(), 1, std::multiplies<uint64_t>{});
229231
}
230232

231-
bool isParamGridConstant(const Value &V) {
232-
if (const Argument *Arg = dyn_cast<Argument>(&V)) {
233-
// "grid_constant" counts argument indices starting from 1
234-
if (Arg->hasByValAttr() &&
235-
argHasNVVMAnnotation(*Arg, "grid_constant",
236-
/*StartArgIndexAtOne*/ true)) {
237-
assert(isKernelFunction(*Arg->getParent()) &&
238-
"only kernel arguments can be grid_constant");
233+
bool isParamGridConstant(const Argument &Arg) {
234+
assert(isKernelFunction(*Arg.getParent()) &&
235+
"only kernel arguments can be grid_constant");
236+
237+
if (!Arg.hasByValAttr())
238+
return false;
239+
240+
// Lowering an argument as a grid_constant violates the byval semantics (and
241+
// the C++ API) by reusing the same memory location for the argument across
242+
// multiple threads. If an argument doesn't read memory and its address is not
243+
// captured (its address is not compared with any value), then the tweak of
244+
// the C++ API and byval semantics is unobservable by the program and we can
245+
// lower the arg as a grid_constant.
246+
if (Arg.onlyReadsMemory()) {
247+
const auto CI = Arg.getAttributes().getCaptureInfo();
248+
if (!capturesAddress(CI) && !capturesFullProvenance(CI))
239249
return true;
240-
}
241250
}
251+
252+
// "grid_constant" counts argument indices starting from 1
253+
if (argHasNVVMAnnotation(Arg, "grid_constant",
254+
/*StartArgIndexAtOne*/ true))
255+
return true;
256+
242257
return false;
243258
}
244259

llvm/lib/Target/NVPTX/NVPTXUtilities.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ inline bool isKernelFunction(const Function &F) {
6363
return F.getCallingConv() == CallingConv::PTX_Kernel;
6464
}
6565

66-
bool isParamGridConstant(const Value &);
66+
bool isParamGridConstant(const Argument &);
6767

6868
inline MaybeAlign getAlign(const Function &F, unsigned Index) {
6969
return F.getAttributes().getAttributes(Index).getStackAlignment();

llvm/test/CodeGen/NVPTX/bug21465.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ define ptx_kernel void @_Z11TakesStruct1SPi(ptr byval(%struct.S) nocapture reado
1212
entry:
1313
; CHECK-LABEL: @_Z11TakesStruct1SPi
1414
; PTX-LABEL: .visible .entry _Z11TakesStruct1SPi(
15-
; CHECK: addrspacecast ptr %input to ptr addrspace(101)
15+
; CHECK: call ptr addrspace(101) @llvm.nvvm.internal.addrspace.wrap.p101.p0(ptr %input)
1616
%b = getelementptr inbounds %struct.S, ptr %input, i64 0, i32 1
1717
%0 = load i32, ptr %b, align 4
1818
; PTX-NOT: ld.param.u32 {{%r[0-9]+}}, [{{%rd[0-9]+}}]

llvm/test/CodeGen/NVPTX/forward-ld-param.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ define void @test_ld_param_byval(ptr byval(i32) %a) {
6565
; CHECK-LABEL: test_ld_param_byval(
6666
; CHECK: {
6767
; CHECK-NEXT: .reg .b32 %r<2>;
68-
; CHECK-NEXT: .reg .b64 %rd<3>;
68+
; CHECK-NEXT: .reg .b64 %rd<2>;
6969
; CHECK-EMPTY:
7070
; CHECK-NEXT: // %bb.0:
7171
; CHECK-NEXT: ld.param.u32 %r1, [test_ld_param_byval_param_0];

0 commit comments

Comments
 (0)