14
14
//
15
15
// ===----------------------------------------------------------------------===//
16
16
17
- #include " LocalAccessorToSharedMemory.h"
18
- #include " ../MCTargetDesc/NVPTXBaseInfo.h"
17
+ #include " llvm/SYCLLowerIR/LocalAccessorToSharedMemory.h"
19
18
#include " llvm/IR/GlobalValue.h"
20
19
#include " llvm/IR/Instructions.h"
21
20
#include " llvm/IR/PassManager.h"
21
+ #include " llvm/Support/CommandLine.h"
22
22
#include " llvm/Transforms/IPO.h"
23
23
24
24
using namespace llvm ;
25
25
26
26
#define DEBUG_TYPE " localaccessortosharedmemory"
27
27
28
+ static bool EnableLocalAccessor;
29
+
30
+ static cl::opt<bool , true > EnableLocalAccessorFlag (
31
+ " sycl-enable-local-accessor" , cl::Hidden,
32
+ cl::desc (" Enable local accessor to shared memory optimisation." ),
33
+ cl::location(EnableLocalAccessor), cl::init(false ));
34
+
28
35
namespace llvm {
29
36
void initializeLocalAccessorToSharedMemoryPass (PassRegistry &);
30
- }
37
+ } // namespace llvm
31
38
32
39
namespace {
33
40
34
41
class LocalAccessorToSharedMemory : public ModulePass {
42
+ private:
43
+ enum class ArchType { Cuda, AMDHSA, Unsupported };
44
+
45
+ struct KernelPayload {
46
+ KernelPayload (Function *Kernel, MDNode *MD = nullptr )
47
+ : Kernel(Kernel), MD(MD){};
48
+ Function *Kernel;
49
+ MDNode *MD;
50
+ };
51
+
52
+ unsigned SharedASValue = 0 ;
53
+
35
54
public:
36
55
static char ID;
37
56
LocalAccessorToSharedMemory () : ModulePass(ID) {}
38
57
39
58
bool runOnModule (Module &M) override {
59
+ if (!EnableLocalAccessor)
60
+ return false ;
61
+
62
+ auto AT = StringSwitch<ArchType>(M.getTargetTriple ().c_str ())
63
+ .Case (" nvptx64-nvidia-cuda" , ArchType::Cuda)
64
+ .Case (" nvptx-nvidia-cuda" , ArchType::Cuda)
65
+ .Case (" amdgcn-amd-amdhsa" , ArchType::AMDHSA)
66
+ .Default (ArchType::Unsupported);
67
+
40
68
// Invariant: This pass is only intended to operate on SYCL kernels being
41
- // compiled to the `nvptx{,64}-nvidia-cuda` triple.
42
- // TODO: make sure that non-SYCL kernels are not impacted.
69
+ // compiled to either `nvptx{,64}-nvidia-cuda`, or `amdgcn-amd-amdhsa`
70
+ // triples.
71
+ if (ArchType::Unsupported == AT)
72
+ return false ;
73
+
43
74
if (skipModule (M))
44
75
return false ;
45
76
46
- // Keep track of whether the module was changed.
47
- auto Changed = false ;
77
+ switch (AT) {
78
+ case ArchType::Cuda:
79
+ // ADDRESS_SPACE_SHARED = 3,
80
+ SharedASValue = 3 ;
81
+ break ;
82
+ case ArchType::AMDHSA:
83
+ // LOCAL_ADDRESS = 3,
84
+ SharedASValue = 3 ;
85
+ break ;
86
+ default :
87
+ SharedASValue = 0 ;
88
+ break ;
89
+ }
48
90
49
- // Access `nvvm.annotations` to determine which functions are kernel entry
50
- // points.
51
- auto NvvmMetadata = M. getNamedMetadata ( " nvvm.annotations " );
52
- if (!NvvmMetadata )
91
+ SmallVector<KernelPayload> Kernels;
92
+ SmallVector<std::pair<Function *, KernelPayload>> NewToOldKernels;
93
+ populateKernels (M, Kernels, AT );
94
+ if (Kernels. empty () )
53
95
return false ;
54
96
55
- for (auto MetadataNode : NvvmMetadata->operands ()) {
56
- if (MetadataNode->getNumOperands () != 3 )
57
- continue ;
97
+ // Process the function and if changed, update the metadata.
98
+ for (auto K : Kernels) {
99
+ auto *NewKernel = processKernel (M, K.Kernel );
100
+ if (NewKernel)
101
+ NewToOldKernels.push_back (std::make_pair (NewKernel, K));
102
+ }
58
103
59
- // NVPTX identifies kernel entry points using metadata nodes of the form:
60
- // !X = !{<function>, !"kernel", i32 1}
61
- const MDOperand &TypeOperand = MetadataNode->getOperand (1 );
62
- auto Type = dyn_cast<MDString>(TypeOperand);
63
- if (!Type)
64
- continue ;
65
- // Only process kernel entry points.
66
- if (Type->getString () != " kernel" )
67
- continue ;
104
+ if (NewToOldKernels.empty ())
105
+ return false ;
68
106
69
- // Get a pointer to the entry point function from the metadata.
70
- const MDOperand &FuncOperand = MetadataNode->getOperand (0 );
71
- if (!FuncOperand)
72
- continue ;
73
- auto FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
74
- if (!FuncConstant)
75
- continue ;
76
- auto Func = dyn_cast<Function>(FuncConstant->getValue ());
77
- if (!Func)
78
- continue ;
107
+ postProcessKernels (NewToOldKernels, AT);
79
108
80
- // Process the function and if changed, update the metadata.
81
- auto NewFunc = this ->ProcessFunction (M, Func);
82
- if (NewFunc) {
83
- Changed = true ;
84
- MetadataNode->replaceOperandWith (
85
- 0 , llvm::ConstantAsMetadata::get (NewFunc));
86
- }
87
- }
109
+ return true ;
110
+ }
88
111
89
- return Changed;
112
+ virtual llvm::StringRef getPassName () const override {
113
+ return " SYCL Local Accessor to Shared Memory" ;
90
114
}
91
115
92
- Function *ProcessFunction (Module &M, Function *F) {
116
+ private:
117
+ Function *processKernel (Module &M, Function *F) {
93
118
// Check if this function is eligible by having an argument that uses shared
94
119
// memory.
95
120
auto UsesLocalMemory = false ;
96
121
for (Function::arg_iterator FA = F->arg_begin (), FE = F->arg_end ();
97
122
FA != FE; ++FA) {
98
- if (FA->getType ()->isPointerTy ()) {
99
- UsesLocalMemory =
100
- FA->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_SHARED;
101
- }
102
- if (UsesLocalMemory) {
123
+ if (FA->getType ()->isPointerTy () &&
124
+ FA->getType ()->getPointerAddressSpace () == SharedASValue) {
125
+ UsesLocalMemory = true ;
103
126
break ;
104
127
}
105
128
}
@@ -111,9 +134,9 @@ class LocalAccessorToSharedMemory : public ModulePass {
111
134
// Create a global symbol to CUDA shared memory.
112
135
auto SharedMemGlobalName = F->getName ().str ();
113
136
SharedMemGlobalName.append (" _shared_mem" );
114
- auto SharedMemGlobalType =
137
+ auto * SharedMemGlobalType =
115
138
ArrayType::get (Type::getInt8Ty (M.getContext ()), 0 );
116
- auto SharedMemGlobal = new GlobalVariable (
139
+ auto * SharedMemGlobal = new GlobalVariable (
117
140
/* Module= */ M,
118
141
/* Type= */ &*SharedMemGlobalType,
119
142
/* IsConstant= */ false ,
@@ -122,7 +145,7 @@ class LocalAccessorToSharedMemory : public ModulePass {
122
145
/* Name= */ Twine{SharedMemGlobalName},
123
146
/* InsertBefore= */ nullptr ,
124
147
/* ThreadLocalMode= */ GlobalValue::NotThreadLocal,
125
- /* AddressSpace= */ ADDRESS_SPACE_SHARED ,
148
+ /* AddressSpace= */ SharedASValue ,
126
149
/* IsExternallyInitialized= */ false );
127
150
SharedMemGlobal->setAlignment (Align (4 ));
128
151
@@ -139,7 +162,7 @@ class LocalAccessorToSharedMemory : public ModulePass {
139
162
for (Function::arg_iterator FA = F->arg_begin (), FE = F->arg_end ();
140
163
FA != FE; ++FA, ++i) {
141
164
if (FA->getType ()->isPointerTy () &&
142
- FA->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_SHARED ) {
165
+ FA->getType ()->getPointerAddressSpace () == SharedASValue ) {
143
166
// Replace pointers to shared memory with i32 offsets.
144
167
Arguments.push_back (Type::getInt32Ty (M.getContext ()));
145
168
ArgumentAttributes.push_back (
@@ -178,8 +201,8 @@ class LocalAccessorToSharedMemory : public ModulePass {
178
201
if (ArgumentReplaced[i]) {
179
202
// If this argument was replaced, then create a `getelementptr`
180
203
// instruction that uses it to recreate the pointer that was replaced.
181
- auto InsertBefore = &NF->getEntryBlock ().front ();
182
- auto PtrInst = GetElementPtrInst::CreateInBounds (
204
+ auto * InsertBefore = &NF->getEntryBlock ().front ();
205
+ auto * PtrInst = GetElementPtrInst::CreateInBounds (
183
206
/* PointeeType= */ SharedMemGlobalType,
184
207
/* Ptr= */ SharedMemGlobal,
185
208
/* IdxList= */
@@ -191,7 +214,7 @@ class LocalAccessorToSharedMemory : public ModulePass {
191
214
// Then create a bitcast to make sure the new pointer is the same type
192
215
// as the old one. This will only ever be a `i8 addrspace(3)*` to `i32
193
216
// addrspace(3)*` type of cast.
194
- auto CastInst = new BitCastInst (PtrInst, FA->getType ());
217
+ auto * CastInst = new BitCastInst (PtrInst, FA->getType ());
195
218
CastInst->insertAfter (PtrInst);
196
219
NewValueForUse = CastInst;
197
220
}
@@ -217,11 +240,85 @@ class LocalAccessorToSharedMemory : public ModulePass {
217
240
return NF;
218
241
}
219
242
220
- virtual llvm::StringRef getPassName () const {
221
- return " localaccessortosharedmemory" ;
243
+ void populateCudaKernels (Module &M, SmallVector<KernelPayload> &Kernels) {
244
+ // Access `nvvm.annotations` to determine which functions are kernel entry
245
+ // points.
246
+ auto *NvvmMetadata = M.getNamedMetadata (" nvvm.annotations" );
247
+ if (!NvvmMetadata)
248
+ return ;
249
+
250
+ for (auto *MetadataNode : NvvmMetadata->operands ()) {
251
+ if (MetadataNode->getNumOperands () != 3 )
252
+ continue ;
253
+
254
+ // NVPTX identifies kernel entry points using metadata nodes of the form:
255
+ // !X = !{<function>, !"kernel", i32 1}
256
+ const MDOperand &TypeOperand = MetadataNode->getOperand (1 );
257
+ auto *Type = dyn_cast<MDString>(TypeOperand);
258
+ if (!Type)
259
+ continue ;
260
+ // Only process kernel entry points.
261
+ if (Type->getString () != " kernel" )
262
+ continue ;
263
+
264
+ // Get a pointer to the entry point function from the metadata.
265
+ const MDOperand &FuncOperand = MetadataNode->getOperand (0 );
266
+ if (!FuncOperand)
267
+ continue ;
268
+ auto *FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
269
+ if (!FuncConstant)
270
+ continue ;
271
+ auto *Func = dyn_cast<Function>(FuncConstant->getValue ());
272
+ if (!Func)
273
+ continue ;
274
+
275
+ Kernels.push_back (KernelPayload (Func, MetadataNode));
276
+ }
277
+ }
278
+
279
+ void populateAMDKernels (Module &M, SmallVector<KernelPayload> &Kernels) {
280
+ for (auto &F : M) {
281
+ if (F.getCallingConv () == CallingConv::AMDGPU_KERNEL)
282
+ Kernels.push_back (KernelPayload (&F));
283
+ }
222
284
}
223
- };
224
285
286
+ void populateKernels (Module &M, SmallVector<KernelPayload> &Kernels,
287
+ ArchType AT) {
288
+ switch (AT) {
289
+ case ArchType::Cuda:
290
+ return populateCudaKernels (M, Kernels);
291
+ case ArchType::AMDHSA:
292
+ return populateAMDKernels (M, Kernels);
293
+ default :
294
+ llvm_unreachable (" Unsupported arch type." );
295
+ }
296
+ }
297
+
298
+ void postProcessCudaKernels (
299
+ SmallVector<std::pair<Function *, KernelPayload>> &NewToOldKernels) {
300
+ for (auto &Pair : NewToOldKernels) {
301
+ std::get<1 >(Pair).MD ->replaceOperandWith (
302
+ 0 , llvm::ConstantAsMetadata::get (std::get<0 >(Pair)));
303
+ }
304
+ }
305
+
306
+ void postProcessAMDKernels (
307
+ SmallVector<std::pair<Function *, KernelPayload>> &NewToOldKernels) {}
308
+
309
+ void postProcessKernels (
310
+ SmallVector<std::pair<Function *, KernelPayload>> &NewToOldKernels,
311
+ ArchType AT) {
312
+ switch (AT) {
313
+ case ArchType::Cuda:
314
+ return postProcessCudaKernels (NewToOldKernels);
315
+ case ArchType::AMDHSA:
316
+ return postProcessAMDKernels (NewToOldKernels);
317
+ default :
318
+ llvm_unreachable (" Unsupported arch type." );
319
+ }
320
+ }
321
+ };
225
322
} // end anonymous namespace
226
323
227
324
char LocalAccessorToSharedMemory::ID = 0 ;
0 commit comments