Skip to content

[mlir][xegpu] Tensor descriptor type verifier #124548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 7, 2025

Conversation

adam-smnk
Copy link
Contributor

@adam-smnk adam-smnk commented Jan 27, 2025

Adds XeGPU tensor descriptor type verifier.

The type verifier covers general tensor descriptor invariants w.r.t. Xe ISA semantics.
Related operation verifiers are updated to account for the new descriptor checks and avoid duplication.

@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Adam Siemieniuk (adam-smnk)

Changes

Adds XeGPU tensor descriptor type verifier.

The checks focus on ensuring that provided subgroup map is valid with respect to the underlying data.
Related operation verifiers are updated to account for the new descriptor validation.


Full diff: https://github.com/llvm/llvm-project/pull/124548.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+1-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+53-3)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+15-11)
  • (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+22)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+127-4)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d09c5c1870d506..494f11f041b71f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   }];
 
   let hasCustomAssemblyFormat = true;
-
+  let genVerifyDecl = 1;
 }
 
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eb01b15de75c60..ef0ea38027c450 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   if (parser.parseGreater())
     return {};
 
-  return TensorDescType::get(parser.getContext(), shape, elementType,
-                             encoding.value_or(mlir::Attribute()),
-                             sg_map.value_or(mlir::Attribute()));
+  return TensorDescType::getChecked(
+      [&]() { return parser.emitError(parser.getNameLoc()); },
+      parser.getContext(), shape, elementType,
+      encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
 }
 
 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
   return Base::get(context, shape, elementType, attr, sg_map);
 }
 
+LogicalResult TensorDescType::verify(
+    llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+    llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+    mlir::Attribute encoding, mlir::Attribute sg_map) {
+  size_t rank = shape.size();
+  if (rank > 2)
+    return emitError() << "desc shape rank exceeds 2";
+
+  if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+    ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+    ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+    if (rank == 1) {
+      if (wiLayout[0] != 1 || wiData[0] != 1)
+        return emitError() << "outer layout and data mapping must be 1 "
+                              "for 1D tensor";
+    }
+
+    // For 1D tensor, pad the shape with an outer unit dimension to allow common
+    // validation logic.
+    SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+    if (rank == 1)
+      tensorShape = {1, tensorShape.back()};
+
+    size_t dims = tensorShape.size();
+    for (size_t i = 0; i < dims; ++i) {
+      uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+      if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+        return emitError() << "cannot map " << tensorShape[i]
+                           << " elements into " << wiLayout[i] << " by "
+                           << wiData[i] << " tiles";
+    }
+
+    if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+      auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+      if (wiData[0] != 1)
+        return emitError()
+               << "cannot map over non-contiguous scattered elements";
+
+      unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+      if (wiData[1] > chunkSize)
+        return emitError()
+               << "too few contiguous elements for work item mapping";
+    }
+  }
+
+  return success();
+}
+
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81f46f941785a1..bf9eb8f7e10c3c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
 //   each dimension.
 static bool isArgShapesValid(ArrayRef<int64_t> descShape,
                              ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
-  if (descShape == valShape) {
-    if (!sgMap)
-      return true;
-
-    // this can be relaxed if necessary by supporting non-2d shapes distribution
-    // until the constraints are defined this lives here instead of the tensor
-    // descriptor type.
-    return valShape.size() == sgMap.getWiLayout().size();
-  }
+  // Equal shapes with no distribution - no further verification needed.
+  if (descShape == valShape && !sgMap)
+    return true;
 
+  // Unknown distribution - cannot perform operation on partial shape.
   if (!sgMap)
     return false;
 
-  if (valShape.size() != descShape.size())
+  // Invalid rank or mixed rank usage.
+  size_t descRank = descShape.size();
+  if (descRank > 2 || valShape.size() != descRank)
     return false;
 
+  // For 1D, SG map is guaranteed to be unit size in the outer dimension.
+  // Only take the distribution over the innermost dimension for validation.
+  ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
+  SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
+  if (descRank == 1)
+    mapLayout = {wiLayout.back()};
+
   for (const auto &[factor, dim, expected] :
-       llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+       llvm::zip_equal(mapLayout, valShape, descShape)) {
     if (factor * dim != expected)
       return false;
   }
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index d7174a489888a4..729abc5d69f3d1 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+  gpu.return
+}
+
 // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+  %1 = arith.constant dense<1.0>: vector<2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+    !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7816bff0582f81..94dc15756fe4ae 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    -> vector<8x2xf32>
   return
 }
 
 // -----
 func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
-  %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    -> vector<8xf32>
+  return
+}
+
+// -----
+func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32>
+  // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
   return
 }
 
@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
   return
 }
 
+// -----
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1
+    : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1
+    : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32>
+  // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
 // -----
 func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
   %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
@@ -238,4 +284,81 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
   // expected-error@+1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
   xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
   return
-}
\ No newline at end of file
+}
+
+// -----
+func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{desc shape rank exceeds 2}}
+      !xegpu.tensor_desc<16x2x2xf32>
+  return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
+      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
+      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{cannot map 8 elements into 16 by 1 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{cannot map 4 elements into 2 by 4 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
+      // expected-error@+1 {{cannot map over non-contiguous scattered elements}}
+      !xegpu.tensor_desc<4x2xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+        #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
+  %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+      // expected-error@+1 {{too few contiguous elements for work item mapping}}
+      !xegpu.tensor_desc<16xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+        #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2025

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Adds XeGPU tensor descriptor type verifier.

The checks focus on ensuring that provided subgroup map is valid with respect to the underlying data.
Related operation verifiers are updated to account for the new descriptor validation.


Full diff: https://github.com/llvm/llvm-project/pull/124548.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+1-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+53-3)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+15-11)
  • (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+22)
  • (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+127-4)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index d09c5c1870d506..494f11f041b71f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -179,7 +179,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   }];
 
   let hasCustomAssemblyFormat = true;
-
+  let genVerifyDecl = 1;
 }
 
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eb01b15de75c60..ef0ea38027c450 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -175,9 +175,10 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   if (parser.parseGreater())
     return {};
 
-  return TensorDescType::get(parser.getContext(), shape, elementType,
-                             encoding.value_or(mlir::Attribute()),
-                             sg_map.value_or(mlir::Attribute()));
+  return TensorDescType::getChecked(
+      [&]() { return parser.emitError(parser.getNameLoc()); },
+      parser.getContext(), shape, elementType,
+      encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
 }
 
 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -223,6 +224,55 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
   return Base::get(context, shape, elementType, attr, sg_map);
 }
 
+LogicalResult TensorDescType::verify(
+    llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+    llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+    mlir::Attribute encoding, mlir::Attribute sg_map) {
+  size_t rank = shape.size();
+  if (rank > 2)
+    return emitError() << "desc shape rank exceeds 2";
+
+  if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
+    ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
+    ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
+
+    if (rank == 1) {
+      if (wiLayout[0] != 1 || wiData[0] != 1)
+        return emitError() << "outer layout and data mapping must be 1 "
+                              "for 1D tensor";
+    }
+
+    // For 1D tensor, pad the shape with an outer unit dimension to allow common
+    // validation logic.
+    SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
+    if (rank == 1)
+      tensorShape = {1, tensorShape.back()};
+
+    size_t dims = tensorShape.size();
+    for (size_t i = 0; i < dims; ++i) {
+      uint32_t numElemPerWi = wiLayout[i] * wiData[i];
+      if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
+        return emitError() << "cannot map " << tensorShape[i]
+                           << " elements into " << wiLayout[i] << " by "
+                           << wiData[i] << " tiles";
+    }
+
+    if (mlir::isa_and_nonnull<ScatterTensorDescAttr>(encoding)) {
+      auto scatterAttr = llvm::dyn_cast<ScatterTensorDescAttr>(encoding);
+      if (wiData[0] != 1)
+        return emitError()
+               << "cannot map over non-contiguous scattered elements";
+
+      unsigned chunkSize = scatterAttr.getChunkSize().getInt();
+      if (wiData[1] > chunkSize)
+        return emitError()
+               << "too few contiguous elements for work item mapping";
+    }
+  }
+
+  return success();
+}
+
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81f46f941785a1..bf9eb8f7e10c3c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -81,24 +81,28 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
 //   each dimension.
 static bool isArgShapesValid(ArrayRef<int64_t> descShape,
                              ArrayRef<int64_t> valShape, SGMapAttr sgMap) {
-  if (descShape == valShape) {
-    if (!sgMap)
-      return true;
-
-    // this can be relaxed if necessary by supporting non-2d shapes distribution
-    // until the constraints are defined this lives here instead of the tensor
-    // descriptor type.
-    return valShape.size() == sgMap.getWiLayout().size();
-  }
+  // Equal shapes with no distribution - no further verification needed.
+  if (descShape == valShape && !sgMap)
+    return true;
 
+  // Unknown distribution - cannot perform operation on partial shape.
   if (!sgMap)
     return false;
 
-  if (valShape.size() != descShape.size())
+  // Invalid rank or mixed rank usage.
+  size_t descRank = descShape.size();
+  if (descRank > 2 || valShape.size() != descRank)
     return false;
 
+  // For 1D, SG map is guaranteed to be unit size in the outer dimension.
+  // Only take the distribution over the innermost dimension for validation.
+  ArrayRef<uint32_t> wiLayout = sgMap.getWiLayout();
+  SmallVector<uint32_t> mapLayout(wiLayout.begin(), wiLayout.end());
+  if (descRank == 1)
+    mapLayout = {wiLayout.back()};
+
   for (const auto &[factor, dim, expected] :
-       llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) {
+       llvm::zip_equal(mapLayout, valShape, descShape)) {
     if (factor * dim != expected)
       return false;
   }
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index d7174a489888a4..729abc5d69f3d1 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -97,6 +97,16 @@ gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: func @test_load_nd_vc_4(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<32xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<2xf32>
+  gpu.return
+}
+
 // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -132,6 +142,18 @@ gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK: func @test_store_nd_vc_4(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @test_store_nd_vc_4(%src: memref<24x32xf16>) {
+   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+  %1 = arith.constant dense<1.0>: vector<2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> ->
+    !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7816bff0582f81..94dc15756fe4ae 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -82,16 +82,33 @@ func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
     !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x2xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    -> vector<8x2xf32>
   return
 }
 
 // -----
 func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) {
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
   // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
-  %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<16xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+    -> vector<8xf32>
+  return
+}
+
+// -----
+func.func @test_load_nd_vc_6(%src: memref<24x32xf32>) {
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32>
+  // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>,
+      l2_hint = #xegpu.cache_hint<uncached>}>
+    : !xegpu.tensor_desc<8x16xf32> -> vector<8x1xf32>
   return
 }
 
@@ -116,6 +133,35 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
   return
 }
 
+// -----
+func.func @test_store_nd_vc_3(%dst: memref<24x32xf32>, %data: vector<8x2xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1
+    : vector<8x2xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @test_store_nd_vc_4(%dst: memref<24x32xf32>, %data: vector<2xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1
+    : vector<2xf32>, !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @test_store_nd_vc_5(%dst: memref<24x32xf32>, %data: vector<8x1xf32>) {
+  %1 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32>
+  // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}}
+  xegpu.store_nd %data, %1 : vector<8x1xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
 // -----
 func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
   %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
@@ -238,4 +284,81 @@ func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector
   // expected-error@+1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
   xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
   return
-}
\ No newline at end of file
+}
+
+// -----
+func.func @tensor_desc_invalid_rank(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{desc shape rank exceeds 2}}
+      !xegpu.tensor_desc<16x2x2xf32>
+  return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_layout(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
+      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [2, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{outer layout and data mapping must be 1 for 1D tensor}}
+      !xegpu.tensor_desc<16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [2, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{cannot map 8 elements into 16 by 1 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{cannot map 4 elements into 2 by 4 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [2, 8], wi_data = [4, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
+  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      // expected-error@+1 {{cannot map 4 elements into 8 by 1 tiles}}
+      !xegpu.tensor_desc<4x8xf32, #xegpu.sg_map<wi_layout = [8, 2], wi_data = [1, 2]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data(%src: ui64) {
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> ->
+      // expected-error@+1 {{cannot map over non-contiguous scattered elements}}
+      !xegpu.tensor_desc<4x2xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 2>,
+        #xegpu.sg_map<wi_layout = [1, 1], wi_data = [2, 1]>>
+  return
+}
+
+// -----
+func.func @tensor_desc_scatter_invalid_map_data_1(%src: ui64, %offsets: vector<16xindex>) {
+  %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
+      // expected-error@+1 {{too few contiguous elements for work item mapping}}
+      !xegpu.tensor_desc<16xf32,
+        #xegpu.scatter_tdesc_attr<chunk_size = 1>,
+        #xegpu.sg_map<wi_layout = [1, 8], wi_data = [1, 2]>>
+  return
+}

@adam-smnk
Copy link
Contributor Author

@chencha3 It's a follow-up on one of the discussion from #120566
It slightly relaxes load/store checks to allow for 1D distribution to align with XeGPU design docs.

@adam-smnk
Copy link
Contributor Author

@akroviakov

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, ignoring the scattered case.

@adam-smnk adam-smnk force-pushed the xegpu-tensor-desc-verifier branch from 2b8dd10 to a72c12f Compare February 7, 2025 14:36
@adam-smnk
Copy link
Contributor Author

Rebased on top of main, improved scattered tensor verification and moved more invariant checks from op verifiers to TensorDescType verifier.

@adam-smnk adam-smnk requested a review from akroviakov February 7, 2025 14:40
@adam-smnk adam-smnk merged commit 8a03658 into llvm:main Feb 7, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
Adds XeGPU tensor descriptor type verifier.

The type verifier covers general tensor descriptor invariants w.r.t. Xe
ISA semantics.
Related operation verifiers are updated to account for the new
descriptor checks and avoid duplication.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants