Skip to content

[MLIR][GPU] Add a pattern to rewrite gpu.subgroup_id #137671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class FuncOp;
/// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);

/// Collect a set of patterns to rewrite SubgroupIdOp op within the GPU
/// dialect.
void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns);

/// Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
void populateGpuShufflePatterns(RewritePatternSet &patterns);

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ROCDLAttachTarget.cpp
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
Transforms/SubgroupIdRewriter.cpp
Transforms/SubgroupReduceLowering.cpp

OBJECT
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct GpuGlobalIdRewriter : public OpRewritePattern<gpu::GlobalIdOp> {

LogicalResult matchAndRewrite(gpu::GlobalIdOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Location loc = op.getLoc();
auto dim = op.getDimension();
auto blockId = rewriter.create<gpu::BlockIdOp>(loc, dim);
auto blockDim = rewriter.create<gpu::BlockDimOp>(loc, dim);
Expand Down
82 changes: 82 additions & 0 deletions mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
// where:
// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {
struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;

LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
PatternRewriter &rewriter) const override {
// Calculation of the thread's subgroup identifier.
//
// The process involves mapping the thread's 3D identifier within its
// block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
// This linearization assumes a layout where the x-dimension (w_dim.x)
// varies most rapidly (i.e., it is the innermost dimension).
//
// The formula for the linearized thread index is:
// L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
//
// Subsequently, the range of linearized indices [0, N_threads-1] is
// divided into consecutive, non-overlapping segments, each representing
// a subgroup of size 'subgroup_size'.
//
// Example Partitioning (N = subgroup_size):
// | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
// | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
//
// The subgroup identifier is obtained via integer division of the
// linearized thread index by the predefined 'subgroup_size'.
//
// subgroup_id = floor( L / subgroup_size )
// = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
// subgroup_size

Location loc = op->getLoc();

Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);

Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ);
Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
Value dimYxIdZPlusIdYTimesDimX =
rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think arith:: over index:: are fine here - none of these values are at the point where the stuff that caused the index dialect to come into existence is a problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Hey, why'd this get landed with index::?

Value IdXPlusDimYxIdZPlusIdYTimesDimX =
rewriter.create<index::AddOp>(loc, tidX, dimYxIdZPlusIdYTimesDimX);
Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
Value subgroupIdOp = rewriter.create<index::DivUOp>(
loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
rewriter.replaceOp(op, {subgroupIdOp});
return success();
}
};

} // namespace

void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) {
patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
}
24 changes: 24 additions & 0 deletions mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s

// CHECK-LABEL: func.func @subgroupId
// CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
// CHECK: %[[DIMX:.*]] = gpu.block_dim x
// CHECK-NEXT: %[[DIMY:.*]] = gpu.block_dim y
// CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x
// CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y
// CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z
// CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
// CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
// CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
// CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
// CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
// CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
%idz = gpu.subgroup_id : index
memref.store %idz, %mem[] : memref<index, 1>
gpu.terminator
}
return
}
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct TestGpuRewritePass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateGpuRewritePatterns(patterns);
populateGpuSubgroupIdPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
Expand Down