Skip to content

[MLIR][ROCDL] Add conversion for gpu.subgroup_id to ROCDL #136405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

lialan
Copy link
Member

@lialan lialan commented Apr 19, 2025

This patch creates a path to lower gpu.subgroup_id. Also removes some code in LowerGpuToROCDLOpsPass.

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a path to convert gpu.subgroup_id to ROCDL by creating a new conversion pattern that maps the operation to a ROCDL wave_id op and further to the corresponding LLVM intrinsic. Key changes include:

  • Addition of a helper template function truncOrExtToLLVMType to handle bitwidth adjustments.
  • Refactoring of GPULaneIdOp conversion to use the new helper function.
  • Introduction of a new GPUSubgroupIdOpToROCDL conversion pattern for handling gpu.subgroup_id.
Files not reviewed (3)
  • mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td: Language not supported
  • mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir: Language not supported
  • mlir/test/Target/LLVMIR/rocdl.mlir: Language not supported
Comments suppressed due to low confidence (2)

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp:86

  • [nitpick] Consider renaming the template parameter N to something more descriptive like 'expectedBitwidth' for improved clarity.
template <int64_t N>

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp:211

  • [nitpick] Consider adding an inline comment here to explain the purpose of applying truncOrExtToLLVMType to waveIdOp for consistency with the GPULaneIdOp conversion.
waveIdOp = truncOrExtToLLVMType<32>(rewriter, op.getLoc(), waveIdOp, getTypeConverter()->getIndexTypeBitwidth());

@lialan lialan marked this pull request as ready for review April 19, 2025 05:35
@lialan lialan requested review from krzysz00 and kuhar April 19, 2025 05:35
@llvmbot
Copy link
Member

llvmbot commented Apr 19, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Alan Li (lialan)

Changes

This patch creates a path to convert gpu.subgroup_id to rocdl.wave_id op, then to __builtin_amdgcn_s_get_waveid_in_workgroup intrinsic.


Full diff: https://github.com/llvm/llvm-project/pull/136405.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+8)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+37-11)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+7-3)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+6)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 186a4f53f93cb..09d22da0d4c72 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -204,6 +204,14 @@ def ROCDL_ReadlaneOp : ROCDL_IntrOp<"readlane", [], [0], [AllTypesMatch<["res",
    }];
 }
 
+// the intrinsic function name is too long so we use a shorter name for rocdl.
+def ROCDL_WaveIdOp :  LLVM_IntrOpBase<ROCDL_Dialect, "wave_id",
+                        "amdgcn_s_get_waveid_in_workgroup", [], [], [Pure], 1>,
+  Arguments<(ins)> {
+  let results = (outs LLVM_Type:$res);
+  let assemblyFormat = "attr-dict `:` type($res)";
+}
+
 //===----------------------------------------------------------------------===//
 // Thread index and Block index
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index e6dd6f135884e..315bc7157cd83 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -80,6 +80,24 @@ static constexpr StringLiteral amdgcnDataLayout =
     "64-S32-A5-G1-ni:7:8:9";
 
 namespace {
+
+// Truncate or extend the result depending on the index bitwidth specified
+// by the LLVMTypeConverter options.
+static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
+                                  Location loc, Value value,
+                                  const LLVMTypeConverter *converter) {
+  auto intWidth = cast<IntegerType>(value.getType()).getWidth();
+  auto indexBitwidth = converter->getIndexTypeBitwidth();
+  if (indexBitwidth > intWidth) {
+    return rewriter.create<LLVM::SExtOp>(
+        loc, IntegerType::get(rewriter.getContext(), indexBitwidth), value);
+  } else if (indexBitwidth < intWidth) {
+    return rewriter.create<LLVM::TruncOp>(
+        loc, IntegerType::get(rewriter.getContext(), indexBitwidth), value);
+  }
+  return value;
+}
+
 struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
   using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
 
@@ -98,16 +116,7 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
         rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
     Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
         loc, intTy, ValueRange{minus1, mbcntLo});
-    // Truncate or extend the result depending on the index bitwidth specified
-    // by the LLVMTypeConverter options.
-    const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
-    if (indexBitwidth > 32) {
-      laneId = rewriter.create<LLVM::SExtOp>(
-          loc, IntegerType::get(context, indexBitwidth), laneId);
-    } else if (indexBitwidth < 32) {
-      laneId = rewriter.create<LLVM::TruncOp>(
-          loc, IntegerType::get(context, indexBitwidth), laneId);
-    }
+    laneId = truncOrExtToLLVMType(rewriter, loc, laneId, getTypeConverter());
     rewriter.replaceOp(op, {laneId});
     return success();
   }
@@ -190,6 +199,21 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   }
 };
 
+struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
+  using ConvertOpToLLVMPattern<gpu::SubgroupIdOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+    Value waveIdOp = rewriter.create<ROCDL::WaveIdOp>(op.getLoc(), int32Type);
+    waveIdOp = truncOrExtToLLVMType(rewriter, op.getLoc(), waveIdOp,
+                                    getTypeConverter());
+    rewriter.replaceOp(op, {waveIdOp});
+    return success();
+  }
+};
+
 /// Import the GPU Ops to ROCDL Patterns.
 #include "GPUToROCDL.cpp.inc"
 
@@ -405,7 +429,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
   // TODO: Add alignment for workgroup memory
   patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
 
-  patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
+  patterns
+      .add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>(
+          converter);
 
   populateMathToROCDLConversionPatterns(converter, patterns);
 }
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 071cae9d5789f..a06b77dcff038 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -11,7 +11,7 @@ gpu.module @test_module {
   func.func @gpu_index_ops()
       -> (index, index, index, index, index, index,
           index, index, index, index, index, index,
-          index) {
+          index, index) {
     // CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64
 
     // CHECK: rocdl.workitem.id.x : i32
@@ -59,12 +59,16 @@ gpu.module @test_module {
     // CHECK: = llvm.sext %{{.*}} : i32 to i64
     %laneId = gpu.lane_id
 
+    // CHECK: = rocdl.wave_id : i32
+    // CHECK: = llvm.sext %{{.*}} : i32 to i64
+    %waveId = gpu.subgroup_id : index
+
     func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
                %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ,
-               %laneId
+               %laneId, %waveId
         : index, index, index, index, index, index,
           index, index, index, index, index, index,
-          index
+          index, index
   }
 }
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 3db1f7b2b6427..f5767dd1fc95a 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -88,6 +88,12 @@ llvm.func @rocdl.lane_id() -> i32 {
   llvm.return %3 : i32
 }
 
+llvm.func @rocdl.wave_id() -> i32 {
+  // CHECK: call i32 @llvm.amdgcn.s.get.waveid.in.workgroup()
+  %0 = rocdl.wave_id : i32
+  llvm.return %0 : i32
+}
+
 llvm.func @rocdl.swizzle(%src : i32) -> i32 {
   // CHECK-LABEL: rocdl.swizzle
   // CHECK: call i32 @llvm.amdgcn.ds.swizzle

auto intWidth = cast<IntegerType>(value.getType()).getWidth();
auto indexBitwidth = converter->getIndexTypeBitwidth();
if (indexBitwidth > intWidth) {
return rewriter.create<LLVM::SExtOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00 side question: does sext vs zext make any perf difference for AMDGPU?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00 Echo here: should we generally prefer sext over zext at lower level?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. I suspect that we want zext nneg when we know it's equivalent - ex, on these IDs that are guaranteed to be in [0, i32_signed_max)or some stricter range. However, there's an assumption in much ofaffineand co thatindex` is signed when there's an ambiguity.

Int range optimization often creates sext/trunc pairs for these things in practice.

I suspect that for now we want sextfor MLIR-semantic reasons and then to filter it down later in the backend, especially if we can stick a range(i32 0, [upper bound we actually know]) on the intrinsic.

... Heck, even in the absence of the actual subgroup size, we do know said upper bound: 1024 / [wave size], aka 1024 / 32 = 32 ... which we can just hint to LLVM

(Though see also my note about how this intrinsic basically doesn't exist on any GPU of interest)

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, missing a test in mlir/test/Dialect/LLVMIR/rocdl.mlir for the ROCDL op

LogicalResult
matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset.majorVersion < 10) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a valid implementation: workitem_id.x + workitem_dim.x * (workitem_id.y + workitem_dim.y * workitem_id.z)) / 64, with all those adds and muls being nuw and nsw

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Though I prefer sext indices to i64 and avoid nuw and nsw.

@lialan
Copy link
Member Author

lialan commented Apr 23, 2025

Also, missing a test in mlir/test/Dialect/LLVMIR/rocdl.mlir for the ROCDL op

@krzysz00 well the new op is gone so no need.

@lialan lialan requested review from krzysz00, kuhar and Hardcode84 April 23, 2025 00:31
Comment on lines 236 to 237
Value subgroupSize = rewriter.create<LLVM::ConstantOp>(
loc, IntegerType::get(rewriter.getContext(), 32), 64);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it safe to hardcode this for gfx9 here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, there is a subgroup_size op for this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for now I statically query the wavefront size according to chipset info. I think in the future we should opt to use a more accurate approach. Left a TODO.

@lialan lialan requested a review from kuhar April 23, 2025 18:23
@lialan lialan requested a review from kuhar April 23, 2025 22:35
Option<"subgroupSize", "subgroup-size", "unsigned",
"0",
"specify subgroup size for the kernel, if left empty, the default "
"value will be decided by the target chipset.">,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a sensible thing to do? You have choosers that can be +wavesize64 and it's quite possible someone might end up not setting this flag. Might be worth seeing if LLVM has something for subgroup size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd go for wavefrontsize if the user doesn't specify

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@krzysz00 there is the TODO https://github.com/llvm/llvm-project/pull/136405/files#diff-cd4257dddc1cb3043071e5c7641774615ffd685cc779acf70a47a3e83401b515R67.

Currently wavefrontsize is not exposed to rocdl yet, I prefer to do it after this is merged. Think this is a good idea?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd strongly prefer to use subgroup size here instead of a pass option

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kuhar subgroup size op is added, so I have changed this to divide by using subgroup size op

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Can we remove this option then?

Option<"subgroupSize", "subgroup-size", "unsigned",
"0",
"specify subgroup size for the kernel, if left empty, the default "
"value will be decided by the target chipset.">,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Can we remove this option then?

@lialan lialan requested review from kuhar and krzysz00 April 28, 2025 14:18
: ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer that this not be a rewrite pattern on rocdl, but a rewrite pattern that's gpu => gpu, and that can be applied before lowering to rocdl

(That way, downstream, we can run this pattern before PRopagateDispatchSizeBounds)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea. I created a new PR for it: #137671

@lialan
Copy link
Member Author

lialan commented Apr 28, 2025

Closing this in favor of #137671 as per discussion in #136405 (comment)

@lialan lialan closed this Apr 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants