9
9
10
10
#include " mlir/Dialect/GPU/IR/CompilationInterfaces.h"
11
11
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
12
+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
12
13
#include " mlir/IR/DialectImplementation.h"
13
14
#include " llvm/ADT/TypeSwitch.h"
14
15
@@ -18,9 +19,260 @@ using namespace xevm;
18
19
#include " gc/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
19
20
#include " gc/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
20
21
21
- // TODO
22
- LogicalResult BlockLoad2dOp::verify () { return success (); }
23
- LogicalResult BlockStore2dOp::verify () { return success (); }
22
+ namespace {
23
+ constexpr uint32_t subgroupSize = 16 ;
24
+
25
+ template <typename Op> LogicalResult verifyMatrixInput (Op op) {
26
+ static_assert (llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp>::value,
27
+ " Unexpected template parameter" );
28
+
29
+ std::optional<int64_t > width = getConstantIntValue (op.getBaseWidth ());
30
+ std::optional<int64_t > pitch = getConstantIntValue (op.getBasePitch ());
31
+ if (pitch && width && *pitch < *width)
32
+ return op->emitOpError (
33
+ " 4th operand (base pitch) should be >= 2nd operand (base width)" );
34
+
35
+ if (op.getElemSizeInBits () != 8 && op.getElemSizeInBits () != 16 &&
36
+ op.getElemSizeInBits () != 32 )
37
+ return op->emitOpError (" expecting 'elem_size_in_bits' to be 8, 16, or 32" );
38
+
39
+ uint32_t tileHeight = op.getTileHeight ();
40
+ if (tileHeight != 1 && tileHeight != 2 && tileHeight != 4 &&
41
+ tileHeight != 8 && tileHeight != 16 && tileHeight != 32 )
42
+ return op->emitOpError (" expecting tile_height to be 1, 2, 4, 8, 16, or 32" );
43
+
44
+ uint32_t vBlocks = op.getVBlocks ();
45
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 && vBlocks != 8 )
46
+ return op->emitOpError (" expecting v_blocks to be 1, 2, 4, or 8" );
47
+
48
+ return success ();
49
+ }
50
+
51
+ LogicalResult verify2DBlockLoadHWRestriction (BlockLoad2dOp op) {
52
+ VectorType resTy = op.getRes ().getType ();
53
+ unsigned resElemTySize = resTy.getElementType ().getIntOrFloatBitWidth ();
54
+ unsigned resSize = resTy.getNumElements () * resElemTySize;
55
+ unsigned expectedSize = op.getElemSizeInBits () * op.getTileHeight () *
56
+ op.getTileWidth () * op.getVBlocks () / subgroupSize;
57
+ if (resSize != expectedSize)
58
+ return op.emitOpError () << " result size of " << resSize
59
+ << " bits does not match the expected size of "
60
+ << expectedSize << " bits" ;
61
+
62
+ if (op.getTranspose () && op.getVnniTransform ())
63
+ return op.emitOpError (
64
+ " transpose and vnni_transform are mutually exclusive" );
65
+
66
+ if (!op.getTranspose () && !op.getVnniTransform ()) {
67
+ uint32_t tileHeight = op.getTileHeight ();
68
+ if (tileHeight < 1 || tileHeight > 32 )
69
+ return op.emitOpError (" expecting tile_height to be between 1 and 32" );
70
+
71
+ uint32_t tileWidth = op.getTileWidth ();
72
+ uint32_t vBlocks = op.getVBlocks ();
73
+ switch (op.getElemSizeInBits ()) {
74
+ case 8 :
75
+ if (tileWidth < 4 || tileWidth > 64 )
76
+ return op.emitOpError (" expecting tile_width to be between 4 and 64" );
77
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 )
78
+ return op.emitOpError (" expecting v_blocks to be 1, 2, or 4" );
79
+ if (tileWidth * vBlocks > 64 )
80
+ return op.emitOpError (
81
+ " tile_width * v_blocks should be less than or equal "
82
+ " to 64 for 8 bit elements" );
83
+ break ;
84
+ case 16 :
85
+ if (tileWidth < 2 || tileWidth > 32 )
86
+ return op.emitOpError (" expecting tile_width to be between 2 and 32" );
87
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 )
88
+ return op.emitOpError (" expecting v_blocks to be 1, 2, or 4" );
89
+ if (tileWidth * vBlocks > 32 )
90
+ return op.emitOpError (
91
+ " tile_width * v_blocks should be less than or equal "
92
+ " to 32 for 16 bit elements" );
93
+ break ;
94
+ case 32 :
95
+ if (tileWidth < 1 || tileWidth > 16 )
96
+ return op.emitOpError (" expecting tile_width to be between 1 and 16" );
97
+ if (vBlocks != 1 && vBlocks != 2 )
98
+ return op.emitOpError (" expecting v_blocks to be 1 or 2" );
99
+ if (tileWidth * vBlocks > 16 )
100
+ return op.emitOpError (
101
+ " tile_width * v_blocks should be less than or equal "
102
+ " to 16 for 32 bit elements" );
103
+ break ;
104
+ case 64 :
105
+ if (tileWidth < 1 || tileWidth > 8 )
106
+ return op.emitOpError (" expecting tile_width to be between 1 and 8" );
107
+ if (vBlocks != 1 )
108
+ return op.emitOpError (" expecting v_blocks to be 1" );
109
+ break ;
110
+ default :
111
+ return op.emitOpError (
112
+ " expecting elem_size_in_bits to be 8, 16, 32, or 64" );
113
+ }
114
+
115
+ return success ();
116
+ }
117
+
118
+ if (op.getTranspose ()) {
119
+ assert (!op.getVnniTransform () &&
120
+ " Expecting vnni_transform should be false" );
121
+
122
+ uint32_t vBlocks = op.getVBlocks ();
123
+ if (vBlocks != 1 )
124
+ return op.emitOpError (" expecting v_blocks to be 1" );
125
+
126
+ uint32_t tileHeight = op.getTileHeight ();
127
+ uint32_t tileWidth = op.getTileWidth ();
128
+ switch (op.getElemSizeInBits ()) {
129
+ case 32 :
130
+ if (tileHeight < 1 || tileHeight > 32 )
131
+ return op.emitOpError (" expecting tile_height to be between 1 and 32" );
132
+ if (tileWidth < 1 || tileWidth > 8 )
133
+ return op.emitOpError (" expecting tile_width to be between 1 and 8" );
134
+ break ;
135
+ case 64 :
136
+ if (tileHeight != 8 )
137
+ return op.emitOpError (
138
+ " expecting tile_height to be 8 for 64 bit elements" );
139
+ if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4 )
140
+ return op.emitOpError (" expecting tile_width to be 1, 2, or 4" );
141
+ break ;
142
+ default :
143
+ return op.emitOpError (" transpose is only supported for 32 and 64 bit "
144
+ " elements" );
145
+ }
146
+
147
+ return success ();
148
+ }
149
+
150
+ assert (op.getVnniTransform () && !op.getTranspose () &&
151
+ " Expecting vnni_transform should be true and transpose should be "
152
+ " false" );
153
+
154
+ uint32_t vBlocks = op.getVBlocks ();
155
+ if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 )
156
+ return op.emitOpError (" expecting v_blocks to be 1, 2, or 4" );
157
+
158
+ uint32_t tileHeight = op.getTileHeight ();
159
+ uint32_t tileWidth = op.getTileWidth ();
160
+ switch (op.getElemSizeInBits ()) {
161
+ case 8 :
162
+ if (tileHeight < 4 || tileHeight > 32 )
163
+ return op.emitOpError (" expecting tile_height to be between 4 and 32" );
164
+ if (tileWidth < 4 || tileWidth > 16 )
165
+ return op.emitOpError (" expecting tile_width to be between 4 and 16" );
166
+ break ;
167
+ case 16 :
168
+ if (tileHeight < 2 || tileHeight > 32 )
169
+ return op.emitOpError (" expecting tile_height to be between 2 and 32" );
170
+ if (tileWidth < 2 || tileWidth > 16 )
171
+ return op.emitOpError (" expecting tile_width to be between 2 and 16" );
172
+ if (tileWidth * vBlocks > 32 )
173
+ return op.emitOpError (
174
+ " tile_width * v_blocks should be less than or equal "
175
+ " to 32 for 16 bit elements" );
176
+ break ;
177
+ default :
178
+ return op.emitOpError (" vnni_transform is only supported for 8 and 16 bit "
179
+ " elements" );
180
+ }
181
+
182
+ return success ();
183
+ }
184
+
185
+ static LogicalResult verify2DBlockStoreHWRestriction (BlockStore2dOp op) {
186
+ uint32_t tileHeight = op.getTileHeight ();
187
+ if (tileHeight < 1 || tileHeight > 8 )
188
+ return op.emitOpError (" expecting tile_height to be between 1 and 8" );
189
+
190
+ uint32_t tileWidth = op.getTileWidth ();
191
+ switch (op.getElemSizeInBits ()) {
192
+ case 8 :
193
+ if (tileWidth < 4 || tileWidth > 64 )
194
+ return op.emitOpError (" expecting tile_width to be between 4 and 64" );
195
+ break ;
196
+ case 16 :
197
+ if (tileWidth < 2 || tileWidth > 32 )
198
+ return op.emitOpError (" expecting tile_width to be between 2 and 32" );
199
+ break ;
200
+ case 32 :
201
+ if (tileWidth < 1 || tileWidth > 16 )
202
+ return op.emitOpError (" expecting tile_width to be between 1 and 16" );
203
+ break ;
204
+ case 64 :
205
+ if (tileWidth < 1 || tileWidth > 8 )
206
+ return op.emitOpError (" expecting tile_width to be between 1 and 8" );
207
+ break ;
208
+ default :
209
+ return op.emitOpError (" expecting elem_size_in_bits to be 8, 16, 32, or 64" );
210
+ }
211
+
212
+ uint32_t vBlocks = op.getVBlocks ();
213
+ if (vBlocks != 1 )
214
+ return op.emitOpError (" expecting v_blocks to be 1" );
215
+ return success ();
216
+ }
217
+
218
+ } // namespace
219
+
220
+ LogicalResult BlockLoad2dOp::verify () {
221
+ if (verify2DBlockLoadHWRestriction (*this ).failed ())
222
+ return failure ();
223
+
224
+ if (verifyMatrixInput (*this ).failed ())
225
+ return failure ();
226
+
227
+ VectorType resTy = getRes ().getType ();
228
+ unsigned resElemTySize = resTy.getElementType ().getIntOrFloatBitWidth ();
229
+ if (getElemSizeInBits () == 32 || getVnniTransform ()) {
230
+ if (resElemTySize != 32 )
231
+ return emitOpError () << " expecting result element type to be 32 bits" ;
232
+ }
233
+
234
+ uint32_t tileWidth = getTileWidth ();
235
+ if (getVnniTransform ()) {
236
+ if (tileWidth != 16 )
237
+ return emitOpError (
238
+ " tile_width when vnni_transform is true should be equal "
239
+ " to subgroup size (16 elements)" );
240
+ return success ();
241
+ }
242
+
243
+ return success ();
244
+ }
245
+
246
+ LogicalResult BlockStore2dOp::verify () {
247
+ if (verify2DBlockStoreHWRestriction (*this ).failed ())
248
+ return failure ();
249
+
250
+ if (verifyMatrixInput (*this ).failed ())
251
+ return failure ();
252
+
253
+ uint32_t tileWidth = getTileWidth ();
254
+ switch (getElemSizeInBits ()) {
255
+ case 8 :
256
+ if (tileWidth != 16 && tileWidth != 32 )
257
+ return emitOpError (" tile_width for 8 bit elements should be equal to "
258
+ " 16 or 32" );
259
+ break ;
260
+ case 16 :
261
+ if (tileWidth != 16 )
262
+ return emitOpError (" tile_width for 16 bit elements should be equal "
263
+ " to 16" );
264
+ break ;
265
+ case 32 :
266
+ if (tileWidth != 16 )
267
+ return emitOpError (" tile_width for 32 bit elements should be equal "
268
+ " to 16" );
269
+ break ;
270
+ default :
271
+ llvm_unreachable (" unexpected element size" );
272
+ }
273
+
274
+ return success ();
275
+ }
24
276
25
277
void XeVMDialect::initialize () {
26
278
// NOLINTBEGIN
0 commit comments