-
Notifications
You must be signed in to change notification settings - Fork 13.8k
[mlir][xegpu] Tensor descriptor type verifier #124548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu Author: Adam Siemieniuk (adam-smnk) ChangesAdds XeGPU tensor descriptor type verifier. The checks focus on ensuring that provided subgroup map is valid with respect to the underlying data. Full diff: https://github.com/llvm/llvm-project/pull/124548.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d09c5c1870d506..494f11f041b71f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
}];
let hasCustomAssemblyFormat = true;
-
+ let genVerifyDecl = 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eb01b15de75c60..ef0ea38027c450 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
if (parser.parseGreater())
return {};
- return TensorDescType::get(parser.getContext(), shape, elementType,
- encoding.value_or(mlir::Attribute()),
- sg_map.value_or(mlir::Attribute()));
+ return TensorDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), shape, elementType,
+ encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
}
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, sg_map);
}
+LogicalResult TensorDescType::verify(
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute sg_map) {
+ size_t rank = shape.size();
+ if (rank > 2)
+ return emitError() << "desc shape rank exceeds 2";
+
+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+ ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+ ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+ if (rank == 1) {
+ if (wiLayout[0] != 1 || wiData[0] != 1)
+ return emitError() << "outer layout and data mapping must be 1 "
+ "for 1D tensor";
+ }
+
+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
+ // validation logic.
+ SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+ if (rank == 1)
+ tensorShape = {1, tensorShape.back()};
+
+ size_t dims = tensorShape.size();
+ for (size_t i = 0; i < dims; ++i) {
+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+ return emitError() << "cannot map " << tensorShape[i]
+ << " elements into " << wiLayout[i] << " by "
+ << wiData[i] << " tiles";
+ }
+
+ if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+ auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+ if (wiData[0] != 1)
+ return emitError()
+ << "cannot map over non-contiguous scattered elements";
+
+ unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+ if (wiData[1] > chunkSize)
+ return emitError()
+ << "too few contiguous elements for work item mapping";
+ }
+ }
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81f46f941785a1..bf9eb8f7e10c3c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
// each dimension.
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
- if (descShape == valShape) {
- if (!sgMap)
- return true;
-
- // this can be relaxed if necessary by supporting non-2d shapes distribution
- // until the constraints are defined this lives here instead of the tensor
- // descriptor type.
- return valShape.size() == sgMap.getWiLayout().size();
- }
+ // Equal shapes with no distribution - no further verification needed.
+ if (descShape == valShape && !sgMap)
+ return true;
+ // Unknown distribution - cannot perform operation on partial shape.
if (!sgMap)
return false;
- if (valShape.size() != descShape.size())
+ // Invalid rank or mixed rank usage.
+ size_t descRank = descShape.size();
+ if (descRank > 2 || valShape.size() != descRank)
return false;
+ // For 1D, SG map is guaranteed to be unit size in the outer dimension.
+ // Only take the distribution over the innermost dimension for validation.
+ ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
+ SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
+ if (descRank == 1)
+ mapLayout = {wiLayout.back()};
+
for (const auto &[factor, dim, expected] :
- llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+ llvm::zip_equal(mapLayout, valShape, descShape)) {
if (factor * dim != expected)
return false;
}
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index d7174a489888a4..729abc5d69f3d1 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ gpu.return
+}
+
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+ %1 = arith.constant dense<1.0>: vector<2xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7816bff0582f81..94dc15756fe4ae 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8x2xf32>
return
}
// -----
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8xf32>
+ return
+}
+
+// -----
+func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
return
}
@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
return
}
+// -----
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
// -----
func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
%0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
@@ -238,4 +284,81 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
// expected-error@+1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
return
-}
\ No newline at end of file
+}
+
+// -----
+func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{desc shape rank exceeds 2}}
+ !xegpu.tensor_desc<16x2x2xf32>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{cannot map 8 elements into 16 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{cannot map 4 elements into 2 by 4 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
+ // expected-error@+1 {{cannot map over non-contiguous scattered elements}}
+ !xegpu.tensor_desc<4x2xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+ #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
+ %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+ // expected-error@+1 {{too few contiguous elements for work item mapping}}
+ !xegpu.tensor_desc<16xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+ return
+}
|
@llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesAdds XeGPU tensor descriptor type verifier. The checks focus on ensuring that provided subgroup map is valid with respect to the underlying data. Full diff: https://github.com/llvm/llvm-project/pull/124548.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d09c5c1870d506..494f11f041b71f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
}];
let hasCustomAssemblyFormat = true;
-
+ let genVerifyDecl = 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eb01b15de75c60..ef0ea38027c450 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
if (parser.parseGreater())
return {};
- return TensorDescType::get(parser.getContext(), shape, elementType,
- encoding.value_or(mlir::Attribute()),
- sg_map.value_or(mlir::Attribute()));
+ return TensorDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); },
+ parser.getContext(), shape, elementType,
+ encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
}
void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, sg_map);
}
+LogicalResult TensorDescType::verify(
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute sg_map) {
+ size_t rank = shape.size();
+ if (rank > 2)
+ return emitError() << "desc shape rank exceeds 2";
+
+ if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+ ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+ ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+ if (rank == 1) {
+ if (wiLayout[0] != 1 || wiData[0] != 1)
+ return emitError() << "outer layout and data mapping must be 1 "
+ "for 1D tensor";
+ }
+
+ // For 1D tensor, pad the shape with an outer unit dimension to allow common
+ // validation logic.
+ SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+ if (rank == 1)
+ tensorShape = {1, tensorShape.back()};
+
+ size_t dims = tensorShape.size();
+ for (size_t i = 0; i < dims; ++i) {
+ uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+ if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+ return emitError() << "cannot map " << tensorShape[i]
+ << " elements into " << wiLayout[i] << " by "
+ << wiData[i] << " tiles";
+ }
+
+ if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+ auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+ if (wiData[0] != 1)
+ return emitError()
+ << "cannot map over non-contiguous scattered elements";
+
+ unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+ if (wiData[1] > chunkSize)
+ return emitError()
+ << "too few contiguous elements for work item mapping";
+ }
+ }
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81f46f941785a1..bf9eb8f7e10c3c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
// each dimension.
static bool isArgShapesValid(ArrayRef<int64_t> descShape,
ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
- if (descShape == valShape) {
- if (!sgMap)
- return true;
-
- // this can be relaxed if necessary by supporting non-2d shapes distribution
- // until the constraints are defined this lives here instead of the tensor
- // descriptor type.
- return valShape.size() == sgMap.getWiLayout().size();
- }
+ // Equal shapes with no distribution - no further verification needed.
+ if (descShape == valShape && !sgMap)
+ return true;
+ // Unknown distribution - cannot perform operation on partial shape.
if (!sgMap)
return false;
- if (valShape.size() != descShape.size())
+ // Invalid rank or mixed rank usage.
+ size_t descRank = descShape.size();
+ if (descRank > 2 || valShape.size() != descRank)
return false;
+ // For 1D, SG map is guaranteed to be unit size in the outer dimension.
+ // Only take the distribution over the innermost dimension for validation.
+ ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
+ SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
+ if (descRank == 1)
+ mapLayout = {wiLayout.back()};
+
for (const auto &[factor, dim, expected] :
- llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+ llvm::zip_equal(mapLayout, valShape, descShape)) {
if (factor * dim != expected)
return false;
}
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index d7174a489888a4..729abc5d69f3d1 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
gpu.return
}
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+ gpu.return
+}
+
// CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
// CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
gpu.return
}
+// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+ // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+ %1 = arith.constant dense<1.0>: vector<2xf16>
+ // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+ !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ gpu.return
+}
+
// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7816bff0582f81..94dc15756fe4ae 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8x2xf32>
return
}
// -----
func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
- !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
// expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
- %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ -> vector<8xf32>
+ return
+}
+
+// -----
+func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
+ %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+ %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+ l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
return
}
@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
return
}
+// -----
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1
+ : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
+ %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+ !xegpu.tensor_desc<8x16xf32>
+ // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+ xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
// -----
func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
%0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
@@ -238,4 +284,81 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
// expected-error@+1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
return
-}
\ No newline at end of file
+}
+
+// -----
+func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{desc shape rank exceeds 2}}
+ !xegpu.tensor_desc<16x2x2xf32>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
+ !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{cannot map 8 elements into 16 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{cannot map 4 elements into 2 by 4 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
+ %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+ // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
+ !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
+ %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+ %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
+ // expected-error@+1 {{cannot map over non-contiguous scattered elements}}
+ !xegpu.tensor_desc<4x2xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+ #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
+ return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
+ %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+ // expected-error@+1 {{too few contiguous elements for work item mapping}}
+ !xegpu.tensor_desc<16xf32,
+ #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+ #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, ignoring the scattered case.
Adds XeGPU tensor descriptor type verifier. The checks focus on ensuring that provided subgroup map is valid with respect to the underlying data.
2b8dd10
to
a72c12f
Compare
Rebased on top of main, improved scattered tensor verification and moved more invariant checks from op verifiers to TensorDescType verifier. |
Adds XeGPU tensor descriptor type verifier. The type verifier covers general tensor descriptor invariants w.r.t. Xe ISA semantics. Related operation verifiers are updated to account for the new descriptor checks and avoid duplication.
Adds XeGPU tensor descriptor type verifier.
The type verifier covers general tensor descriptor invariants w.r.t. Xe ISA semantics.
Related operation verifiers are updated to account for the new descriptor checks and avoid duplication.