Skip to content

Commit 0a535e8

Browse files
committed
update according to comments.
1 parent 25ab466 commit 0a535e8

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class GPUModuleOp;
3434
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter,
3535
RewritePatternSet &patterns,
3636
gpu::amd::Runtime runtime,
37-
mlir::amdgpu::Chipset chipset);
37+
mlir::amdgpu::Chipset chipset,
38+
std::optional<int64_t> subgroupSize);
3839

3940
/// Configure target to convert from the GPU dialect to ROCDL.
4041
void configureGpuToROCDLConversionLegality(ConversionTarget &target);

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,10 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
608608
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
609609
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL",
610610
"OpenCL"))}]>,
611+
Option<"subgroupSize", "subgroup-size", "unsigned",
612+
"0",
613+
"specify subgroup size for the kernel, if left empty, the default "
614+
"value will be decided by the target chipset.">,
611615
ListOption<"allowedDialects", "allowed-dialects", "std::string",
612616
"Run conversion patterns of only the specified dialects">,
613617
];

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,26 @@ namespace mlir {
5252

5353
using namespace mlir;
5454

55+
/// Query function for static subgroup size lookup for given chipset.
56+
// TODO: move this function to a common place.
57+
static int64_t querySubgroupSize(const amdgpu::Chipset &chipset) {
58+
// The subgroup size is the same as the wavefront size for all chipsets.
59+
// The wavefront size is 64 for GCN and 32 for RDNA.
60+
61+
// There are two ways we can know the subgroup size:
62+
// 1. subgroup size is passed down as part of configuration by the caller.
63+
// 2. lower subgroup size down to LLVM intrinsic:
64+
// `Intrinsic::amdgcn_wavefrontsize`, which will then be folded into a
65+
// constant according to subtarget info.
66+
67+
// TODO: change to prefer method 1 if the caller has provided a subgroup size,
68+
// otherwise use method 2. for now statically query the subgroup size
69+
// according to the chipset.
70+
if (chipset.majorVersion >= 10)
71+
return 32;
72+
return 64;
73+
}
74+
5575
/// Returns true if the given `gpu.func` can be safely called using the bare
5676
/// pointer calling convention.
5777
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
@@ -90,7 +110,7 @@ static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
90110
int64_t indexBitwidth = converter.getIndexTypeBitwidth();
91111
auto indexBitwidthType =
92112
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
93-
// TODO: use <=> in C++20
113+
// TODO: use <=> in C++20.
94114
if (indexBitwidth > intWidth) {
95115
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
96116
}
@@ -203,13 +223,21 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
203223

204224
struct GPUSubgroupIdOpToROCDL final
205225
: ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
206-
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
207226

208227
GPUSubgroupIdOpToROCDL(const LLVMTypeConverter &converter,
209-
const mlir::amdgpu::Chipset &chipset)
210-
: ConvertOpToLLVMPattern(converter), chipset(chipset) {}
228+
const mlir::amdgpu::Chipset &chipset,
229+
std::optional<int64_t> subgroupSize = std::nullopt)
230+
: ConvertOpToLLVMPattern(converter), chipset(chipset),
231+
subgroupSize(subgroupSize) {}
211232

212233
const mlir::amdgpu::Chipset chipset;
234+
const std::optional<int64_t> subgroupSize;
235+
236+
int64_t getSubgroupSize() const {
237+
if (subgroupSize)
238+
return *subgroupSize;
239+
return querySubgroupSize(chipset);
240+
}
213241

214242
LogicalResult
215243
matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
@@ -218,7 +246,7 @@ struct GPUSubgroupIdOpToROCDL final
218246
auto loc = op.getLoc();
219247
LLVM::IntegerOverflowFlags flags =
220248
LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
221-
// w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) / subgroup_size
249+
// w_id.x + w_dim.x * (w_id.y + (w_dim.y * w_id.z)) / subgroup_size
222250
Value workitemIdX = rewriter.create<ROCDL::ThreadIdXOp>(loc, int32Type);
223251
Value workitemIdY = rewriter.create<ROCDL::ThreadIdYOp>(loc, int32Type);
224252
Value workitemIdZ = rewriter.create<ROCDL::ThreadIdZOp>(loc, int32Type);
@@ -233,8 +261,9 @@ struct GPUSubgroupIdOpToROCDL final
233261
Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX =
234262
rewriter.create<LLVM::AddOp>(loc, int32Type, workitemIdX,
235263
dimYxIdZPlusIdYTimesDimX, flags);
236-
Value subgroupSize = rewriter.create<LLVM::ConstantOp>(
237-
loc, IntegerType::get(rewriter.getContext(), 32), 64);
264+
265+
Value subgroupSize =
266+
rewriter.create<LLVM::ConstantOp>(loc, int32Type, getSubgroupSize());
238267
Value waveIdOp = rewriter.create<LLVM::SDivOp>(
239268
loc, workitemIdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
240269

@@ -361,8 +390,10 @@ struct LowerGpuOpsToROCDLOpsPass final
361390

362391
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
363392
*maybeChipset);
364-
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
365-
*maybeChipset);
393+
populateGpuToROCDLConversionPatterns(
394+
converter, llvmPatterns, runtime, *maybeChipset,
395+
subgroupSize == 0 ? std::nullopt
396+
: std::optional<int64_t>(subgroupSize));
366397
configureGpuToROCDLConversionLegality(target);
367398
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
368399
signalPassFailure();
@@ -410,7 +441,8 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
410441

411442
void mlir::populateGpuToROCDLConversionPatterns(
412443
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
413-
mlir::gpu::amd::Runtime runtime, mlir::amdgpu::Chipset chipset) {
444+
mlir::gpu::amd::Runtime runtime, mlir::amdgpu::Chipset chipset,
445+
std::optional<int64_t> subgroupSize) {
414446
using gpu::index_lowering::IndexKind;
415447
using gpu::index_lowering::IntrType;
416448
using mlir::gpu::amd::Runtime;
@@ -449,7 +481,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
449481
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
450482

451483
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
452-
patterns.add<GPUSubgroupIdOpToROCDL>(converter, chipset);
484+
patterns.add<GPUSubgroupIdOpToROCDL>(converter, chipset, subgroupSize);
453485
populateMathToROCDLConversionPatterns(converter, patterns);
454486
}
455487

0 commit comments

Comments
 (0)