Skip to content

[GPU] Add 2d load/store ops validation #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 260 additions & 3 deletions lib/gc/Dialect/LLVMIR/IR/XeVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,275 @@

#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/MathExtras.h"

using namespace mlir;
using namespace xevm;

#include "gc/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
#include "gc/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"

// TODO
LogicalResult BlockLoad2dOp::verify() { return success(); }
LogicalResult BlockStore2dOp::verify() { return success(); }
namespace {
constexpr uint32_t subgroupSize = 16;

template <typename Op> LogicalResult verifyMatrixInput(Op op) {
static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp>::value,
"Unexpected template parameter");

std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
if (pitch && width && *pitch < *width)
return op->emitOpError(
"4th operand (base pitch) should be >= 2nd operand (base width)");

uint32_t elemSize = op.getElemSizeInBits();
if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");

uint32_t tileHeight = op.getTileHeight();
if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");

uint32_t vBlocks = op.getVBlocks();
if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");

return success();
}

LogicalResult verify2DBlockLoadHWRestriction(BlockLoad2dOp op) {
VectorType resTy = op.getRes().getType();
if (!resTy.getElementType().isIntOrFloat())
return op.emitOpError()
<< "expecting result element type to be int or float";
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
unsigned resSize = resTy.getNumElements() * resElemTySize;
unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
op.getTileWidth() * op.getVBlocks() / subgroupSize;
if (resSize != expectedSize)
return op.emitOpError() << "result size of " << resSize
<< " bits does not match the expected size of "
<< expectedSize << " bits";

if (op.getTranspose() && op.getVnniTransform())
return op.emitOpError(
"transpose and vnni_transform are mutually exclusive");

if (!op.getTranspose() && !op.getVnniTransform()) {
uint32_t tileHeight = op.getTileHeight();
if (tileHeight < 1 || tileHeight > 32)
return op.emitOpError("expecting tile_height to be between 1 and 32");

uint32_t tileWidth = op.getTileWidth();
uint32_t vBlocks = op.getVBlocks();
switch (op.getElemSizeInBits()) {
case 8:
if (tileWidth < 4 || tileWidth > 64)
return op.emitOpError("expecting tile_width to be between 4 and 64");
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
if (tileWidth * vBlocks > 64)
return op.emitOpError(
"tile_width * v_blocks should be less than or equal "
"to 64 for 8 bit elements");
break;
case 16:
if (tileWidth < 2 || tileWidth > 32)
return op.emitOpError("expecting tile_width to be between 2 and 32");
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
if (tileWidth * vBlocks > 32)
return op.emitOpError(
"tile_width * v_blocks should be less than or equal "
"to 32 for 16 bit elements");
break;
case 32:
if (tileWidth < 1 || tileWidth > 16)
return op.emitOpError("expecting tile_width to be between 1 and 16");
if (vBlocks != 1 && vBlocks != 2)
return op.emitOpError("expecting v_blocks to be 1 or 2");
if (tileWidth * vBlocks > 16)
return op.emitOpError(
"tile_width * v_blocks should be less than or equal "
"to 16 for 32 bit elements");
break;
case 64:
if (tileWidth < 1 || tileWidth > 8)
return op.emitOpError("expecting tile_width to be between 1 and 8");
if (vBlocks != 1)
return op.emitOpError("expecting v_blocks to be 1");
break;
default:
return op.emitOpError(
"expecting elem_size_in_bits to be 8, 16, 32, or 64");
}

return success();
}

if (op.getTranspose()) {
assert(!op.getVnniTransform() &&
"Expecting vnni_transform should be false");

uint32_t vBlocks = op.getVBlocks();
if (vBlocks != 1)
return op.emitOpError("expecting v_blocks to be 1");

uint32_t tileHeight = op.getTileHeight();
uint32_t tileWidth = op.getTileWidth();
switch (op.getElemSizeInBits()) {
case 32:
if (tileHeight < 1 || tileHeight > 32)
return op.emitOpError("expecting tile_height to be between 1 and 32");
if (tileWidth < 1 || tileWidth > 8)
return op.emitOpError("expecting tile_width to be between 1 and 8");
break;
case 64:
if (tileHeight != 8)
return op.emitOpError(
"expecting tile_height to be 8 for 64 bit elements");
if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
return op.emitOpError("expecting tile_width to be 1, 2, or 4");
break;
default:
return op.emitOpError("transpose is only supported for 32 and 64 bit "
"elements");
}

return success();
}

assert(op.getVnniTransform() && !op.getTranspose() &&
"Expecting vnni_transform should be true and transpose should be "
"false");

uint32_t vBlocks = op.getVBlocks();
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");

uint32_t tileHeight = op.getTileHeight();
uint32_t tileWidth = op.getTileWidth();
switch (op.getElemSizeInBits()) {
case 8:
if (tileHeight < 4 || tileHeight > 32)
return op.emitOpError("expecting tile_height to be between 4 and 32");
if (tileWidth < 4 || tileWidth > 16)
return op.emitOpError("expecting tile_width to be between 4 and 16");
break;
case 16:
if (tileHeight < 2 || tileHeight > 32)
return op.emitOpError("expecting tile_height to be between 2 and 32");
if (tileWidth < 2 || tileWidth > 16)
return op.emitOpError("expecting tile_width to be between 2 and 16");
if (tileWidth * vBlocks > 32)
return op.emitOpError(
"tile_width * v_blocks should be less than or equal "
"to 32 for 16 bit elements");
break;
default:
return op.emitOpError("vnni_transform is only supported for 8 and 16 bit "
"elements");
}

return success();
}

static LogicalResult verify2DBlockStoreHWRestriction(BlockStore2dOp op) {
uint32_t tileHeight = op.getTileHeight();
if (tileHeight < 1 || tileHeight > 8)
return op.emitOpError("expecting tile_height to be between 1 and 8");

uint32_t tileWidth = op.getTileWidth();
switch (op.getElemSizeInBits()) {
case 8:
if (tileWidth < 4 || tileWidth > 64)
return op.emitOpError("expecting tile_width to be between 4 and 64");
break;
case 16:
if (tileWidth < 2 || tileWidth > 32)
return op.emitOpError("expecting tile_width to be between 2 and 32");
break;
case 32:
if (tileWidth < 1 || tileWidth > 16)
return op.emitOpError("expecting tile_width to be between 1 and 16");
break;
case 64:
if (tileWidth < 1 || tileWidth > 8)
return op.emitOpError("expecting tile_width to be between 1 and 8");
break;
default:
return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
}

uint32_t vBlocks = op.getVBlocks();
if (vBlocks != 1)
return op.emitOpError("expecting v_blocks to be 1");
return success();
}

} // namespace

LogicalResult BlockLoad2dOp::verify() {
if (verify2DBlockLoadHWRestriction(*this).failed())
return failure();

if (verifyMatrixInput(*this).failed())
return failure();

VectorType resTy = getRes().getType();
if (!resTy.getElementType().isIntOrFloat())
return emitOpError() << "expecting result element type to be int of float";
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
if (getElemSizeInBits() == 32 || getVnniTransform()) {
if (resElemTySize != 32)
return emitOpError() << "expecting result element type to be 32 bits";
}

uint32_t tileWidth = getTileWidth();
if (getVnniTransform()) {
if (tileWidth != 16)
return emitOpError(
"tile_width when vnni_transform is true should be equal "
"to subgroup size (16 elements)");
return success();
}

return success();
}

LogicalResult BlockStore2dOp::verify() {
if (verify2DBlockStoreHWRestriction(*this).failed())
return failure();

if (verifyMatrixInput(*this).failed())
return failure();

uint32_t tileWidth = getTileWidth();
switch (getElemSizeInBits()) {
case 8:
if (tileWidth != 16 && tileWidth != 32)
return emitOpError("tile_width for 8 bit elements should be equal to "
"16 or 32");
break;
case 16:
if (tileWidth != 16)
return emitOpError("tile_width for 16 bit elements should be equal "
"to 16");
break;
case 32:
if (tileWidth != 16)
return emitOpError("tile_width for 32 bit elements should be equal "
"to 16");
break;
default:
llvm_unreachable("unexpected element size");
}

return success();
}

void XeVMDialect::initialize() {
// NOLINTBEGIN
Expand Down
Loading