Skip to content

[MLIR][GPU] Add a pattern to rewrite gpu.subgroup_id #137671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 29, 2025

Conversation

lialan
Copy link
Member

@lialan lialan commented Apr 28, 2025

This patch impelemnts a rewrite pattern for transforming gpu.subgroup_id to:

subgroup_id = linearized_thread_id / gpu.subgroup_size

where:

linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z)

Copy link

github-actions bot commented Apr 28, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@lialan lialan force-pushed the lialan/rewrite_subgroup_id branch from 2cc88a6 to 8c603f0 Compare April 28, 2025 17:22
This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id`
to:
```
subgroup_id = linearized_thread_id / gpu.subgroup_size
```

where:
```
linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z)
```
@lialan lialan force-pushed the lialan/rewrite_subgroup_id branch from 8c603f0 to a19415f Compare April 28, 2025 18:14
@lialan lialan marked this pull request as ready for review April 28, 2025 18:14
@llvmbot
Copy link
Member

llvmbot commented Apr 28, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Alan Li (lialan)

Changes

This patch impelemnts a rewrite pattern for transforming gpu.subgroup_id to:

subgroup_id = linearized_thread_id / gpu.subgroup_size

where:

linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z)

Full diff: https://github.com/llvm/llvm-project/pull/137671.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/Transforms/Passes.h (+5)
  • (modified) mlir/lib/Dialect/GPU/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp (+82)
  • (added) mlir/test/Dialect/GPU/subgroupId-rewrite.mlir (+26)
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index a13ad33df29cd..cbb990e603a38 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -39,6 +39,10 @@ class FuncOp;
 /// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
 void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns to rewrite SubgroupIdOp op within the GPU
+/// dialect.
+void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns);
+
 /// Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
 void populateGpuShufflePatterns(RewritePatternSet &patterns);
 
@@ -88,6 +92,7 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
   populateGpuGlobalIdPatterns(patterns);
   populateGpuShufflePatterns(patterns);
+  populateGpuSubgroupIdPatterns(patterns);
 }
 
 namespace gpu {
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index be6492a22f34f..e21fa501bae6b 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/ROCDLAttachTarget.cpp
   Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
+  Transforms/SubgroupIdRewriter.cpp
   Transforms/SubgroupReduceLowering.cpp
 
   OBJECT
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
new file mode 100644
index 0000000000000..1c322c1016c01
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
@@ -0,0 +1,82 @@
+//===- SubgroupIdRewriter.cpp - Implementation of SugroupId rewriting  ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
+// where:
+// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
+  using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
+                                PatternRewriter &rewriter) const override {
+    // Calculation of the thread's subgroup identifier.
+    //
+    // The process involves mapping the thread's 3D identifier within its
+    // block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
+    // This linearization assumes a layout where the x-dimension (w_dim.x)
+    // varies most rapidly (i.e., it is the innermost dimension).
+    //
+    // The formula for the linearized thread index is:
+    // L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
+    //
+    // Subsequently, the range of linearized indices [0, N_threads-1] is
+    // divided into consecutive, non-overlapping segments, each representing
+    // a subgroup of size 'subgroup_size'.
+    //
+    // Example Partitioning (N = subgroup_size):
+    // | Subgroup 0      | Subgroup 1      | Subgroup 2      | ... |
+    // | Indices 0..N-1  | Indices N..2N-1 | Indices 2N..3N-1| ... |
+    //
+    // The subgroup identifier is obtained via integer division of the
+    // linearized thread index by the predefined 'subgroup_size'.
+    //
+    // subgroup_id = floor( L / subgroup_size )
+    //             = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
+    //             subgroup_size
+
+    auto loc = op->getLoc();
+
+    Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
+    Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
+    Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
+    Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
+    Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);
+
+    Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ);
+    Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
+    Value dimYxIdZPlusIdYTimesDimX =
+        rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
+    Value IdXPlusDimYxIdZPlusIdYTimesDimX =
+        rewriter.create<index::AddOp>(loc, tidX, dimYxIdZPlusIdYTimesDimX);
+    Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
+        loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
+    Value subgroupIdOp = rewriter.create<index::DivUOp>(
+        loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
+    rewriter.replaceOp(op, {subgroupIdOp});
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) {
+  patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
new file mode 100644
index 0000000000000..02fcb2ba21dad
--- /dev/null
+++ b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
+
+module {
+  // CHECK-LABEL: func.func @subgroupId
+  // CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
+  func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+               threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+      // CHECK: %[[DIMX:.*]] = gpu.block_dim  x
+      // CHECK-NEXT: %[[DIMY:.*]] = gpu.block_dim  y
+      // CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id  x
+      // CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id  y
+      // CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id  z
+      // CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
+      // CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
+      // CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
+      // CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
+      // CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
+      // CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
+      %idz = gpu.subgroup_id : index
+      memref.store %idz, %mem[] : memref<index, 1>
+      gpu.terminator
+    }
+    return
+  }
+}

@lialan
Copy link
Member Author

lialan commented Apr 28, 2025

@krzysz00 suggests we can move decomposing gpu.subgroup_id to within gpu dialect.

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a new rewrite pattern for the GPU dialect to transform gpu.subgroup_id using a linearized thread ID calculation.

  • Introduces the GpuSubgroupIdRewriter rewrite pattern in SubgroupIdRewriter.cpp
  • Updates the passes header to register the new rewrite pattern

Reviewed Changes

Copilot reviewed 2 out of 4 changed files in this pull request and generated 1 comment.

File Description
mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp Implements the rewrite pattern for gpu.subgroup_id.
mlir/include/mlir/Dialect/GPU/Transforms/Passes.h Registers the new SubgroupId rewrite pattern.
Files not reviewed (2)
  • mlir/lib/Dialect/GPU/CMakeLists.txt: Language not supported
  • mlir/test/Dialect/GPU/subgroupId-rewrite.mlir: Language not supported

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor notes, ideas seem fine

@@ -88,6 +92,7 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
populateGpuGlobalIdPatterns(patterns);
populateGpuShufflePatterns(patterns);
populateGpuSubgroupIdPatterns(patterns);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure that this doesn't end up in the SPIR-V lowerings, which seem to have an alternate approach to this op

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. SPV links get_sub_group_id function to get the subgroup id. So I think we should just remove this line inside populateGpuRewritePatterns.

side note: SPIRV uses the same calculation method to compute subgroup_id.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I eventually stripped it from this function.

Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ);
Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
Value dimYxIdZPlusIdYTimesDimX =
rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think arith:: over index:: are fine here - none of these values are at the point where the stuff that caused the index dialect to come into existence is a problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Hey, why'd this get landed with index::?

@@ -0,0 +1,26 @@
// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s

module {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This contains both an explicit module and and implicit one -- I don't think we need both. Can we drop either one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@lialan lialan force-pushed the lialan/rewrite_subgroup_id branch from bbde763 to e316a6e Compare April 29, 2025 03:04
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
@lialan lialan requested review from kuhar and krzysz00 April 29, 2025 13:26
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % use of the index dialect

@lialan lialan merged commit ac65b2c into llvm:main Apr 29, 2025
11 checks passed
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This patch impelemnts a rewrite pattern for transforming
`gpu.subgroup_id` to:
```
subgroup_id = linearized_thread_id / gpu.subgroup_size
```

where:
```
linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z)
```
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This patch impelemnts a rewrite pattern for transforming
`gpu.subgroup_id` to:
```
subgroup_id = linearized_thread_id / gpu.subgroup_size
```

where:
```
linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z)
```
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This patch impelemnts a rewrite pattern for transforming
`gpu.subgroup_id` to:
```
subgroup_id = linearized_thread_id / gpu.subgroup_size
```

where:
```
linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z)
```
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
This patch impelemnts a rewrite pattern for transforming
`gpu.subgroup_id` to:
```
subgroup_id = linearized_thread_id / gpu.subgroup_size
```

where:
```
linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z)
```
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
This patch impelemnts a rewrite pattern for transforming
`gpu.subgroup_id` to:
```
subgroup_id = linearized_thread_id / gpu.subgroup_size
```

where:
```
linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants