@@ -14,11 +14,67 @@ class RewritePatternSet;
14
14
15
15
namespace xegpu {
16
16
17
+ // / Options to control the XeGPU unrolling. Its main purpose is to
18
+ // / provide a way to customize the native shape of the operation.
19
+ struct UnrollOptions {
20
+ // / Callback function that indicates whether vector unrolling should be
21
+ // / attempted on the operation.
22
+ using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
23
+ FilterConstraintFnType filterConstraint = nullptr ;
24
+ UnrollOptions &setFilterConstraint (FilterConstraintFnType constraint) {
25
+ filterConstraint = std::move (constraint);
26
+ return *this ;
27
+ }
28
+
29
+ // / Function that computes the target shape for unrolling. It returns an
30
+ // / optional vector of integers representing the shape. If it returns
31
+ // / `std::nullopt`, unrolling is aborted for the given operation.
32
+ using NativeShapeFnType =
33
+ std::function<std::optional<SmallVector<int64_t >>(Operation *op)>;
34
+ NativeShapeFnType nativeShape = nullptr ;
35
+ UnrollOptions &setNativeShapeFn (NativeShapeFnType fn) {
36
+ nativeShape = std::move (fn);
37
+ return *this ;
38
+ }
39
+
40
+ // / Function that converts a ShapedType (TensorDescType or VectorType)
41
+ // / into the unrolled type based on the tileShape. It returns a vector of
42
+ // / types representing the unrolled types for simplicity.
43
+ using UnrolledTypeFnType = std::function<SmallVector<Type>(
44
+ ShapedType type, ArrayRef<int64_t > tileShape)>;
45
+ UnrolledTypeFnType getUnrolledTypes = nullptr ;
46
+ UnrollOptions &setUnrolledTypesFn (UnrolledTypeFnType fn) {
47
+ getUnrolledTypes = std::move (fn);
48
+ return *this ;
49
+ }
50
+ };
51
+
17
52
// / Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
18
53
void populateXeGPUFoldAliasOpsPatterns (RewritePatternSet &patterns);
54
+
19
55
// / Appends patterns for XeGPU SIMT distribution into `patterns`.
20
56
void populateXeGPUSubgroupDistributePatterns (RewritePatternSet &patterns);
21
57
58
+ // / Collect a set of patterns to unroll xegpu operations to a smaller shapes.
59
+ // / Users can control whether an operation to be unrolled or not, as well as
60
+ // / its target shape via `options` structure. (via setting filterConstraint
61
+ // / and nativeShape respectively, both of them are function refs taking `op` as
62
+ // / input).
63
+ // / An `op` is unrolled to the `targetShape` as follows, for each of its
64
+ // / operands:
65
+ // / 1. the unrolled type `unrolledType` and number of unrolled instances
66
+ // / `numUnrolledInstances` are computed from the `targetShape`.
67
+ // / 2. pack each operand. ExtractStridedSlice are created to break-up the
68
+ // / vector operands. And BuiltinUnrealizedCastop are created to break-up
69
+ // / the TensorDesc operands.
70
+ // / 3. the original op is cloned `numUnrolledInstances` times, once for each
71
+ // / result.
72
+ // / 4. unpack the results. InsertStridedSlice are inserted for VectorType
73
+ // / result, and BuiltinUnrealizedCastOp are inserted for TensorDescType result
74
+ // / to re-assemble the slices into the original shape.
75
+ void populateXeGPUUnrollPatterns (RewritePatternSet &patterns,
76
+ const UnrollOptions &options);
77
+
22
78
} // namespace xegpu
23
79
} // namespace mlir
24
80
0 commit comments