-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][GPU] Add a pattern to rewrite gpu.subgroup_id #137671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
2cc88a6
to
8c603f0
Compare
This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id` to: ``` subgroup_id = linearized_thread_id / gpu.subgroup_size ``` where: ``` linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z) ```
8c603f0
to
a19415f
Compare
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Alan Li (lialan) ChangesThis patch impelemnts a rewrite pattern for transforming
where:
Full diff: https://github.com/llvm/llvm-project/pull/137671.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index a13ad33df29cd..cbb990e603a38 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -39,6 +39,10 @@ class FuncOp;
/// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);
+/// Collect a set of patterns to rewrite SubgroupIdOp op within the GPU
+/// dialect.
+void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns);
+
/// Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
void populateGpuShufflePatterns(RewritePatternSet &patterns);
@@ -88,6 +92,7 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
populateGpuGlobalIdPatterns(patterns);
populateGpuShufflePatterns(patterns);
+ populateGpuSubgroupIdPatterns(patterns);
}
namespace gpu {
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index be6492a22f34f..e21fa501bae6b 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ROCDLAttachTarget.cpp
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
+ Transforms/SubgroupIdRewriter.cpp
Transforms/SubgroupReduceLowering.cpp
OBJECT
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
new file mode 100644
index 0000000000000..1c322c1016c01
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
@@ -0,0 +1,82 @@
+//===- SubgroupIdRewriter.cpp - Implementation of SugroupId rewriting ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
+// where:
+// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
+ using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
+ PatternRewriter &rewriter) const override {
+ // Calculation of the thread's subgroup identifier.
+ //
+ // The process involves mapping the thread's 3D identifier within its
+ // block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
+ // This linearization assumes a layout where the x-dimension (w_dim.x)
+ // varies most rapidly (i.e., it is the innermost dimension).
+ //
+ // The formula for the linearized thread index is:
+ // L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
+ //
+ // Subsequently, the range of linearized indices [0, N_threads-1] is
+ // divided into consecutive, non-overlapping segments, each representing
+ // a subgroup of size 'subgroup_size'.
+ //
+ // Example Partitioning (N = subgroup_size):
+ // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
+ // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
+ //
+ // The subgroup identifier is obtained via integer division of the
+ // linearized thread index by the predefined 'subgroup_size'.
+ //
+ // subgroup_id = floor( L / subgroup_size )
+ // = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
+ // subgroup_size
+
+ auto loc = op->getLoc();
+
+ Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
+ Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
+ Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
+ Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
+ Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);
+
+ Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ);
+ Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
+ Value dimYxIdZPlusIdYTimesDimX =
+ rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
+ Value IdXPlusDimYxIdZPlusIdYTimesDimX =
+ rewriter.create<index::AddOp>(loc, tidX, dimYxIdZPlusIdYTimesDimX);
+ Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
+ loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
+ Value subgroupIdOp = rewriter.create<index::DivUOp>(
+ loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
+ rewriter.replaceOp(op, {subgroupIdOp});
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) {
+ patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
new file mode 100644
index 0000000000000..02fcb2ba21dad
--- /dev/null
+++ b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
+
+module {
+ // CHECK-LABEL: func.func @subgroupId
+ // CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
+ func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+ threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+ // CHECK: %[[DIMX:.*]] = gpu.block_dim x
+ // CHECK-NEXT: %[[DIMY:.*]] = gpu.block_dim y
+ // CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x
+ // CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y
+ // CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z
+ // CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
+ // CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
+ // CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
+ // CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
+ // CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
+ // CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
+ %idz = gpu.subgroup_id : index
+ memref.store %idz, %mem[] : memref<index, 1>
+ gpu.terminator
+ }
+ return
+ }
+}
|
@krzysz00 suggests we can move decomposing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a new rewrite pattern for the GPU dialect to transform gpu.subgroup_id using a linearized thread ID calculation.
- Introduces the GpuSubgroupIdRewriter rewrite pattern in SubgroupIdRewriter.cpp
- Updates the passes header to register the new rewrite pattern
Reviewed Changes
Copilot reviewed 2 out of 4 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp | Implements the rewrite pattern for gpu.subgroup_id. |
mlir/include/mlir/Dialect/GPU/Transforms/Passes.h | Registers the new SubgroupId rewrite pattern. |
Files not reviewed (2)
- mlir/lib/Dialect/GPU/CMakeLists.txt: Language not supported
- mlir/test/Dialect/GPU/subgroupId-rewrite.mlir: Language not supported
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor notes, ideas seem fine
@@ -88,6 +92,7 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) { | |||
populateGpuAllReducePatterns(patterns); | |||
populateGpuGlobalIdPatterns(patterns); | |||
populateGpuShufflePatterns(patterns); | |||
populateGpuSubgroupIdPatterns(patterns); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure that this doesn't end up in the SPIR-V lowerings, which seem to have an alternate approach to this op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. SPV links get_sub_group_id
function to get the subgroup id. So I think we should just remove this line inside populateGpuRewritePatterns
.
side note: SPIRV uses the same calculation method to compute subgroup_id.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I eventually stripped it from this function.
Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ); | ||
Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY); | ||
Value dimYxIdZPlusIdYTimesDimX = | ||
rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think arith::
over index::
are fine here - none of these values are at the point where the stuff that caused the index dialect to come into existence is a problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... Hey, why'd this get landed with index::
?
e3e30d1
to
bbde763
Compare
@@ -0,0 +1,26 @@ | |||
// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s | |||
|
|||
module { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This contains both an explicit module and and implicit one -- I don't think we need both. Can we drop either one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
bbde763
to
e316a6e
Compare
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % use of the index dialect
This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id` to: ``` subgroup_id = linearized_thread_id / gpu.subgroup_size ``` where: ``` linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z) ```
This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id` to: ``` subgroup_id = linearized_thread_id / gpu.subgroup_size ``` where: ``` linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z) ```
This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id` to: ``` subgroup_id = linearized_thread_id / gpu.subgroup_size ``` where: ``` linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z) ```
This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id` to: ``` subgroup_id = linearized_thread_id / gpu.subgroup_size ``` where: ``` linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z) ```
This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id` to: ``` subgroup_id = linearized_thread_id / gpu.subgroup_size ``` where: ``` linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z) ```
This patch impelemnts a rewrite pattern for transforming
gpu.subgroup_id
to:where: