Skip to content

Commit df152b7

Browse files
committed
address review comments
1 parent f2fcf95 commit df152b7

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

lib/gc/Dialect/LLVMIR/IR/XeVMDialect.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1313
#include "mlir/IR/DialectImplementation.h"
1414
#include "llvm/ADT/TypeSwitch.h"
15+
#include "llvm/Support/MathExtras.h"
1516

1617
using namespace mlir;
1718
using namespace xevm;
@@ -32,24 +33,26 @@ template <typename Op> LogicalResult verifyMatrixInput(Op op) {
3233
return op->emitOpError(
3334
"4th operand (base pitch) should be >= 2nd operand (base width)");
3435

35-
if (op.getElemSizeInBits() != 8 && op.getElemSizeInBits() != 16 &&
36-
op.getElemSizeInBits() != 32)
36+
uint32_t elemSize = op.getElemSizeInBits();
37+
if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
3738
return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
3839

3940
uint32_t tileHeight = op.getTileHeight();
40-
if (tileHeight != 1 && tileHeight != 2 && tileHeight != 4 &&
41-
tileHeight != 8 && tileHeight != 16 && tileHeight != 32)
41+
if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
4242
return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
4343

4444
uint32_t vBlocks = op.getVBlocks();
45-
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4 && vBlocks != 8)
45+
if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
4646
return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
4747

4848
return success();
4949
}
5050

5151
LogicalResult verify2DBlockLoadHWRestriction(BlockLoad2dOp op) {
5252
VectorType resTy = op.getRes().getType();
53+
if (!resTy.getElementType().isIntOrFloat())
54+
return op.emitOpError()
55+
<< "expecting result element type to be int or float";
5356
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
5457
unsigned resSize = resTy.getNumElements() * resElemTySize;
5558
unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
@@ -225,6 +228,8 @@ LogicalResult BlockLoad2dOp::verify() {
225228
return failure();
226229

227230
VectorType resTy = getRes().getType();
231+
if (!resTy.getElementType().isIntOrFloat())
232+
return emitOpError() << "expecting result element type to be int of float";
228233
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
229234
if (getElemSizeInBits() == 32 || getVnniTransform()) {
230235
if (resElemTySize != 32)

0 commit comments

Comments
 (0)