Skip to content

Commit d0615a9

Browse files
committed
[NVPTX] Handle bitcast and ASC(101) when trying to avoid argument copy.
This allows us to skip the copy in few more cases. Differential Revision: https://reviews.llvm.org/D99979
1 parent 9ef6aa0 commit d0615a9

File tree

2 files changed

+128
-31
lines changed

2 files changed

+128
-31
lines changed

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

+81-30
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@
9999
#include "llvm/IR/Type.h"
100100
#include "llvm/Pass.h"
101101

102+
#define DEBUG_TYPE "nvptx-lower-args"
103+
102104
using namespace llvm;
103105

104106
namespace llvm {
@@ -166,40 +168,60 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
166168
Value *NewParam;
167169
};
168170
SmallVector<IP> ItemsToConvert = {{I, Param}};
169-
SmallVector<GetElementPtrInst *> GEPsToDelete;
170-
while (!ItemsToConvert.empty()) {
171-
IP I = ItemsToConvert.pop_back_val();
172-
if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction))
171+
SmallVector<Instruction *> InstructionsToDelete;
172+
173+
auto CloneInstInParamAS = [](const IP &I) -> Value * {
174+
if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
173175
LI->setOperand(0, I.NewParam);
174-
else if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
176+
return LI;
177+
}
178+
if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
175179
SmallVector<Value *, 4> Indices(GEP->indices());
176180
auto *NewGEP = GetElementPtrInst::Create(nullptr, I.NewParam, Indices,
177181
GEP->getName(), GEP);
178182
NewGEP->setIsInBounds(GEP->isInBounds());
179-
llvm::for_each(GEP->users(), [NewGEP, &ItemsToConvert](Value *V) {
180-
ItemsToConvert.push_back({cast<Instruction>(V), NewGEP});
181-
});
182-
GEPsToDelete.push_back(GEP);
183-
} else
184-
llvm_unreachable("Only Load and GEP can be converted to param AS.");
185-
}
186-
llvm::for_each(GEPsToDelete,
187-
[](GetElementPtrInst *GEP) { GEP->eraseFromParent(); });
188-
}
183+
return NewGEP;
184+
}
185+
if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
186+
auto *NewBCType = BC->getType()->getPointerElementType()->getPointerTo(
187+
ADDRESS_SPACE_PARAM);
188+
return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
189+
BC->getName(), BC);
190+
}
191+
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
192+
assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
193+
// Just pass through the argument, the old ASC is no longer needed.
194+
return I.NewParam;
195+
}
196+
llvm_unreachable("Unsupported instruction");
197+
};
189198

190-
static bool isALoadChain(Value *Start) {
191-
SmallVector<Value *, 16> ValuesToCheck = {Start};
192-
while (!ValuesToCheck.empty()) {
193-
Value *V = ValuesToCheck.pop_back_val();
194-
Instruction *I = dyn_cast<Instruction>(V);
195-
if (!I)
196-
return false;
197-
if (isa<GetElementPtrInst>(I))
198-
ValuesToCheck.append(I->user_begin(), I->user_end());
199-
else if (!isa<LoadInst>(I))
200-
return false;
199+
while (!ItemsToConvert.empty()) {
200+
IP I = ItemsToConvert.pop_back_val();
201+
Value *NewInst = CloneInstInParamAS(I);
202+
203+
if (NewInst && NewInst != I.OldInstruction) {
204+
// We've created a new instruction. Queue users of the old instruction to
205+
// be converted and the instruction itself to be deleted. We can't delete
206+
// the old instruction yet, because it's still in use by a load somewhere.
207+
llvm::for_each(
208+
I.OldInstruction->users(), [NewInst, &ItemsToConvert](Value *V) {
209+
ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
210+
});
211+
212+
InstructionsToDelete.push_back(I.OldInstruction);
213+
}
201214
}
202-
return true;
215+
216+
// Now we know that all argument loads are using addresses in parameter space
217+
// and we can finally remove the old instructions in generic AS. Instructions
218+
// scheduled for removal should be processed in reverse order so the ones
219+
// closest to the load are deleted first. Otherwise they may still be in use.
220+
// E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
221+
// have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
222+
// the BitCast.
223+
llvm::for_each(reverse(InstructionsToDelete),
224+
[](Instruction *I) { I->eraseFromParent(); });
203225
}
204226

205227
void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
@@ -211,16 +233,43 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
211233

212234
Type *StructType = PType->getElementType();
213235

214-
if (llvm::all_of(Arg->users(), isALoadChain)) {
215-
// Replace all loads with the loads in param AS. This allows loading the Arg
216-
// directly from parameter AS, without making a temporary copy.
236+
auto IsALoadChain = [Arg](Value *Start) {
237+
SmallVector<Value *, 16> ValuesToCheck = {Start};
238+
auto IsALoadChainInstr = [](Value *V) -> bool {
239+
if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
240+
return true;
241+
// ASC to param space are OK, too -- we'll just strip them.
242+
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
243+
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
244+
return true;
245+
}
246+
return false;
247+
};
248+
249+
while (!ValuesToCheck.empty()) {
250+
Value *V = ValuesToCheck.pop_back_val();
251+
if (!IsALoadChainInstr(V)) {
252+
LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
253+
<< "\n");
254+
return false;
255+
}
256+
if (!isa<LoadInst>(V))
257+
llvm::append_range(ValuesToCheck, V->users());
258+
}
259+
return true;
260+
};
261+
262+
if (llvm::all_of(Arg->users(), IsALoadChain)) {
263+
// Convert all loads and intermediate operations to use parameter AS and
264+
// skip creation of a local copy of the argument.
217265
SmallVector<User *, 16> UsersToUpdate(Arg->users());
218266
Value *ArgInParamAS = new AddrSpaceCastInst(
219267
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
220268
FirstInst);
221269
llvm::for_each(UsersToUpdate, [ArgInParamAS](Value *V) {
222270
convertToParamAS(V, ArgInParamAS);
223271
});
272+
LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
224273
return;
225274
}
226275

@@ -297,6 +346,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
297346
}
298347
}
299348

349+
LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
300350
for (Argument &Arg : F.args()) {
301351
if (Arg.getType()->isPointerTy()) {
302352
if (Arg.hasByValAttr())
@@ -310,6 +360,7 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
310360

311361
// Device functions only need to copy byval args into local memory.
312362
bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) {
363+
LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
313364
for (Argument &Arg : F.args())
314365
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
315366
handleByValParam(&Arg);

llvm/test/CodeGen/NVPTX/lower-byval-args.ll

+47-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ target triple = "nvptx64-nvidia-cuda"
66

77
; // Verify that load with static offset into parameter is done directly.
88
; CHECK-LABEL: .visible .entry static_offset
9+
; CHECK-NOT: .local
910
; CHECK: ld.param.u64 [[result_addr:%rd[0-9]+]], [{{.*}}_param_0]
1011
; CHECK: mov.b64 %[[param_addr:rd[0-9]+]], {{.*}}_param_1
1112
; CHECK: mov.u64 %[[param_addr1:rd[0-9]+]], %[[param_addr]]
@@ -30,6 +31,7 @@ bb6: ; preds = %bb3, %bb
3031

3132
; // Verify that load with dynamic offset into parameter is also done directly.
3233
; CHECK-LABEL: .visible .entry dynamic_offset
34+
; CHECK-NOT: .local
3335
; CHECK: ld.param.u64 [[result_addr:%rd[0-9]+]], [{{.*}}_param_0]
3436
; CHECK: mov.b64 %[[param_addr:rd[0-9]+]], {{.*}}_param_1
3537
; CHECK: mov.u64 %[[param_addr1:rd[0-9]+]], %[[param_addr]]
@@ -48,6 +50,48 @@ bb:
4850
ret void
4951
}
5052

53+
; Same as above, but with a bitcast present in the chain
54+
; CHECK-LABEL:.visible .entry gep_bitcast
55+
; CHECK-NOT: .local
56+
; CHECK-DAG: ld.param.u64 [[out:%rd[0-9]+]], [gep_bitcast_param_0]
57+
; CHECK-DAG: mov.b64 {{%rd[0-9]+}}, gep_bitcast_param_1
58+
; CHECK-DAG: ld.param.u32 {{%r[0-9]+}}, [gep_bitcast_param_2]
59+
; CHECK: ld.param.u8 [[value:%rs[0-9]+]], [{{%rd[0-9]+}}]
60+
; CHECK: st.global.u8 [{{%rd[0-9]+}}], [[value]];
61+
;
62+
; Function Attrs: nofree norecurse nounwind willreturn mustprogress
63+
define dso_local void @gep_bitcast(i8* nocapture %out, %struct.ham* nocapture readonly byval(%struct.ham) align 4 %in, i32 %n) local_unnamed_addr #0 {
64+
bb:
65+
%n64 = sext i32 %n to i64
66+
%gep = getelementptr inbounds %struct.ham, %struct.ham* %in, i64 0, i32 0, i64 %n64
67+
%bc = bitcast i32* %gep to i8*
68+
%load = load i8, i8* %bc, align 4
69+
store i8 %load, i8* %out, align 4
70+
ret void
71+
}
72+
73+
; Same as above, but with an ASC(101) present in the chain
74+
; CHECK-LABEL:.visible .entry gep_bitcast_asc
75+
; CHECK-NOT: .local
76+
; CHECK-DAG: ld.param.u64 [[out:%rd[0-9]+]], [gep_bitcast_asc_param_0]
77+
; CHECK-DAG: mov.b64 {{%rd[0-9]+}}, gep_bitcast_asc_param_1
78+
; CHECK-DAG: ld.param.u32 {{%r[0-9]+}}, [gep_bitcast_asc_param_2]
79+
; CHECK: ld.param.u8 [[value:%rs[0-9]+]], [{{%rd[0-9]+}}]
80+
; CHECK: st.global.u8 [{{%rd[0-9]+}}], [[value]];
81+
;
82+
; Function Attrs: nofree norecurse nounwind willreturn mustprogress
83+
define dso_local void @gep_bitcast_asc(i8* nocapture %out, %struct.ham* nocapture readonly byval(%struct.ham) align 4 %in, i32 %n) local_unnamed_addr #0 {
84+
bb:
85+
%n64 = sext i32 %n to i64
86+
%gep = getelementptr inbounds %struct.ham, %struct.ham* %in, i64 0, i32 0, i64 %n64
87+
%bc = bitcast i32* %gep to i8*
88+
%asc = addrspacecast i8* %bc to i8 addrspace(101)*
89+
%load = load i8, i8 addrspace(101)* %asc, align 4
90+
store i8 %load, i8* %out, align 4
91+
ret void
92+
}
93+
94+
5195
; Verify that if the pointer escapes, then we do fall back onto using a temp copy.
5296
; CHECK-LABEL: .visible .entry pointer_escapes
5397
; CHECK: .local .align 8 .b8 __local_depot{{.*}}
@@ -82,11 +126,13 @@ declare dso_local i32* @escape(i32*) local_unnamed_addr
82126

83127

84128
!llvm.module.flags = !{!0, !1, !2}
85-
!nvvm.annotations = !{!3, !4, !5}
129+
!nvvm.annotations = !{!3, !4, !5, !6, !7}
86130

87131
!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 9, i32 1]}
88132
!1 = !{i32 1, !"wchar_size", i32 4}
89133
!2 = !{i32 4, !"nvvm-reflect-ftz", i32 0}
90134
!3 = !{void (i32*, %struct.ham*, i32)* @static_offset, !"kernel", i32 1}
91135
!4 = !{void (i32*, %struct.ham*, i32)* @dynamic_offset, !"kernel", i32 1}
92136
!5 = !{void (i32*, %struct.ham*, i32)* @pointer_escapes, !"kernel", i32 1}
137+
!6 = !{void (i8*, %struct.ham*, i32)* @gep_bitcast, !"kernel", i32 1}
138+
!7 = !{void (i8*, %struct.ham*, i32)* @gep_bitcast_asc, !"kernel", i32 1}

0 commit comments

Comments
 (0)