Skip to content

Commit 71b4ca0

Browse files
committed
[MLIR][ROCDL] Lower gpu.subgroup_id to wavefrontsize
1 parent 34f3466 commit 71b4ca0

File tree

4 files changed

+49
-4
lines changed

4 files changed

+49
-4
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def ROCDL_BlockIdXOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.x">;
216216
def ROCDL_BlockIdYOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.y">;
217217
def ROCDL_BlockIdZOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.z">;
218218

219+
def ROCDL_WavefrontSizeOp : ROCDL_SpecialIdRegisterOp<"wavefrontsize">;
220+
219221
//===----------------------------------------------------------------------===//
220222
// Thread range and Block range
221223
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,25 @@ namespace mlir {
5252

5353
using namespace mlir;
5454

55+
// Truncate or extend the result depending on the index bitwidth specified
56+
// by the LLVMTypeConverter options.
57+
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
58+
Location loc, Value value,
59+
const LLVMTypeConverter &converter) {
60+
int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
61+
int64_t indexBitwidth = converter.getIndexTypeBitwidth();
62+
auto indexBitwidthType =
63+
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
64+
// TODO: use <=> in C++20.
65+
if (indexBitwidth > intWidth) {
66+
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
67+
}
68+
if (indexBitwidth < intWidth) {
69+
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
70+
}
71+
return value;
72+
}
73+
5574
/// Returns true if the given `gpu.func` can be safely called using the bare
5675
/// pointer calling convention.
5776
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
@@ -113,6 +132,20 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
113132
}
114133
};
115134

135+
struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
136+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
137+
LogicalResult
138+
matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
139+
ConversionPatternRewriter &rewriter) const override {
140+
Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
141+
op.getLoc(), IntegerType::get(rewriter.getContext(), 32));
142+
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
143+
*getTypeConverter());
144+
rewriter.replaceOp(op, {wavefrontOp});
145+
return success();
146+
}
147+
};
148+
116149
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
117150
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
118151

@@ -405,7 +438,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
405438
// TODO: Add alignment for workgroup memory
406439
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
407440

408-
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
441+
patterns
442+
.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupSizeOpToROCDL>(
443+
converter);
409444

410445
populateMathToROCDLConversionPatterns(converter, patterns);
411446
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ gpu.module @test_module {
1111
func.func @gpu_index_ops()
1212
-> (index, index, index, index, index, index,
1313
index, index, index, index, index, index,
14-
index) {
14+
index, index) {
1515
// CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64
1616

1717
// CHECK: rocdl.workitem.id.x : i32
@@ -59,12 +59,16 @@ gpu.module @test_module {
5959
// CHECK: = llvm.sext %{{.*}} : i32 to i64
6060
%laneId = gpu.lane_id
6161

62+
// CHECK: = rocdl.wavefrontsize : i32
63+
// CHECK: = llvm.sext %{{.*}} : i32 to i64
64+
%subgroupSize = gpu.subgroup_size : index
65+
6266
func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
6367
%bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ,
64-
%laneId
68+
%laneId, %subgroupSize
6569
: index, index, index, index, index, index,
6670
index, index, index, index, index, index,
67-
index
71+
index, index
6872
}
6973
}
7074

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ llvm.func @rocdl_special_regs() -> i32 {
3232

3333
// CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0)
3434
%14 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64
35+
36+
// CHECK: call i64 $llvm.amdgcn.wavefrontsize()
37+
%15 = rocdl.wavefrontsize : i32
38+
3539
llvm.return %1 : i32
3640
}
3741

0 commit comments

Comments
 (0)