12
12
#include " mlir/Dialect/Utils/StaticValueUtils.h"
13
13
#include " mlir/IR/DialectImplementation.h"
14
14
#include " llvm/ADT/TypeSwitch.h"
15
+ #include " llvm/Support/MathExtras.h"
15
16
16
17
using namespace mlir ;
17
18
using namespace xevm ;
@@ -32,24 +33,26 @@ template <typename Op> LogicalResult verifyMatrixInput(Op op) {
32
33
return op->emitOpError (
33
34
" 4th operand (base pitch) should be >= 2nd operand (base width)" );
34
35
35
- if (op. getElemSizeInBits () != 8 && op.getElemSizeInBits () != 16 &&
36
- op. getElemSizeInBits () != 32 )
36
+ uint32_t elemSize = op.getElemSizeInBits ();
37
+ if (elemSize < 8 || ! llvm::isPowerOf2_32 (elemSize) || elemSize > 32 )
37
38
return op->emitOpError (" expecting 'elem_size_in_bits' to be 8, 16, or 32" );
38
39
39
40
uint32_t tileHeight = op.getTileHeight ();
40
- if (tileHeight != 1 && tileHeight != 2 && tileHeight != 4 &&
41
- tileHeight != 8 && tileHeight != 16 && tileHeight != 32 )
41
+ if (tileHeight > 32 || !llvm::isPowerOf2_32 (tileHeight))
42
42
return op->emitOpError (" expecting tile_height to be 1, 2, 4, 8, 16, or 32" );
43
43
44
44
uint32_t vBlocks = op.getVBlocks ();
45
- if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 && vBlocks != 8 )
45
+ if (vBlocks > 8 || ! llvm::isPowerOf2_32 ( vBlocks) )
46
46
return op->emitOpError (" expecting v_blocks to be 1, 2, 4, or 8" );
47
47
48
48
return success ();
49
49
}
50
50
51
51
LogicalResult verify2DBlockLoadHWRestriction (BlockLoad2dOp op) {
52
52
VectorType resTy = op.getRes ().getType ();
53
+ if (!resTy.getElementType ().isIntOrFloat ())
54
+ return op.emitOpError ()
55
+ << " expecting result element type to be int or float" ;
53
56
unsigned resElemTySize = resTy.getElementType ().getIntOrFloatBitWidth ();
54
57
unsigned resSize = resTy.getNumElements () * resElemTySize;
55
58
unsigned expectedSize = op.getElemSizeInBits () * op.getTileHeight () *
@@ -225,6 +228,8 @@ LogicalResult BlockLoad2dOp::verify() {
225
228
return failure ();
226
229
227
230
VectorType resTy = getRes ().getType ();
231
+ if (!resTy.getElementType ().isIntOrFloat ())
232
+ return emitOpError () << " expecting result element type to be int of float" ;
228
233
unsigned resElemTySize = resTy.getElementType ().getIntOrFloatBitWidth ();
229
234
if (getElemSizeInBits () == 32 || getVnniTransform ()) {
230
235
if (resElemTySize != 32 )
0 commit comments