Skip to content

[MLIR][ROCDL] Add conversion for gpu.subgroup_id to ROCDL #136405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ void configureGpuToROCDLConversionLegality(ConversionTarget &target);
/// index bitwidth used for the lowering of the device side index computations
/// is configurable.
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
createLowerGpuOpsToROCDLOpsPass(
const std::string &chipset = "gfx900",
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
bool useBarePtrCallConv = false,
gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
createLowerGpuOpsToROCDLOpsPass();

} // namespace mlir

Expand Down
142 changes: 89 additions & 53 deletions mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,6 @@ namespace mlir {

using namespace mlir;

// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
Location loc, Value value,
const LLVMTypeConverter &converter) {
int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
int64_t indexBitwidth = converter.getIndexTypeBitwidth();
auto indexBitwidthType =
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
// TODO: use <=> in C++20.
if (indexBitwidth > intWidth) {
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
}
if (indexBitwidth < intWidth) {
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
}
return value;
}

/// Returns true if the given `gpu.func` can be safely called using the bare
/// pointer calling convention.
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
Expand Down Expand Up @@ -99,6 +80,25 @@ static constexpr StringLiteral amdgcnDataLayout =
"64-S32-A5-G1-ni:7:8:9";

namespace {

// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
Location loc, Value value,
const LLVMTypeConverter &converter) {
int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
int64_t indexBitwidth = converter.getIndexTypeBitwidth();
auto indexBitwidthType =
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
if (indexBitwidth > intWidth) {
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
}
if (indexBitwidth < intWidth) {
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
}
return value;
}

struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;

Expand All @@ -117,16 +117,7 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
loc, intTy, ValueRange{minus1, mbcntLo});
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
if (indexBitwidth > 32) {
laneId = rewriter.create<LLVM::SExtOp>(
loc, IntegerType::get(context, indexBitwidth), laneId);
} else if (indexBitwidth < 32) {
laneId = rewriter.create<LLVM::TruncOp>(
loc, IntegerType::get(context, indexBitwidth), laneId);
}
laneId = truncOrExtToLLVMType(rewriter, loc, laneId, *getTypeConverter());
rewriter.replaceOp(op, {laneId});
return success();
}
Expand All @@ -150,11 +141,11 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
/*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
/*upper=*/op.getUpperBoundAttr().getInt() + 1);
}
Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
Value wavefrontSizeOp = rewriter.create<ROCDL::WavefrontSizeOp>(
op.getLoc(), rewriter.getI32Type(), bounds);
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
*getTypeConverter());
rewriter.replaceOp(op, {wavefrontOp});
wavefrontSizeOp = truncOrExtToLLVMType(
rewriter, op.getLoc(), wavefrontSizeOp, *getTypeConverter());
rewriter.replaceOp(op, {wavefrontSizeOp});
return success();
}

Expand Down Expand Up @@ -239,6 +230,65 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
}
};

struct GPUSubgroupIdOpToROCDL final
: ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer that this not be a rewrite pattern on rocdl, but a rewrite pattern that's gpu => gpu, and that can be applied before lowering to rocdl

(That way, downstream, we can run this pattern before PRopagateDispatchSizeBounds)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea. I created a new PR for it: #137671

matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Calculation of the thread's subgroup identifier.
//
// The process involves mapping the thread's 3D identifier within its
// workgroup/block (w_id.x, w_id.y, w_id.z) to a 1D linear index.
// This linearization assumes a layout where the x-dimension (w_dim.x)
// varies most rapidly (i.e., it is the innermost dimension).
//
// The formula for the linearized thread index is:
// L = w_id.x + w_dim.x * (w_id.y + (w_dim.y * w_id.z))
//
// Subsequently, the range of linearized indices [0, N_threads-1] is
// divided into consecutive, non-overlapping segments, each representing
// a subgroup of size 'subgroup_size'.
//
// Example Partitioning (N = subgroup_size):
// | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
// | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
//
// The subgroup identifier is obtained via integer division of the
// linearized thread index by the predefined 'subgroup_size'.
//
// subgroup_id = floor( L / subgroup_size )
// = (w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) /
// subgroup_size
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Location loc = op.getLoc();
LLVM::IntegerOverflowFlags flags =
LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
Value workitemIdX = rewriter.create<ROCDL::ThreadIdXOp>(loc, int32Type);
Value workitemIdY = rewriter.create<ROCDL::ThreadIdYOp>(loc, int32Type);
Value workitemIdZ = rewriter.create<ROCDL::ThreadIdZOp>(loc, int32Type);
Value workitemDimX = rewriter.create<ROCDL::BlockDimXOp>(loc, int32Type);
Value workitemDimY = rewriter.create<ROCDL::BlockDimYOp>(loc, int32Type);
Value dimYxIdZ = rewriter.create<LLVM::MulOp>(loc, int32Type, workitemDimY,
workitemIdZ, flags);
Value dimYxIdZPlusIdY = rewriter.create<LLVM::AddOp>(
loc, int32Type, dimYxIdZ, workitemIdY, flags);
Value dimYxIdZPlusIdYTimesDimX = rewriter.create<LLVM::MulOp>(
loc, int32Type, workitemDimX, dimYxIdZPlusIdY, flags);
Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX =
rewriter.create<LLVM::AddOp>(loc, int32Type, workitemIdX,
dimYxIdZPlusIdYTimesDimX, flags);
Value subgroupSize = rewriter.create<ROCDL::WavefrontSizeOp>(
loc, rewriter.getI32Type(), /*upper_bound = */ nullptr);
Value waveIdOp = rewriter.create<LLVM::UDivOp>(
loc, workitemIdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
rewriter.replaceOp(op, {truncOrExtToLLVMType(rewriter, loc, waveIdOp,
*getTypeConverter())});
return success();
}
};

/// Import the GPU Ops to ROCDL Patterns.
#include "GPUToROCDL.cpp.inc"

Expand All @@ -249,19 +299,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
// code.
struct LowerGpuOpsToROCDLOpsPass final
: public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
LowerGpuOpsToROCDLOpsPass() = default;
LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
bool useBarePtrCallConv,
gpu::amd::Runtime runtime) {
if (this->chipset.getNumOccurrences() == 0)
this->chipset = chipset;
if (this->indexBitwidth.getNumOccurrences() == 0)
this->indexBitwidth = indexBitwidth;
if (this->useBarePtrCallConv.getNumOccurrences() == 0)
this->useBarePtrCallConv = useBarePtrCallConv;
if (this->runtime.getNumOccurrences() == 0)
this->runtime = runtime;
}
using Base::Base;

void getDependentDialects(DialectRegistry &registry) const override {
Base::getDependentDialects(registry);
Expand Down Expand Up @@ -455,17 +493,15 @@ void mlir::populateGpuToROCDLConversionPatterns(
// TODO: Add alignment for workgroup memory
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);

patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
patterns
.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>(
converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);

populateMathToROCDLConversionPatterns(converter, patterns);
}

std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
unsigned indexBitwidth,
bool useBarePtrCallConv,
gpu::amd::Runtime runtime) {
return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
chipset, indexBitwidth, useBarePtrCallConv, runtime);
mlir::createLowerGpuOpsToROCDLOpsPass() {
return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
}
22 changes: 22 additions & 0 deletions mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,25 @@ gpu.module @test_module {
gpu.module @test_custom_data_layout attributes {llvm.data_layout = "e"} {

}

// -----

gpu.module @test_module {
// CHECK-LABEL: func @gpu_subgroup_id()
func.func @gpu_subgroup_id() -> (index) {
// CHECK: %[[widx:.*]] = rocdl.workitem.id.x : i32
// CHECK: %[[widy:.*]] = rocdl.workitem.id.y : i32
// CHECK: %[[widz:.*]] = rocdl.workitem.id.z : i32
// CHECK: %[[dimx:.*]] = rocdl.workgroup.dim.x : i32
// CHECK: %[[dimy:.*]] = rocdl.workgroup.dim.y : i32
// CHECK: %[[int5:.*]] = llvm.mul %[[dimy]], %[[widz]] overflow<nsw, nuw> : i32
// CHECK: %[[int6:.*]] = llvm.add %[[int5]], %[[widy]] overflow<nsw, nuw> : i32
// CHECK: %[[int7:.*]] = llvm.mul %[[dimx]], %[[int6]] overflow<nsw, nuw> : i32
// CHECK: %[[int8:.*]] = llvm.add %[[widx]], %[[int7]] overflow<nsw, nuw> : i32
// CHECK: %[[wavefrontsize:.*]] = rocdl.wavefrontsize : i32
// CHECK: %[[result:.*]] = llvm.udiv %[[int8]], %[[wavefrontsize]] : i32
// CHECK: = llvm.sext %[[result]] : i32 to i64
%subgroupId = gpu.subgroup_id : index
func.return %subgroupId : index
}
}