Skip to content

Commit

Permalink
[mlir][gpu] Use DenseI32Array for NVVM's maxntid and reqntid (NFC) (l…
Browse files Browse the repository at this point in the history
  • Loading branch information
grypp authored Jan 9, 2024
1 parent ca06c33 commit 2aec708
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 21 deletions.
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// If any of the dimensions are missing, fill them in with 1.
attributes.emplace_back(
kernelBlockSizeAttributeName.value(),
rewriter.getI32ArrayAttr(
rewriter.getDenseI32ArrayAttr(
{dimX.value_or(1), dimY.value_or(1), dimZ.value_or(1)}));
}
}
Expand Down
10 changes: 2 additions & 8 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,19 +1060,13 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
// If maxntid and reqntid exist, it must be an array with max 3 dim
if (attrName == NVVMDialect::getMaxntidAttrName() ||
attrName == NVVMDialect::getReqntidAttrName()) {
auto values = llvm::dyn_cast<ArrayAttr>(attr.getValue());
auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
if (!values || values.empty() || values.size() > 3)
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
for (auto val : llvm::cast<ArrayAttr>(attr.getValue())) {
if (!llvm::dyn_cast<IntegerAttr>(val))
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
}
}
// If minctasm and maxnreg exist, it must be an array with max 3 dim
// If minctasm and maxnreg exist, it must be an integer attribute
if (attrName == NVVMDialect::getMinctasmAttrName() ||
attrName == NVVMDialect::getMaxnregAttrName()) {
if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
Expand Down
10 changes: 4 additions & 6 deletions mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,18 @@ class NVVMDialectLLVMIRTranslationInterface
->addOperand(llvmMetadataNode);
};
if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
if (!dyn_cast<ArrayAttr>(attribute.getValue()))
if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
return failure();
SmallVector<int64_t> values =
extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
if (values.size() > 1)
generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
if (values.size() > 2)
generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
} else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
if (!dyn_cast<ArrayAttr>(attribute.getValue()))
if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
return failure();
SmallVector<int64_t> values =
extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
if (values.size() > 1)
generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ gpu.module @test_module_31 {

gpu.module @gpumodule {
// CHECK-LABEL: func @kernel_with_block_size()
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = [128 : i32, 1 : i32, 1 : i32]}
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>}
gpu.func @kernel_with_block_size() kernel attributes {gpu.known_block_size = array<i32: 128, 1, 1>} {
gpu.return
}
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel} {

// -----

llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} {
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 23, 32>} {
llvm.return
}

Expand All @@ -410,7 +410,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} {
// CHECK: {ptr @kernel_func, !"maxntidz", i32 32}
// -----

llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [1,23,32]} {
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 23, 32>} {
llvm.return
}

Expand Down Expand Up @@ -442,7 +442,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxnreg = 16} {
// CHECK: {ptr @kernel_func, !"maxnreg", i32 16}
// -----

llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32],
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 23, 32>,
nvvm.minctasm = 16, nvvm.maxnreg = 32} {
llvm.return
}
Expand Down Expand Up @@ -472,13 +472,13 @@ nvvm.maxnreg = "boo"} {
}
// -----
// expected-error @below {{'"nvvm.reqntid"' attribute must be integer array with maximum 3 index}}
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [3,4,5,6]} {
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 3, 4, 5, 6>} {
llvm.return
}

// -----
// expected-error @below {{'"nvvm.maxntid"' attribute must be integer array with maximum 3 index}}
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [3,4,5,6]} {
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4, 5, 6>} {
llvm.return
}

0 comments on commit 2aec708

Please sign in to comment.