@@ -52,6 +52,26 @@ namespace mlir {
52
52
53
53
using namespace mlir ;
54
54
55
+ // / Query function for static subgroup size lookup for given chipset.
56
+ // TODO: move this function to a common place.
57
+ static int64_t querySubgroupSize (const amdgpu::Chipset &chipset) {
58
+ // The subgroup size is the same as the wavefront size for all chipsets.
59
+ // The wavefront size is 64 for GCN and 32 for RDNA.
60
+
61
+ // There are two ways we can know the subgroup size:
62
+ // 1. subgroup size is passed down as part of configuration by the caller.
63
+ // 2. lower subgroup size down to LLVM intrinsic:
64
+ // `Intrinsic::amdgcn_wavefrontsize`, which will then be folded into a
65
+ // constant according to subtarget info.
66
+
67
+ // TODO: change to prefer method 1 if the caller has provided a subgroup size,
68
+ // otherwise use method 2. for now statically query the subgroup size
69
+ // according to the chipset.
70
+ if (chipset.majorVersion >= 10 )
71
+ return 32 ;
72
+ return 64 ;
73
+ }
74
+
55
75
// / Returns true if the given `gpu.func` can be safely called using the bare
56
76
// / pointer calling convention.
57
77
static bool canBeCalledWithBarePointers (gpu::GPUFuncOp func) {
@@ -90,7 +110,7 @@ static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
90
110
int64_t indexBitwidth = converter.getIndexTypeBitwidth ();
91
111
auto indexBitwidthType =
92
112
IntegerType::get (rewriter.getContext (), converter.getIndexTypeBitwidth ());
93
- // TODO: use <=> in C++20
113
+ // TODO: use <=> in C++20.
94
114
if (indexBitwidth > intWidth) {
95
115
return rewriter.create <LLVM::SExtOp>(loc, indexBitwidthType, value);
96
116
}
@@ -203,13 +223,21 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
203
223
204
224
struct GPUSubgroupIdOpToROCDL final
205
225
: ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
206
- using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
207
226
208
227
GPUSubgroupIdOpToROCDL (const LLVMTypeConverter &converter,
209
- const mlir::amdgpu::Chipset &chipset)
210
- : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
228
+ const mlir::amdgpu::Chipset &chipset,
229
+ std::optional<int64_t > subgroupSize = std::nullopt)
230
+ : ConvertOpToLLVMPattern(converter), chipset(chipset),
231
+ subgroupSize (subgroupSize) {}
211
232
212
233
const mlir::amdgpu::Chipset chipset;
234
+ const std::optional<int64_t > subgroupSize;
235
+
236
+ int64_t getSubgroupSize () const {
237
+ if (subgroupSize)
238
+ return *subgroupSize;
239
+ return querySubgroupSize (chipset);
240
+ }
213
241
214
242
LogicalResult
215
243
matchAndRewrite (gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
@@ -218,7 +246,7 @@ struct GPUSubgroupIdOpToROCDL final
218
246
auto loc = op.getLoc ();
219
247
LLVM::IntegerOverflowFlags flags =
220
248
LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
221
- // w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) / subgroup_size
249
+ // w_id.x + w_dim.x * (w_id.y + ( w_dim.y * w_id.z)) / subgroup_size
222
250
Value workitemIdX = rewriter.create <ROCDL::ThreadIdXOp>(loc, int32Type);
223
251
Value workitemIdY = rewriter.create <ROCDL::ThreadIdYOp>(loc, int32Type);
224
252
Value workitemIdZ = rewriter.create <ROCDL::ThreadIdZOp>(loc, int32Type);
@@ -233,8 +261,9 @@ struct GPUSubgroupIdOpToROCDL final
233
261
Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX =
234
262
rewriter.create <LLVM::AddOp>(loc, int32Type, workitemIdX,
235
263
dimYxIdZPlusIdYTimesDimX, flags);
236
- Value subgroupSize = rewriter.create <LLVM::ConstantOp>(
237
- loc, IntegerType::get (rewriter.getContext (), 32 ), 64 );
264
+
265
+ Value subgroupSize =
266
+ rewriter.create <LLVM::ConstantOp>(loc, int32Type, getSubgroupSize ());
238
267
Value waveIdOp = rewriter.create <LLVM::SDivOp>(
239
268
loc, workitemIdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
240
269
@@ -361,8 +390,10 @@ struct LowerGpuOpsToROCDLOpsPass final
361
390
362
391
populateAMDGPUToROCDLConversionPatterns (converter, llvmPatterns,
363
392
*maybeChipset);
364
- populateGpuToROCDLConversionPatterns (converter, llvmPatterns, runtime,
365
- *maybeChipset);
393
+ populateGpuToROCDLConversionPatterns (
394
+ converter, llvmPatterns, runtime, *maybeChipset,
395
+ subgroupSize == 0 ? std::nullopt
396
+ : std::optional<int64_t >(subgroupSize));
366
397
configureGpuToROCDLConversionLegality (target);
367
398
if (failed (applyPartialConversion (m, target, std::move (llvmPatterns))))
368
399
signalPassFailure ();
@@ -410,7 +441,8 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
410
441
411
442
void mlir::populateGpuToROCDLConversionPatterns (
412
443
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
413
- mlir::gpu::amd::Runtime runtime, mlir::amdgpu::Chipset chipset) {
444
+ mlir::gpu::amd::Runtime runtime, mlir::amdgpu::Chipset chipset,
445
+ std::optional<int64_t > subgroupSize) {
414
446
using gpu::index_lowering::IndexKind;
415
447
using gpu::index_lowering::IntrType;
416
448
using mlir::gpu::amd::Runtime;
@@ -449,7 +481,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
449
481
patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
450
482
451
483
patterns.add <GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
452
- patterns.add <GPUSubgroupIdOpToROCDL>(converter, chipset);
484
+ patterns.add <GPUSubgroupIdOpToROCDL>(converter, chipset, subgroupSize );
453
485
populateMathToROCDLConversionPatterns (converter, patterns);
454
486
}
455
487
0 commit comments