-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[tosa] Add verifier checks for Scatter #142661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This adds verifier checks for the scatter op to make sure the shapes of inputs and output are consistent with respect to spec. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I59531fa63e2d1dbd2865e0ef9b08b76991915c9a
@llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis adds verifier checks for the scatter op Full diff: https://github.com/llvm/llvm-project/pull/142661.diff 9 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a22e6b7aa9791..f707770970e5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2692,6 +2692,73 @@ LogicalResult tosa::ScatterOp::verify() {
.failed()) {
return failure();
}
+
+ const ShapeAdaptor valuesInShape(getValuesIn().getType());
+ const ShapeAdaptor indicesShape(getIndices().getType());
+ const ShapeAdaptor inputShape(getInput().getType());
+ const ShapeAdaptor outputShape(getValuesOut().getType());
+
+ int64_t N = ShapedType::kDynamic;
+ int64_t K = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+ if (valuesInShape.hasRank()) {
+ N = valuesInShape.getDimSize(0);
+ K = valuesInShape.getDimSize(1);
+ C = valuesInShape.getDimSize(2);
+ }
+ if (indicesShape.hasRank()) {
+ const int64_t indicesN = indicesShape.getDimSize(0);
+ W = indicesShape.getDimSize(1);
+ if (N == ShapedType::kDynamic)
+ N = indicesN;
+ else if (indicesN != ShapedType::kDynamic && N != indicesN)
+ return emitOpError() << "requires indices dimension 0 to have size " << N
+ << ", got " << indicesN;
+ }
+ if (inputShape.hasRank()) {
+ const int64_t inputN = inputShape.getDimSize(0);
+ const int64_t inputW = inputShape.getDimSize(1);
+ const int64_t inputC = inputShape.getDimSize(2);
+ if (N == ShapedType::kDynamic)
+ N = inputN;
+ else if (inputN != ShapedType::kDynamic && N != inputN)
+ return emitOpError() << "requires input dimension 0 to have size " << N
+ << ", got " << inputN;
+ if (W == ShapedType::kDynamic)
+ W = inputW;
+ else if (inputW != ShapedType::kDynamic && W != inputW)
+ return emitOpError() << "requires input dimension 1 to have size " << W
+ << ", got " << inputW;
+
+ if (C == ShapedType::kDynamic)
+ C = inputC;
+ else if (inputC != ShapedType::kDynamic && C != inputC)
+ return emitOpError() << "requires input dimension 2 to have size " << C
+ << ", got " << inputC;
+ }
+ if (outputShape.hasRank()) {
+ const int64_t outputN = outputShape.getDimSize(0);
+ const int64_t outputK = outputShape.getDimSize(1);
+ const int64_t outputC = outputShape.getDimSize(2);
+ if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+ N != outputN)
+ return emitOpError() << "requires values_out dimension 0 to have size "
+ << N << ", got " << outputN;
+ if (K == ShapedType::kDynamic)
+ K = outputK;
+ else if (outputK != ShapedType::kDynamic && K != outputK)
+ return emitOpError() << "requires values_out dimension 1 to have size "
+ << K << ", got " << outputK;
+ if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+ C != outputC)
+ return emitOpError() << "requires values_out dimension 2 to have size "
+ << C << ", got " << outputC;
+ }
+ if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
+ return emitOpError() << "requires dimensions K >= W, got K=" << K
+ << " and W=" << W;
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..0176fc2883518 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
// -----
// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+ return %0 : tensor<13x28x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 2364985442e43..5630c33639d86 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
}
// -----
-func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
+func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
// expected-error@+1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
- return %0 : tensor<13x21x3xbf16>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
+ return %0 : tensor<13x26x3xbf16>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 223bf3b635e18..0dddf26fb1f85 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a
// -----
-func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
+func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
// expected-error@+1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
- return %0 : tensor<13x210000000x3xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
+ return %0 : tensor<13x260000000x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 767fa833dedd4..1ac82400843ed 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -714,9 +714,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
// -----
// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+ return %0 : tensor<13x52x3xf32>
}
// -----
@@ -728,8 +728,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// -----
// CHECK-LABEL: scatter_unranked_indices
-func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1010,9 +1010,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26
// -----
// CHECK-LABEL: scatter_f8E5M2
-func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
- return %0 : tensor<13x21x3xf8E5M2>
+func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
+ return %0 : tensor<13x52x3xf8E5M2>
}
// -----
@@ -1155,7 +1155,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1
// -----
// CHECK-LABEL: scatter_f8E4M3FN
-func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
- return %0 : tensor<13x21x3xf8E4M3FN>
+func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
+ return %0 : tensor<13x29x3xf8E4M3FN>
}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 72669c62c95ca..fad4859351251 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
}
// -----
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+ return %0 : tensor<13x28x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index e98b906377b22..9438179622aad 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
}
// -----
-func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
+func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
- return %0 : tensor<13x21x3xi32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
+ return %0 : tensor<13x27x3xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ad1e6c76c294..591a3f0acf65d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
// -----
// CHECK-LABEL: @scatter_static
-func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
- // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
+func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
+ // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 990e0d954f54e..b3052369b055e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_indices_N
+func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
+ %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_N
+func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_N
+func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_K
+func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_W
+func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_C
+func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_C
+func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_K_W
+func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
|
@llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis adds verifier checks for the scatter op Full diff: https://github.com/llvm/llvm-project/pull/142661.diff 9 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a22e6b7aa9791..f707770970e5f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2692,6 +2692,73 @@ LogicalResult tosa::ScatterOp::verify() {
.failed()) {
return failure();
}
+
+ const ShapeAdaptor valuesInShape(getValuesIn().getType());
+ const ShapeAdaptor indicesShape(getIndices().getType());
+ const ShapeAdaptor inputShape(getInput().getType());
+ const ShapeAdaptor outputShape(getValuesOut().getType());
+
+ int64_t N = ShapedType::kDynamic;
+ int64_t K = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+ if (valuesInShape.hasRank()) {
+ N = valuesInShape.getDimSize(0);
+ K = valuesInShape.getDimSize(1);
+ C = valuesInShape.getDimSize(2);
+ }
+ if (indicesShape.hasRank()) {
+ const int64_t indicesN = indicesShape.getDimSize(0);
+ W = indicesShape.getDimSize(1);
+ if (N == ShapedType::kDynamic)
+ N = indicesN;
+ else if (indicesN != ShapedType::kDynamic && N != indicesN)
+ return emitOpError() << "requires indices dimension 0 to have size " << N
+ << ", got " << indicesN;
+ }
+ if (inputShape.hasRank()) {
+ const int64_t inputN = inputShape.getDimSize(0);
+ const int64_t inputW = inputShape.getDimSize(1);
+ const int64_t inputC = inputShape.getDimSize(2);
+ if (N == ShapedType::kDynamic)
+ N = inputN;
+ else if (inputN != ShapedType::kDynamic && N != inputN)
+ return emitOpError() << "requires input dimension 0 to have size " << N
+ << ", got " << inputN;
+ if (W == ShapedType::kDynamic)
+ W = inputW;
+ else if (inputW != ShapedType::kDynamic && W != inputW)
+ return emitOpError() << "requires input dimension 1 to have size " << W
+ << ", got " << inputW;
+
+ if (C == ShapedType::kDynamic)
+ C = inputC;
+ else if (inputC != ShapedType::kDynamic && C != inputC)
+ return emitOpError() << "requires input dimension 2 to have size " << C
+ << ", got " << inputC;
+ }
+ if (outputShape.hasRank()) {
+ const int64_t outputN = outputShape.getDimSize(0);
+ const int64_t outputK = outputShape.getDimSize(1);
+ const int64_t outputC = outputShape.getDimSize(2);
+ if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+ N != outputN)
+ return emitOpError() << "requires values_out dimension 0 to have size "
+ << N << ", got " << outputN;
+ if (K == ShapedType::kDynamic)
+ K = outputK;
+ else if (outputK != ShapedType::kDynamic && K != outputK)
+ return emitOpError() << "requires values_out dimension 1 to have size "
+ << K << ", got " << outputK;
+ if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+ C != outputC)
+ return emitOpError() << "requires values_out dimension 2 to have size "
+ << C << ", got " << outputC;
+ }
+ if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
+ return emitOpError() << "requires dimensions K >= W, got K=" << K
+ << " and W=" << W;
+
return success();
}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..0176fc2883518 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
// -----
// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+ return %0 : tensor<13x28x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 2364985442e43..5630c33639d86 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
}
// -----
-func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
+func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
// expected-error@+1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
- return %0 : tensor<13x21x3xbf16>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
+ return %0 : tensor<13x26x3xbf16>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 223bf3b635e18..0dddf26fb1f85 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a
// -----
-func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
+func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
// expected-error@+1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
- return %0 : tensor<13x210000000x3xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
+ return %0 : tensor<13x260000000x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 767fa833dedd4..1ac82400843ed 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -714,9 +714,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
// -----
// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+ return %0 : tensor<13x52x3xf32>
}
// -----
@@ -728,8 +728,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// -----
// CHECK-LABEL: scatter_unranked_indices
-func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
+func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1010,9 +1010,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26
// -----
// CHECK-LABEL: scatter_f8E5M2
-func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
- return %0 : tensor<13x21x3xf8E5M2>
+func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
+ return %0 : tensor<13x52x3xf8E5M2>
}
// -----
@@ -1155,7 +1155,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1
// -----
// CHECK-LABEL: scatter_f8E4M3FN
-func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
- return %0 : tensor<13x21x3xf8E4M3FN>
+func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
+ return %0 : tensor<13x29x3xf8E4M3FN>
}
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 72669c62c95ca..fad4859351251 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
}
// -----
-func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
+func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
- return %0 : tensor<13x21x3xf32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
+ return %0 : tensor<13x28x3xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index e98b906377b22..9438179622aad 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
}
// -----
-func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
+func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
- return %0 : tensor<13x21x3xi32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
+ return %0 : tensor<13x27x3xi32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 1ad1e6c76c294..591a3f0acf65d 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
// -----
// CHECK-LABEL: @scatter_static
-func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
- // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
+func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
+ // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
return
}
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 990e0d954f54e..b3052369b055e 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_indices_N
+func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
+ %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_N
+func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_N
+func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_K
+func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_W
+func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_input_C
+func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_out_C
+func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @scatter_invalid_K_W
+func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
+ // expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
+ %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Tai78641!
This adds verifier checks for the scatter op to make sure the shapes of inputs and output are consistent with respect to spec. Signed-off-by: Tai Ly <tai.ly@arm.com>
This adds verifier checks for the scatter op
to make sure the shapes of inputs and output
are consistent with respect to spec.