@@ -12,6 +12,9 @@ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
12
12
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
13
13
include "mlir/Interfaces/SideEffectInterfaces.td"
14
14
15
+ include "mlir/IR/OpBase.td"
16
+ include "mlir/IR/EnumAttr.td"
17
+
15
18
def XeVM_Dialect : Dialect {
16
19
let name = "xevm";
17
20
let cppNamespace = "::mlir::xevm";
@@ -26,4 +29,138 @@ class XeVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
26
29
27
30
def XeVM_TargettAttr : XeVM_Attr<"XeVMTarget", "target"> {}
28
31
32
+ class XeVM_Op<string mnemonic, list<Trait> traits = []> :
33
+ Op<XeVM_Dialect, mnemonic, traits>;
34
+
35
+ def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>;
36
+
37
+ class XeVM_LoadCacheControl<string cacheMnemonic> : I32EnumAttr<!strconcat(cacheMnemonic, "LoadCacheControl"), "XeVM load ops cache control",
38
+ [
39
+ I32EnumAttrCase<"DEFAULT", 0, "Default">,
40
+ I32EnumAttrCase<"UC", 1, !strconcat(cacheMnemonic, "UC")>, // uncached
41
+ I32EnumAttrCase<"C", 2, !strconcat(cacheMnemonic, "C")>, // cached
42
+ I32EnumAttrCase<"S", 3, !strconcat(cacheMnemonic, "S")>, // streaming
43
+ I32EnumAttrCase<"IAR", 4, !strconcat(cacheMnemonic, "IAR")>, // invalidate-after-read
44
+ ]> {
45
+ let cppNamespace = "::mlir::xevm";
46
+ }
47
+
48
+ def XeVM_L1LoadCacheControl : XeVM_LoadCacheControl<"L1">;
49
+ def XeVM_L3LoadCacheControl : XeVM_LoadCacheControl<"L3">;
50
+
51
+ class XeVM_StoreCacheControl<string cacheMnemonic> : I32EnumAttr<!strconcat(cacheMnemonic, "StoreCacheControl"), "XeVM store ops cache control",
52
+ [
53
+ I32EnumAttrCase<"DEFAULT", 0, "Default">,
54
+ I32EnumAttrCase<"UC", 1, !strconcat(cacheMnemonic, "UC")>, // uncached
55
+ I32EnumAttrCase<"WT", 2, !strconcat(cacheMnemonic, "WT")>, // write-through
56
+ I32EnumAttrCase<"S", 3, !strconcat(cacheMnemonic, "S")>, // streaming
57
+ I32EnumAttrCase<"WB", 4, !strconcat(cacheMnemonic, "WB")>, // write back
58
+ ]> {
59
+ let cppNamespace = "::mlir::xevm";
60
+ }
61
+
62
+ def XeVM_L1StoreCacheControl : XeVM_StoreCacheControl<"L1">;
63
+ def XeVM_L3StoreCacheControl : XeVM_StoreCacheControl<"L3">;
64
+
65
+ def XeVM_BlockLoad2dOp : XeVM_Op<"blockload2d">,
66
+ Results<(outs FixedVectorOf<[XeVM_ElemType]>:$res)>,
67
+ Arguments<(ins
68
+ Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
69
+ I32:$base_width,
70
+ I32:$base_height,
71
+ I32:$base_pitch,
72
+ I32:$x,
73
+ I32:$y,
74
+ I32Attr:$elem_size_in_bits,
75
+ I32Attr:$tile_width,
76
+ I32Attr:$tile_height,
77
+ I32Attr:$v_blocks,
78
+ I1Attr:$transpose,
79
+ I1Attr:$vnni_transform,
80
+ DefaultValuedAttr<XeVM_L1LoadCacheControl, "::mlir::xevm::L1LoadCacheControl::DEFAULT">:$l1_cache_control,
81
+ DefaultValuedAttr<XeVM_L3LoadCacheControl, "::mlir::xevm::L3LoadCacheControl::DEFAULT">:$l3_cache_control
82
+ )> {
83
+
84
+ let summary = "2D block load";
85
+
86
+ let description = [{
87
+ The `xevm.blockload2d` operation loads a two dimensional matrix tile
88
+ from a larger matrix residing in memory. The parameters are:
89
+ $ptr - the base address of the matrix containing the tile to load
90
+ $base_width, $base_height, $base_pitch - the shape of matrix
91
+ $x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to load
92
+ $elem_size_in_bits - the size in bits of the matrix element
93
+ - 32 for f32, bf32
94
+ - 16 for f16, int16, bf16
95
+ - 8 for int8, int4, int2
96
+ $v_blocks - number of tiles to load
97
+ $transpose - transpose the tile in registers (useful for 32 bit element type)
98
+ $vnni_transform - transpose and pack the submatrix in registers (useful for < 32 bit element types)
99
+ $cache_control - an enumerator that sets the L1 and L3 cache behaviour
100
+
101
+ Notes:
102
+ - the $transpose and $vnni_transform parameters are mutual exclusive
103
+ - transposing the tile loaded is typically used for the B matrix operand
104
+ (D = C + A * B), where A has row-major layout in registers and B should have column-major layout.
105
+ - if the tile loaded contains out of bound elements of the matrix, they are filled with 0.
106
+ - coordinate is provided in elements, while width and pitch are provided in bytes.
107
+ }];
108
+
109
+ let assemblyFormat = [{
110
+ operands ` ` `{` `elem_size_in_bits` `=` $elem_size_in_bits `,` `tile_width` `=` $tile_width `,`
111
+ `tile_height` `=` $tile_height `,` `v_blocks` `=` $v_blocks `,` `transpose` `=` $transpose `,`
112
+ `vnni_transform` `=` $vnni_transform `,` `l1_cache_control` `=` $l1_cache_control `,`
113
+ `l3_cache_control` `=` $l3_cache_control `}` attr-dict `:` functional-type(operands, results)
114
+ }];
115
+
116
+ let hasVerifier = 1;
117
+ }
118
+
119
+ def XeVM_BlockStore2dOp : XeVM_Op<"blockstore2d">,
120
+ Arguments<(ins
121
+ Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
122
+ I32:$base_width,
123
+ I32:$base_height,
124
+ I32:$base_pitch,
125
+ I32:$x,
126
+ I32:$y,
127
+ I32Attr:$elem_size_in_bits,
128
+ I32Attr:$tile_width,
129
+ I32Attr:$tile_height,
130
+ I32Attr:$v_blocks,
131
+ FixedVectorOf<[XeVM_ElemType]>:$stored_val,
132
+ DefaultValuedAttr<XeVM_L1StoreCacheControl, "::mlir::xevm::L1StoreCacheControl::DEFAULT">:$l1_cache_control,
133
+ DefaultValuedAttr<XeVM_L3StoreCacheControl, "::mlir::xevm::L3StoreCacheControl::DEFAULT">:$l3_cache_control
134
+ )> {
135
+
136
+ let summary = "2D block store";
137
+
138
+ let description = [{
139
+ The `xevm.blockstore2d` operation stores a two dimensional tile into a
140
+ larger matrix residing in memory. The parameters are:
141
+ $ptr - the base address of the matrix where to store the tile
142
+ $base_width, $base_height, $base_pitch - the shape of the matrix
143
+ $x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to store
144
+ $elem_size_in_bits - the size in bits of the matrix element
145
+ - 32 for f32, bf32
146
+ - 16 for f16, int16, bf16
147
+ - 8 for int8, int4, int2
148
+ $v_blocks - number of tiles to store
149
+ $cache_control - an enumerator that sets the L1 and L3 cache behaviour
150
+ $stored_val - the tile to store
151
+
152
+ Notes:
153
+ - coordinate is provided in elements, while width and pitch are provided in bytes.
154
+ }];
155
+
156
+ let assemblyFormat = [{
157
+ operands ` ` `{` `elem_size_in_bits` `=` $elem_size_in_bits `,` `tile_width` `=` $tile_width `,`
158
+ `tile_height` `=` $tile_height `,` `v_blocks` `=` $v_blocks `,` `l1_cache_control` `=` $l1_cache_control `,`
159
+ `l3_cache_control` `=` $l3_cache_control `}`
160
+ attr-dict `:` `(` type(operands) `)`
161
+ }];
162
+
163
+ let hasVerifier = 1;
164
+ }
165
+
29
166
#endif // XEVMIR_OPS
0 commit comments