Skip to content

Commit db42345

Browse files
authored
[MLIR][XeGPU] Add unroll patterns for XeGPU (1/N) (#137010)
Similar to vector ops, XeGPU ops need to be unrolled into smaller shapes such that they can be dispatched into a hardware instruction. This PR marks the initial phase of a series dedicated to incorporating unroll patterns for XeGPU operations. In this installment, we introduce patterns for the following operations: 1. createNd 2. updateNd 3. prefetchNd 4. loadNd 5. storeNd 6. dpas
1 parent d39ca81 commit db42345

File tree

11 files changed

+790
-1
lines changed

11 files changed

+790
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
303303
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
304304
getLaneLayout(), getLaneData(), getOrder());
305305
}
306-
307306
}];
308307

309308
let assemblyFormat = "`<` struct(params) `>`";

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,67 @@ class RewritePatternSet;
1414

1515
namespace xegpu {
1616

17+
/// Options to control the XeGPU unrolling. Its main purpose is to
18+
/// provide a way to customize the native shape of the operation.
19+
struct UnrollOptions {
20+
/// Callback function that indicates whether vector unrolling should be
21+
/// attempted on the operation.
22+
using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
23+
FilterConstraintFnType filterConstraint = nullptr;
24+
UnrollOptions &setFilterConstraint(FilterConstraintFnType constraint) {
25+
filterConstraint = std::move(constraint);
26+
return *this;
27+
}
28+
29+
/// Function that computes the target shape for unrolling. It returns an
30+
/// optional vector of integers representing the shape. If it returns
31+
/// `std::nullopt`, unrolling is aborted for the given operation.
32+
using NativeShapeFnType =
33+
std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
34+
NativeShapeFnType nativeShape = nullptr;
35+
UnrollOptions &setNativeShapeFn(NativeShapeFnType fn) {
36+
nativeShape = std::move(fn);
37+
return *this;
38+
}
39+
40+
/// Function that converts a ShapedType (TensorDescType or VectorType)
41+
/// into the unrolled type based on the tileShape. It returns a vector of
42+
/// types representing the unrolled types for simplicity.
43+
using UnrolledTypeFnType = std::function<SmallVector<Type>(
44+
ShapedType type, ArrayRef<int64_t> tileShape)>;
45+
UnrolledTypeFnType getUnrolledTypes = nullptr;
46+
UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) {
47+
getUnrolledTypes = std::move(fn);
48+
return *this;
49+
}
50+
};
51+
1752
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
1853
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
54+
1955
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
2056
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
2157

58+
/// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
59+
/// Users can control whether an operation to be unrolled or not, as well as
60+
/// its target shape via `options` structure. (via setting filterConstraint
61+
/// and nativeShape respectively, both of them are function refs taking `op` as
62+
/// input).
63+
/// An `op` is unrolled to the `targetShape` as follows, for each of its
64+
/// operands:
65+
/// 1. the unrolled type `unrolledType` and number of unrolled instances
66+
/// `numUnrolledInstances` are computed from the `targetShape`.
67+
/// 2. pack each operand. ExtractStridedSlice are created to break-up the
68+
/// vector operands. And BuiltinUnrealizedCastop are created to break-up
69+
/// the TensorDesc operands.
70+
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
71+
/// result.
72+
/// 4. unpack the results. InsertStridedSlice are inserted for VectorType
73+
/// result, and BuiltinUnrealizedCastOp are inserted for TensorDescType result
74+
/// to re-assemble the slices into the original shape.
75+
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
76+
const UnrollOptions &options);
77+
2278
} // namespace xegpu
2379
} // namespace mlir
2480

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Arith/Utils/Utils.h"
10+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1011
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1112
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1213
#include "mlir/IR/Builders.h"

mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRXeGPUTransforms
22
XeGPUFoldAliasOps.cpp
33
XeGPUSubgroupDistribute.cpp
4+
XeGPUUnroll.cpp
45

56
ADDITIONAL_HEADER_DIRS
67
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU

0 commit comments

Comments
 (0)