Skip to content

Commit 8cc3474

Browse files
committed
[mlir][gpu] Patterns to promote gpu.shuffle to specialized AMDGPU ops
Only swizzle promotion for now, may add DPP ops support later.
1 parent 48585ca commit 8cc3474

File tree

7 files changed

+123
-15
lines changed

7 files changed

+123
-15
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,24 @@ def MapNestedForallToThreads :
132132
TransformEachOpTrait,
133133
TransformOpInterface]> {
134134
let description = [{
135-
Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to
135+
Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to
136136
distributed `gpu.thread_id` attribute.
137137

138138
The operation searches for `scf.forall` ops nested under `target` and maps
139-
each such op to GPU threads.
140-
139+
each such op to GPU threads.
140+
141141
`scf.forall` induction variables are rewritten to `gpu.thread_id` according
142142
to the `mapping` attribute.
143143

144144
Different types of mappings attributes are supported:
145145
- the block_dims is a list of integers that specifies the number of
146146
threads in each dimension. This is a mandatory attribute that is used
147-
to constrain the number of threads in each dimension. If an
147+
to constrain the number of threads in each dimension. If an
148148
`scf.forall` op is mapped to fewer threads, predication occurs.
149149
- the warp_dims is a list of integers that specifies the number of
150150
warps in each dimension. This is an optional attribute that is used
151151
to constrain the number of warps in each dimension. When present, this
152-
attribute must be specified in a way that is compatible with the
152+
attribute must be specified in a way that is compatible with the
153153
block_dims attribute. If an `scf.forall` op is mapped to fewer warps,
154154
predication occurs.
155155

@@ -164,7 +164,7 @@ def MapNestedForallToThreads :
164164
inserted after each scf.forall op. At this time, this is an all or nothing
165165
choice. This will need to be tightened in the future.
166166

167-
The operation alters the block size of the given gpu_launch using the
167+
The operation alters the block size of the given gpu_launch using the
168168
mandatory block_dims argument.
169169

170170
#### Return modes:
@@ -268,7 +268,7 @@ def MapForallToBlocks :
268268
Only scf.forall distributed to **at most 3 dimensions** are
269269
currently supported.
270270

271-
The operation alters the block size of the given gpu_launch using the
271+
The operation alters the block size of the given gpu_launch using the
272272
grid_dims argument.
273273

274274
#### Return modes:
@@ -300,7 +300,7 @@ def MapForallToBlocks :
300300
`:` functional-type($target, $result)
301301
}];
302302
let hasVerifier = 1;
303-
303+
304304
let extraClassDeclaration = [{
305305
::mlir::DiagnosedSilenceableFailure applyToOne(
306306
::mlir::transform::TransformRewriter &rewriter,
@@ -310,4 +310,15 @@ def MapForallToBlocks :
310310
}];
311311
}
312312

313+
def ApplyGPUPromoteShuffleToAMDGPUPatternsOp : Op<Transform_Dialect,
314+
"apply_patterns.gpu.gpu_shuffle_to_amdgpu",
315+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
316+
let description = [{
317+
Collects patterns that are tryin to promote `gpu.shuffle`s to specialized
318+
AMDGPU intrinsics.
319+
}];
320+
let assemblyFormat = "attr-dict";
321+
}
322+
323+
313324
#endif // GPU_TRANSFORM_OPS

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
9494
/// Erase barriers that do not enforce conflicting memory side effects.
9595
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns);
9696

97+
/// Tries to promote `gpu.shuffle`s to specialized AMDGPU intrinsics.
98+
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns);
99+
97100
/// Generate the code for registering passes.
98101
#define GEN_PASS_REGISTRATION
99102
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
150150
rewriter.create<LLVM::AndOp>(loc, int32Type, add, negwidth);
151151
Value dstLane;
152152
// TODO: Add support for gpu::ShuffleMode::UP and gpu::ShuffleMode::DOWN.
153-
// TODO: Use ds_swizzle for XOR when step/offsets are constants for better
154-
// perf.
155153
switch (op.getMode()) {
156154
case gpu::ShuffleMode::DOWN:
157155
dstLane = rewriter.create<LLVM::AddOp>(loc, int32Type, srcLaneId,

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ add_mlir_dialect_library(MLIRGPUTransforms
3737
Transforms/ModuleToBinary.cpp
3838
Transforms/NVVMAttachTarget.cpp
3939
Transforms/ParallelLoopMapper.cpp
40+
Transforms/PromoteShuffleToAMDGPU.cpp
4041
Transforms/ROCDLAttachTarget.cpp
41-
Transforms/ShuffleRewriter.cpp
4242
Transforms/SPIRVAttachTarget.cpp
43+
Transforms/ShuffleRewriter.cpp
4344
Transforms/SubgroupReduceLowering.cpp
4445

4546
OBJECT
@@ -52,6 +53,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
5253
MLIRParallelLoopMapperEnumsGen
5354

5455
LINK_LIBS PUBLIC
56+
MLIRAMDGPUDialect
5557
MLIRAffineUtils
5658
MLIRArithDialect
5759
MLIRAsyncDialect
@@ -66,11 +68,11 @@ add_mlir_dialect_library(MLIRGPUTransforms
6668
MLIRMemRefDialect
6769
MLIRNVVMTarget
6870
MLIRPass
71+
MLIRROCDLTarget
6972
MLIRSCFDialect
70-
MLIRSideEffectInterfaces
7173
MLIRSPIRVTarget
74+
MLIRSideEffectInterfaces
7275
MLIRSupport
73-
MLIRROCDLTarget
7476
MLIRTransformUtils
7577
MLIRVectorDialect
7678
)

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
1212
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
1313
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1415
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -136,6 +137,11 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
136137
populateGpuRewritePatterns(patterns);
137138
}
138139

140+
void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
141+
RewritePatternSet &patterns) {
142+
populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
143+
}
144+
139145
//===----------------------------------------------------------------------===//
140146
// ApplyUnrollVectorsSubgroupMmaOp
141147
//===----------------------------------------------------------------------===//
@@ -914,9 +920,10 @@ class GPUTransformDialectExtension
914920
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
915921

916922
GPUTransformDialectExtension() {
917-
declareGeneratedDialect<scf::SCFDialect>();
918-
declareGeneratedDialect<arith::ArithDialect>();
919923
declareGeneratedDialect<GPUDialect>();
924+
declareGeneratedDialect<amdgpu::AMDGPUDialect>();
925+
declareGeneratedDialect<arith::ArithDialect>();
926+
declareGeneratedDialect<scf::SCFDialect>();
920927
registerTransformOps<
921928
#define GET_OP_LIST
922929
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- PromoteShuffleToAMDGPU.cpp - Promote shuffle to AMDGPU -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains patterns to try to promote `gpu.shuffle`s to specialized
10+
// AMDGPU intrinsics.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
15+
16+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
21+
using namespace mlir;
22+
23+
namespace {
24+
/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
25+
/// and offset must be a constant integer in the range [0, 31].
26+
struct PromoteShuffleToSwizzlePattern
27+
: public OpRewritePattern<gpu::ShuffleOp> {
28+
using OpRewritePattern::OpRewritePattern;
29+
30+
LogicalResult matchAndRewrite(gpu::ShuffleOp op,
31+
PatternRewriter &rewriter) const override {
32+
if (op.getMode() != gpu::ShuffleMode::XOR)
33+
return rewriter.notifyMatchFailure(op,
34+
"only xor shuffle mode is supported");
35+
36+
if (!isConstantIntValue(op.getWidth(), 64))
37+
return rewriter.notifyMatchFailure(op,
38+
"only 64 width shuffle is supported");
39+
40+
std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
41+
if (!offset)
42+
return rewriter.notifyMatchFailure(op,
43+
"offset must be a constant integer");
44+
45+
int64_t offsetValue = *offset;
46+
if (offsetValue < 0 || offsetValue >= 32)
47+
return rewriter.notifyMatchFailure(op,
48+
"offset must be in the range [0, 31]");
49+
50+
Location loc = op.getLoc();
51+
Value res = rewriter.create<amdgpu::SwizzleBitModeOp>(
52+
loc, op.getResult(0).getType(), op.getValue(), /*andMask=*/31,
53+
/*orMask=*/0, /*xorMask=*/offsetValue);
54+
Value valid = rewriter.create<arith::ConstantIntOp>(loc, 1, /*width*/ 1);
55+
rewriter.replaceOp(op, {res, valid});
56+
return success();
57+
}
58+
};
59+
} // namespace
60+
61+
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
62+
RewritePatternSet &patterns) {
63+
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
64+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
2+
3+
module attributes {transform.with_named_sequence} {
4+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
5+
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
6+
transform.apply_patterns to %func {
7+
transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu
8+
} : !transform.any_op
9+
transform.yield
10+
}
11+
}
12+
13+
// CHECK-LABEL: func @gpu_shuffle_swizzle
14+
// CHECK-SAME: (%[[ARG:.*]]: i32)
15+
func.func @gpu_shuffle_swizzle(%arg0: i32) -> (i32, i1) {
16+
// CHECK: %[[TRUE:.*]] = arith.constant true
17+
// CHECK: %[[RES:.*]] = amdgpu.swizzle_bitmode %[[ARG]] 31 0 23 : i32
18+
// CHECK: return %[[RES]], %[[TRUE]] : i32, i1
19+
%width = arith.constant 64 : i32
20+
%offset = arith.constant 23 : i32
21+
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
22+
func.return %shfl, %pred : i32, i1
23+
}

0 commit comments

Comments
 (0)