@@ -25,29 +25,80 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
25
25
Location loc = gpuFuncOp.getLoc ();
26
26
27
27
SmallVector<LLVM::GlobalOp, 3 > workgroupBuffers;
28
- workgroupBuffers.reserve (gpuFuncOp.getNumWorkgroupAttributions ());
29
- for (const auto [idx, attribution] :
30
- llvm::enumerate (gpuFuncOp.getWorkgroupAttributions ())) {
31
- auto type = dyn_cast<MemRefType>(attribution.getType ());
32
- assert (type && type.hasStaticShape () && " unexpected type in attribution" );
33
-
34
- uint64_t numElements = type.getNumElements ();
35
-
36
- auto elementType =
37
- cast<Type>(typeConverter->convertType (type.getElementType ()));
38
- auto arrayType = LLVM::LLVMArrayType::get (elementType, numElements);
39
- std::string name =
40
- std::string (llvm::formatv (" __wg_{0}_{1}" , gpuFuncOp.getName (), idx));
41
- uint64_t alignment = 0 ;
42
- if (auto alignAttr =
43
- dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr (
44
- idx, LLVM::LLVMDialect::getAlignAttrName ())))
45
- alignment = alignAttr.getInt ();
46
- auto globalOp = rewriter.create <LLVM::GlobalOp>(
47
- gpuFuncOp.getLoc (), arrayType, /* isConstant=*/ false ,
48
- LLVM::Linkage::Internal, name, /* value=*/ Attribute (), alignment,
49
- workgroupAddrSpace);
50
- workgroupBuffers.push_back (globalOp);
28
+ if (encodeWorkgroupAttributionsAsArguments) {
29
+ // Append an `llvm.ptr` argument to the function signature to encode
30
+ // workgroup attributions.
31
+
32
+ ArrayRef<BlockArgument> workgroupAttributions =
33
+ gpuFuncOp.getWorkgroupAttributions ();
34
+ size_t numAttributions = workgroupAttributions.size ();
35
+
36
+ // Insert all arguments at the end.
37
+ unsigned index = gpuFuncOp.getNumArguments ();
38
+ SmallVector<unsigned > argIndices (numAttributions, index );
39
+
40
+ // New arguments will simply be `llvm.ptr` with the correct address space
41
+ Type workgroupPtrType =
42
+ rewriter.getType <LLVM::LLVMPointerType>(workgroupAddrSpace);
43
+ SmallVector<Type> argTypes (numAttributions, workgroupPtrType);
44
+
45
+ // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
46
+ std::array attrs{
47
+ rewriter.getNamedAttr (LLVM::LLVMDialect::getNoAliasAttrName (),
48
+ rewriter.getUnitAttr ()),
49
+ rewriter.getNamedAttr (
50
+ getDialect ().getWorkgroupAttributionAttrHelper ().getName (),
51
+ rewriter.getUnitAttr ()),
52
+ };
53
+ SmallVector<DictionaryAttr> argAttrs;
54
+ for (BlockArgument attribution : workgroupAttributions) {
55
+ auto attributionType = cast<MemRefType>(attribution.getType ());
56
+ IntegerAttr numElements =
57
+ rewriter.getI64IntegerAttr (attributionType.getNumElements ());
58
+ Type llvmElementType =
59
+ getTypeConverter ()->convertType (attributionType.getElementType ());
60
+ if (!llvmElementType)
61
+ return failure ();
62
+ TypeAttr type = TypeAttr::get (llvmElementType);
63
+ attrs.back ().setValue (
64
+ rewriter.getAttr <LLVM::WorkgroupAttributionAttr>(numElements, type));
65
+ argAttrs.push_back (rewriter.getDictionaryAttr (attrs));
66
+ }
67
+
68
+ // Location match function location
69
+ SmallVector<Location> argLocs (numAttributions, gpuFuncOp.getLoc ());
70
+
71
+ // Perform signature modification
72
+ rewriter.modifyOpInPlace (
73
+ gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
74
+ static_cast <FunctionOpInterface>(gpuFuncOp).insertArguments (
75
+ argIndices, argTypes, argAttrs, argLocs);
76
+ });
77
+ } else {
78
+ workgroupBuffers.reserve (gpuFuncOp.getNumWorkgroupAttributions ());
79
+ for (auto [idx, attribution] :
80
+ llvm::enumerate (gpuFuncOp.getWorkgroupAttributions ())) {
81
+ auto type = dyn_cast<MemRefType>(attribution.getType ());
82
+ assert (type && type.hasStaticShape () && " unexpected type in attribution" );
83
+
84
+ uint64_t numElements = type.getNumElements ();
85
+
86
+ auto elementType =
87
+ cast<Type>(typeConverter->convertType (type.getElementType ()));
88
+ auto arrayType = LLVM::LLVMArrayType::get (elementType, numElements);
89
+ std::string name =
90
+ std::string (llvm::formatv (" __wg_{0}_{1}" , gpuFuncOp.getName (), idx));
91
+ uint64_t alignment = 0 ;
92
+ if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
93
+ gpuFuncOp.getWorkgroupAttributionAttr (
94
+ idx, LLVM::LLVMDialect::getAlignAttrName ())))
95
+ alignment = alignAttr.getInt ();
96
+ auto globalOp = rewriter.create <LLVM::GlobalOp>(
97
+ gpuFuncOp.getLoc (), arrayType, /* isConstant=*/ false ,
98
+ LLVM::Linkage::Internal, name, /* value=*/ Attribute (), alignment,
99
+ workgroupAddrSpace);
100
+ workgroupBuffers.push_back (globalOp);
101
+ }
51
102
}
52
103
53
104
// Remap proper input types.
@@ -101,16 +152,19 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
101
152
// attribute. The former is necessary for further translation while the
102
153
// latter is expected by gpu.launch_func.
103
154
if (gpuFuncOp.isKernel ()) {
104
- attributes.emplace_back (kernelAttributeName, rewriter.getUnitAttr ());
155
+ if (kernelAttributeName)
156
+ attributes.emplace_back (kernelAttributeName, rewriter.getUnitAttr ());
105
157
// Set the dialect-specific block size attribute if there is one.
106
- if (kernelBlockSizeAttributeName.has_value () && knownBlockSize) {
107
- attributes.emplace_back (kernelBlockSizeAttributeName.value (),
108
- knownBlockSize);
158
+ if (kernelBlockSizeAttributeName && knownBlockSize) {
159
+ attributes.emplace_back (kernelBlockSizeAttributeName, knownBlockSize);
109
160
}
110
161
}
162
+ LLVM::CConv callingConvention = gpuFuncOp.isKernel ()
163
+ ? kernelCallingConvention
164
+ : nonKernelCallingConvention;
111
165
auto llvmFuncOp = rewriter.create <LLVM::LLVMFuncOp>(
112
166
gpuFuncOp.getLoc (), gpuFuncOp.getName (), funcType,
113
- LLVM::Linkage::External, /* dsoLocal=*/ false , /* cconv= */ LLVM::CConv::C ,
167
+ LLVM::Linkage::External, /* dsoLocal=*/ false , callingConvention ,
114
168
/* comdat=*/ nullptr , attributes);
115
169
116
170
{
@@ -125,24 +179,51 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
125
179
rewriter.setInsertionPointToStart (&gpuFuncOp.front ());
126
180
unsigned numProperArguments = gpuFuncOp.getNumArguments ();
127
181
128
- for (const auto [idx, global] : llvm::enumerate (workgroupBuffers)) {
129
- auto ptrType = LLVM::LLVMPointerType::get (rewriter.getContext (),
130
- global.getAddrSpace ());
131
- Value address = rewriter.create <LLVM::AddressOfOp>(
132
- loc, ptrType, global.getSymNameAttr ());
133
- Value memory =
134
- rewriter.create <LLVM::GEPOp>(loc, ptrType, global.getType (), address,
135
- ArrayRef<LLVM::GEPArg>{0 , 0 });
136
-
137
- // Build a memref descriptor pointing to the buffer to plug with the
138
- // existing memref infrastructure. This may use more registers than
139
- // otherwise necessary given that memref sizes are fixed, but we can try
140
- // and canonicalize that away later.
141
- Value attribution = gpuFuncOp.getWorkgroupAttributions ()[idx];
142
- auto type = cast<MemRefType>(attribution.getType ());
143
- auto descr = MemRefDescriptor::fromStaticShape (
144
- rewriter, loc, *getTypeConverter (), type, memory);
145
- signatureConversion.remapInput (numProperArguments + idx, descr);
182
+ if (encodeWorkgroupAttributionsAsArguments) {
183
+ // Build a MemRefDescriptor with each of the arguments added above.
184
+
185
+ unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions ();
186
+ assert (numProperArguments >= numAttributions &&
187
+ " Expecting attributions to be encoded as arguments already" );
188
+
189
+ // Arguments encoding workgroup attributions will be in positions
190
+ // [numProperArguments, numProperArguments+numAttributions)
191
+ ArrayRef<BlockArgument> attributionArguments =
192
+ gpuFuncOp.getArguments ().slice (numProperArguments - numAttributions,
193
+ numAttributions);
194
+ for (auto [idx, vals] : llvm::enumerate (llvm::zip_equal (
195
+ gpuFuncOp.getWorkgroupAttributions (), attributionArguments))) {
196
+ auto [attribution, arg] = vals;
197
+ auto type = cast<MemRefType>(attribution.getType ());
198
+
199
+ // Arguments are of llvm.ptr type and attributions are of memref type:
200
+ // we need to wrap them in memref descriptors.
201
+ Value descr = MemRefDescriptor::fromStaticShape (
202
+ rewriter, loc, *getTypeConverter (), type, arg);
203
+
204
+ // And remap the arguments
205
+ signatureConversion.remapInput (numProperArguments + idx, descr);
206
+ }
207
+ } else {
208
+ for (const auto [idx, global] : llvm::enumerate (workgroupBuffers)) {
209
+ auto ptrType = LLVM::LLVMPointerType::get (rewriter.getContext (),
210
+ global.getAddrSpace ());
211
+ Value address = rewriter.create <LLVM::AddressOfOp>(
212
+ loc, ptrType, global.getSymNameAttr ());
213
+ Value memory =
214
+ rewriter.create <LLVM::GEPOp>(loc, ptrType, global.getType (),
215
+ address, ArrayRef<LLVM::GEPArg>{0 , 0 });
216
+
217
+ // Build a memref descriptor pointing to the buffer to plug with the
218
+ // existing memref infrastructure. This may use more registers than
219
+ // otherwise necessary given that memref sizes are fixed, but we can try
220
+ // and canonicalize that away later.
221
+ Value attribution = gpuFuncOp.getWorkgroupAttributions ()[idx];
222
+ auto type = cast<MemRefType>(attribution.getType ());
223
+ auto descr = MemRefDescriptor::fromStaticShape (
224
+ rewriter, loc, *getTypeConverter (), type, memory);
225
+ signatureConversion.remapInput (numProperArguments + idx, descr);
226
+ }
146
227
}
147
228
148
229
// Rewrite private memory attributions to alloca'ed buffers.
@@ -239,6 +320,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
239
320
copyPointerAttribute (LLVM::LLVMDialect::getDereferenceableAttrName ());
240
321
copyPointerAttribute (
241
322
LLVM::LLVMDialect::getDereferenceableOrNullAttrName ());
323
+ copyPointerAttribute (
324
+ LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr ());
242
325
}
243
326
}
244
327
rewriter.eraseOp (gpuFuncOp);
0 commit comments