-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][ROCDL] Add conversion for gpu.subgroup_id to ROCDL #136405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,25 +52,6 @@ namespace mlir { | |
|
||
using namespace mlir; | ||
|
||
// Truncate or extend the result depending on the index bitwidth specified | ||
// by the LLVMTypeConverter options. | ||
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, | ||
Location loc, Value value, | ||
const LLVMTypeConverter &converter) { | ||
int64_t intWidth = cast<IntegerType>(value.getType()).getWidth(); | ||
int64_t indexBitwidth = converter.getIndexTypeBitwidth(); | ||
auto indexBitwidthType = | ||
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth()); | ||
// TODO: use <=> in C++20. | ||
if (indexBitwidth > intWidth) { | ||
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value); | ||
} | ||
if (indexBitwidth < intWidth) { | ||
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value); | ||
} | ||
return value; | ||
} | ||
|
||
/// Returns true if the given `gpu.func` can be safely called using the bare | ||
/// pointer calling convention. | ||
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { | ||
|
@@ -99,6 +80,25 @@ static constexpr StringLiteral amdgcnDataLayout = | |
"64-S32-A5-G1-ni:7:8:9"; | ||
|
||
namespace { | ||
|
||
// Truncate or extend the result depending on the index bitwidth specified | ||
// by the LLVMTypeConverter options. | ||
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, | ||
Location loc, Value value, | ||
const LLVMTypeConverter &converter) { | ||
int64_t intWidth = cast<IntegerType>(value.getType()).getWidth(); | ||
int64_t indexBitwidth = converter.getIndexTypeBitwidth(); | ||
auto indexBitwidthType = | ||
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth()); | ||
if (indexBitwidth > intWidth) { | ||
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value); | ||
} | ||
if (indexBitwidth < intWidth) { | ||
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value); | ||
} | ||
return value; | ||
} | ||
|
||
struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { | ||
using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern; | ||
|
||
|
@@ -117,16 +117,7 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { | |
rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero}); | ||
Value laneId = rewriter.create<ROCDL::MbcntHiOp>( | ||
loc, intTy, ValueRange{minus1, mbcntLo}); | ||
// Truncate or extend the result depending on the index bitwidth specified | ||
// by the LLVMTypeConverter options. | ||
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); | ||
if (indexBitwidth > 32) { | ||
laneId = rewriter.create<LLVM::SExtOp>( | ||
loc, IntegerType::get(context, indexBitwidth), laneId); | ||
} else if (indexBitwidth < 32) { | ||
laneId = rewriter.create<LLVM::TruncOp>( | ||
loc, IntegerType::get(context, indexBitwidth), laneId); | ||
} | ||
laneId = truncOrExtToLLVMType(rewriter, loc, laneId, *getTypeConverter()); | ||
rewriter.replaceOp(op, {laneId}); | ||
return success(); | ||
} | ||
|
@@ -150,11 +141,11 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> { | |
/*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32, | ||
/*upper=*/op.getUpperBoundAttr().getInt() + 1); | ||
} | ||
Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>( | ||
Value wavefrontSizeOp = rewriter.create<ROCDL::WavefrontSizeOp>( | ||
op.getLoc(), rewriter.getI32Type(), bounds); | ||
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp, | ||
*getTypeConverter()); | ||
rewriter.replaceOp(op, {wavefrontOp}); | ||
wavefrontSizeOp = truncOrExtToLLVMType( | ||
rewriter, op.getLoc(), wavefrontSizeOp, *getTypeConverter()); | ||
rewriter.replaceOp(op, {wavefrontSizeOp}); | ||
return success(); | ||
} | ||
|
||
|
@@ -239,6 +230,65 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { | |
} | ||
}; | ||
|
||
struct GPUSubgroupIdOpToROCDL final | ||
: ConvertOpToLLVMPattern<gpu::SubgroupIdOp> { | ||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; | ||
|
||
LogicalResult | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer that this not be a rewrite pattern on (That way, downstream, we can run this pattern before PRopagateDispatchSizeBounds) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this idea. I created a new PR for it: #137671 |
||
matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
// Calculation of the thread's subgroup identifier. | ||
// | ||
// The process involves mapping the thread's 3D identifier within its | ||
// workgroup/block (w_id.x, w_id.y, w_id.z) to a 1D linear index. | ||
// This linearization assumes a layout where the x-dimension (w_dim.x) | ||
// varies most rapidly (i.e., it is the innermost dimension). | ||
// | ||
// The formula for the linearized thread index is: | ||
// L = w_id.x + w_dim.x * (w_id.y + (w_dim.y * w_id.z)) | ||
// | ||
// Subsequently, the range of linearized indices [0, N_threads-1] is | ||
// divided into consecutive, non-overlapping segments, each representing | ||
// a subgroup of size 'subgroup_size'. | ||
// | ||
// Example Partitioning (N = subgroup_size): | ||
// | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... | | ||
// | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... | | ||
// | ||
// The subgroup identifier is obtained via integer division of the | ||
// linearized thread index by the predefined 'subgroup_size'. | ||
// | ||
// subgroup_id = floor( L / subgroup_size ) | ||
// = (w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) / | ||
// subgroup_size | ||
auto int32Type = IntegerType::get(rewriter.getContext(), 32); | ||
Location loc = op.getLoc(); | ||
LLVM::IntegerOverflowFlags flags = | ||
LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw; | ||
Value workitemIdX = rewriter.create<ROCDL::ThreadIdXOp>(loc, int32Type); | ||
Value workitemIdY = rewriter.create<ROCDL::ThreadIdYOp>(loc, int32Type); | ||
Value workitemIdZ = rewriter.create<ROCDL::ThreadIdZOp>(loc, int32Type); | ||
Value workitemDimX = rewriter.create<ROCDL::BlockDimXOp>(loc, int32Type); | ||
Value workitemDimY = rewriter.create<ROCDL::BlockDimYOp>(loc, int32Type); | ||
Value dimYxIdZ = rewriter.create<LLVM::MulOp>(loc, int32Type, workitemDimY, | ||
workitemIdZ, flags); | ||
Value dimYxIdZPlusIdY = rewriter.create<LLVM::AddOp>( | ||
loc, int32Type, dimYxIdZ, workitemIdY, flags); | ||
Value dimYxIdZPlusIdYTimesDimX = rewriter.create<LLVM::MulOp>( | ||
loc, int32Type, workitemDimX, dimYxIdZPlusIdY, flags); | ||
Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX = | ||
rewriter.create<LLVM::AddOp>(loc, int32Type, workitemIdX, | ||
dimYxIdZPlusIdYTimesDimX, flags); | ||
Value subgroupSize = rewriter.create<ROCDL::WavefrontSizeOp>( | ||
loc, rewriter.getI32Type(), /*upper_bound = */ nullptr); | ||
Value waveIdOp = rewriter.create<LLVM::UDivOp>( | ||
loc, workitemIdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize); | ||
rewriter.replaceOp(op, {truncOrExtToLLVMType(rewriter, loc, waveIdOp, | ||
*getTypeConverter())}); | ||
return success(); | ||
} | ||
}; | ||
|
||
/// Import the GPU Ops to ROCDL Patterns. | ||
#include "GPUToROCDL.cpp.inc" | ||
|
||
|
@@ -249,19 +299,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { | |
// code. | ||
struct LowerGpuOpsToROCDLOpsPass final | ||
: public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> { | ||
LowerGpuOpsToROCDLOpsPass() = default; | ||
LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, | ||
bool useBarePtrCallConv, | ||
gpu::amd::Runtime runtime) { | ||
if (this->chipset.getNumOccurrences() == 0) | ||
this->chipset = chipset; | ||
if (this->indexBitwidth.getNumOccurrences() == 0) | ||
this->indexBitwidth = indexBitwidth; | ||
if (this->useBarePtrCallConv.getNumOccurrences() == 0) | ||
this->useBarePtrCallConv = useBarePtrCallConv; | ||
if (this->runtime.getNumOccurrences() == 0) | ||
this->runtime = runtime; | ||
} | ||
using Base::Base; | ||
|
||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
Base::getDependentDialects(registry); | ||
|
@@ -455,17 +493,15 @@ void mlir::populateGpuToROCDLConversionPatterns( | |
// TODO: Add alignment for workgroup memory | ||
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter); | ||
|
||
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter); | ||
patterns | ||
.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>( | ||
converter); | ||
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset); | ||
|
||
populateMathToROCDLConversionPatterns(converter, patterns); | ||
} | ||
|
||
std::unique_ptr<OperationPass<gpu::GPUModuleOp>> | ||
mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset, | ||
unsigned indexBitwidth, | ||
bool useBarePtrCallConv, | ||
gpu::amd::Runtime runtime) { | ||
lialan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return std::make_unique<LowerGpuOpsToROCDLOpsPass>( | ||
chipset, indexBitwidth, useBarePtrCallConv, runtime); | ||
mlir::createLowerGpuOpsToROCDLOpsPass() { | ||
return std::make_unique<LowerGpuOpsToROCDLOpsPass>(); | ||
} |
Uh oh!
There was an error while loading. Please reload this page.