Skip to content

[tosa] Add verifier checks for Scatter #142661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged

Conversation

Tai78641
Copy link
Contributor

@Tai78641 Tai78641 commented Jun 3, 2025

This adds verifier checks for the scatter op
to make sure the shapes of inputs and output
are consistent with respect to spec.

This adds verifier checks for the scatter op
to make sure the shapes of inputs and output
are consistent with respect to spec.

Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I59531fa63e2d1dbd2865e0ef9b08b76991915c9a
@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

This adds verifier checks for the scatter op
to make sure the shapes of inputs and output
are consistent with respect to spec.


Full diff: https://github.com/llvm/llvm-project/pull/142661.diff

9 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+67)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+11-11)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+72)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a22e6b7aa9791..f707770970e5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2692,6 +2692,73 @@ LogicalResult tosa::ScatterOp::verify() {
           .failed()) {
     return failure();
   }
+
+  const ShapeAdaptor valuesInShape(getValuesIn().getType());
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  const ShapeAdaptor inputShape(getInput().getType());
+  const ShapeAdaptor outputShape(getValuesOut().getType());
+
+  int64_t N = ShapedType::kDynamic;
+  int64_t K = ShapedType::kDynamic;
+  int64_t W = ShapedType::kDynamic;
+  int64_t C = ShapedType::kDynamic;
+  if (valuesInShape.hasRank()) {
+    N = valuesInShape.getDimSize(0);
+    K = valuesInShape.getDimSize(1);
+    C = valuesInShape.getDimSize(2);
+  }
+  if (indicesShape.hasRank()) {
+    const int64_t indicesN = indicesShape.getDimSize(0);
+    W = indicesShape.getDimSize(1);
+    if (N == ShapedType::kDynamic)
+      N = indicesN;
+    else if (indicesN != ShapedType::kDynamic && N != indicesN)
+      return emitOpError() << "requires indices dimension 0 to have size " << N
+                           << ", got " << indicesN;
+  }
+  if (inputShape.hasRank()) {
+    const int64_t inputN = inputShape.getDimSize(0);
+    const int64_t inputW = inputShape.getDimSize(1);
+    const int64_t inputC = inputShape.getDimSize(2);
+    if (N == ShapedType::kDynamic)
+      N = inputN;
+    else if (inputN != ShapedType::kDynamic && N != inputN)
+      return emitOpError() << "requires input dimension 0 to have size " << N
+                           << ", got " << inputN;
+    if (W == ShapedType::kDynamic)
+      W = inputW;
+    else if (inputW != ShapedType::kDynamic && W != inputW)
+      return emitOpError() << "requires input dimension 1 to have size " << W
+                           << ", got " << inputW;
+
+    if (C == ShapedType::kDynamic)
+      C = inputC;
+    else if (inputC != ShapedType::kDynamic && C != inputC)
+      return emitOpError() << "requires input dimension 2 to have size " << C
+                           << ", got " << inputC;
+  }
+  if (outputShape.hasRank()) {
+    const int64_t outputN = outputShape.getDimSize(0);
+    const int64_t outputK = outputShape.getDimSize(1);
+    const int64_t outputC = outputShape.getDimSize(2);
+    if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+        N != outputN)
+      return emitOpError() << "requires values_out dimension 0 to have size "
+                           << N << ", got " << outputN;
+    if (K == ShapedType::kDynamic)
+      K = outputK;
+    else if (outputK != ShapedType::kDynamic && K != outputK)
+      return emitOpError() << "requires values_out dimension 1 to have size "
+                           << K << ", got " << outputK;
+    if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+        C != outputC)
+      return emitOpError() << "requires values_out dimension 2 to have size "
+                           << C << ", got " << outputC;
+  }
+  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
+    return emitOpError() << "requires dimensions K >= W, got K=" << K
+                         << " and W=" << W;
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..0176fc2883518 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 
 // -----
 // CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
   // CHECK: profiles: [ [pro_int, pro_fp] ]
   // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+  return %0 : tensor<13x28x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 2364985442e43..5630c33639d86 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
+func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
   // expected-error@+1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
-  return %0 : tensor<13x21x3xbf16>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
+  return %0 : tensor<13x26x3xbf16>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 223bf3b635e18..0dddf26fb1f85 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a
 
 // -----
 
-func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
+func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
   // expected-error@+1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
-  return %0 : tensor<13x210000000x3xf32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
+  return %0 : tensor<13x260000000x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 767fa833dedd4..1ac82400843ed 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -714,9 +714,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 
 // -----
 // CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+  return %0 : tensor<13x52x3xf32>
 }
 
 // -----
@@ -728,8 +728,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
 
 // -----
 // CHECK-LABEL: scatter_unranked_indices
-func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
@@ -1010,9 +1010,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26
 
 // -----
 // CHECK-LABEL: scatter_f8E5M2
-func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
-  return %0 : tensor<13x21x3xf8E5M2>
+func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
+  return %0 : tensor<13x52x3xf8E5M2>
 }
 
 // -----
@@ -1155,7 +1155,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1
 
 // -----
 // CHECK-LABEL: scatter_f8E4M3FN
-func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
-  return %0 : tensor<13x21x3xf8E4M3FN>
+func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
+  return %0 : tensor<13x29x3xf8E4M3FN>
 }
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 72669c62c95ca..fad4859351251 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
   // expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+  return %0 : tensor<13x28x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index e98b906377b22..9438179622aad 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
+func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
   // expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
-  return %0 : tensor<13x21x3xi32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
+  return %0 : tensor<13x27x3xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ad1e6c76c294..591a3f0acf65d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
 // -----
 
 // CHECK-LABEL: @scatter_static
-func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
-  // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
+func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
+  // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 990e0d954f54e..b3052369b055e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
   tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_indices_N
+func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
+  %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_N
+func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_N
+func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_K
+func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_W
+func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_C
+func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_C
+func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_K_W
+func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2025

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This adds verifier checks for the scatter op
to make sure the shapes of inputs and output
are consistent with respect to spec.


Full diff: https://github.com/llvm/llvm-project/pull/142661.diff

9 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+67)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+11-11)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+72)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a22e6b7aa9791..f707770970e5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2692,6 +2692,73 @@ LogicalResult tosa::ScatterOp::verify() {
           .failed()) {
     return failure();
   }
+
+  const ShapeAdaptor valuesInShape(getValuesIn().getType());
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  const ShapeAdaptor inputShape(getInput().getType());
+  const ShapeAdaptor outputShape(getValuesOut().getType());
+
+  int64_t N = ShapedType::kDynamic;
+  int64_t K = ShapedType::kDynamic;
+  int64_t W = ShapedType::kDynamic;
+  int64_t C = ShapedType::kDynamic;
+  if (valuesInShape.hasRank()) {
+    N = valuesInShape.getDimSize(0);
+    K = valuesInShape.getDimSize(1);
+    C = valuesInShape.getDimSize(2);
+  }
+  if (indicesShape.hasRank()) {
+    const int64_t indicesN = indicesShape.getDimSize(0);
+    W = indicesShape.getDimSize(1);
+    if (N == ShapedType::kDynamic)
+      N = indicesN;
+    else if (indicesN != ShapedType::kDynamic && N != indicesN)
+      return emitOpError() << "requires indices dimension 0 to have size " << N
+                           << ", got " << indicesN;
+  }
+  if (inputShape.hasRank()) {
+    const int64_t inputN = inputShape.getDimSize(0);
+    const int64_t inputW = inputShape.getDimSize(1);
+    const int64_t inputC = inputShape.getDimSize(2);
+    if (N == ShapedType::kDynamic)
+      N = inputN;
+    else if (inputN != ShapedType::kDynamic && N != inputN)
+      return emitOpError() << "requires input dimension 0 to have size " << N
+                           << ", got " << inputN;
+    if (W == ShapedType::kDynamic)
+      W = inputW;
+    else if (inputW != ShapedType::kDynamic && W != inputW)
+      return emitOpError() << "requires input dimension 1 to have size " << W
+                           << ", got " << inputW;
+
+    if (C == ShapedType::kDynamic)
+      C = inputC;
+    else if (inputC != ShapedType::kDynamic && C != inputC)
+      return emitOpError() << "requires input dimension 2 to have size " << C
+                           << ", got " << inputC;
+  }
+  if (outputShape.hasRank()) {
+    const int64_t outputN = outputShape.getDimSize(0);
+    const int64_t outputK = outputShape.getDimSize(1);
+    const int64_t outputC = outputShape.getDimSize(2);
+    if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+        N != outputN)
+      return emitOpError() << "requires values_out dimension 0 to have size "
+                           << N << ", got " << outputN;
+    if (K == ShapedType::kDynamic)
+      K = outputK;
+    else if (outputK != ShapedType::kDynamic && K != outputK)
+      return emitOpError() << "requires values_out dimension 1 to have size "
+                           << K << ", got " << outputK;
+    if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+        C != outputC)
+      return emitOpError() << "requires values_out dimension 2 to have size "
+                           << C << ", got " << outputC;
+  }
+  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
+    return emitOpError() << "requires dimensions K >= W, got K=" << K
+                         << " and W=" << W;
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..0176fc2883518 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 
 // -----
 // CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
   // CHECK: profiles: [ [pro_int, pro_fp] ]
   // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+  return %0 : tensor<13x28x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 2364985442e43..5630c33639d86 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
+func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
   // expected-error@+1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
-  return %0 : tensor<13x21x3xbf16>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
+  return %0 : tensor<13x26x3xbf16>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 223bf3b635e18..0dddf26fb1f85 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a
 
 // -----
 
-func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
+func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
   // expected-error@+1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
-  return %0 : tensor<13x210000000x3xf32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
+  return %0 : tensor<13x260000000x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 767fa833dedd4..1ac82400843ed 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -714,9 +714,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 
 // -----
 // CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+  return %0 : tensor<13x52x3xf32>
 }
 
 // -----
@@ -728,8 +728,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
 
 // -----
 // CHECK-LABEL: scatter_unranked_indices
-func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
@@ -1010,9 +1010,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26
 
 // -----
 // CHECK-LABEL: scatter_f8E5M2
-func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
-  return %0 : tensor<13x21x3xf8E5M2>
+func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
+  return %0 : tensor<13x52x3xf8E5M2>
 }
 
 // -----
@@ -1155,7 +1155,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1
 
 // -----
 // CHECK-LABEL: scatter_f8E4M3FN
-func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
-  return %0 : tensor<13x21x3xf8E4M3FN>
+func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
+  return %0 : tensor<13x29x3xf8E4M3FN>
 }
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 72669c62c95ca..fad4859351251 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
   // expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
-  return %0 : tensor<13x21x3xf32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+  return %0 : tensor<13x28x3xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index e98b906377b22..9438179622aad 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
+func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
   // expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
-  return %0 : tensor<13x21x3xi32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
+  return %0 : tensor<13x27x3xi32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ad1e6c76c294..591a3f0acf65d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
 // -----
 
 // CHECK-LABEL: @scatter_static
-func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
-  // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
+func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
+  // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 990e0d954f54e..b3052369b055e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
   tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_indices_N
+func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
+  %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_N
+func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_N
+func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_K
+func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_W
+func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_C
+func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_C
+func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_K_W
+func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
+  // expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
+  %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
+  return
+}

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Tai78641!

@lhutton1 lhutton1 merged commit c140783 into llvm:main Jun 5, 2025
14 checks passed
trcrsired pushed a commit to trcrsired/llvm-project that referenced this pull request Jun 5, 2025
This adds verifier checks for the scatter op
to make sure the shapes of inputs and output
are consistent with respect to spec.

Signed-off-by: Tai Ly <tai.ly@arm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants