-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOSA] bug fix infer shape for slice #108306
base: main
Are you sure you want to change the base?
Conversation
This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check: - size = -1 - size is out of bound - start is out of bound Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I8b59502a93cb332fe5c9a9f87970b83742538126
@llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check:
Full diff: https://github.com/llvm/llvm-project/pull/108306.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..4ca42cc99a507a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -842,8 +842,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- inferredReturnShapes.push_back(
- ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
+ auto start = adaptor.getStart();
+ auto size = adaptor.getSize();
+
+ // if size[i] is -1, all remaining elements in dimension i are included
+ // in the slice, similar to TF.
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ // initialize outputShape to all unknown
+ SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
+ if (inputShape.hasRank()) {
+ for (size_t i = 0; i < size.size(); i++) {
+ if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
+ (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
+ start[i] < inputShape.getDimSize(i))) {
+ // size[i] is not 0 and not < -1, and start[i] is in valid range
+ if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
+ // input shape has unknown dim[i] - only valid if size[i] > 0
+ if (size[i] > 0) {
+ outputShape[i] = size[i];
+ }
+ } else {
+ // input shape has known dim[i]
+ if (size[i] == -1) {
+ outputShape[i] = inputShape.getDimSize(i) - start[i];
+ } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
+ // start[i] + size[i] is within bound of input shape's dim[i]
+ outputShape[i] = size[i];
+ }
+ }
+ }
+ }
+ } else {
+ outputShape = convertToMlirShape(size);
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..d2314698afa925 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_slice_size_minus_one
+func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+ // this checks following
+ // dim 0: size=-1, input dim=? => inferred output dim is ?
+ // dim 1: size=-1 => inferred output dim is input_dim - start
+ // dim 2: size=-1, start=-1 => inferred output dim is ?
+ // dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
+ %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_size_out_of_bound
+func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: size=0 => inferred output dim is ?
+ // dim 1: size=-2 => inferred output dim is ?
+ // dim 3: start+size out of bound because size too big: inferred output dim is ?
+ // dim 4: size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_start_out_of_bound
+func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: start=-1 => inferred output dim is ?
+ // dim 1: start=8 => inferred output dim is ?
+ // dim 2: start+size out of bound: inferred output dim is ?
+ // dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_slice_dynamic
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
// CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>
|
@llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1 added tests to check:
Full diff: https://github.com/llvm/llvm-project/pull/108306.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..4ca42cc99a507a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -842,8 +842,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
SliceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- inferredReturnShapes.push_back(
- ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
+ auto start = adaptor.getStart();
+ auto size = adaptor.getSize();
+
+ // if size[i] is -1, all remaining elements in dimension i are included
+ // in the slice, similar to TF.
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ // initialize outputShape to all unknown
+ SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
+ if (inputShape.hasRank()) {
+ for (size_t i = 0; i < size.size(); i++) {
+ if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
+ (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
+ start[i] < inputShape.getDimSize(i))) {
+ // size[i] is not 0 and not < -1, and start[i] is in valid range
+ if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
+ // input shape has unknown dim[i] - only valid if size[i] > 0
+ if (size[i] > 0) {
+ outputShape[i] = size[i];
+ }
+ } else {
+ // input shape has known dim[i]
+ if (size[i] == -1) {
+ outputShape[i] = inputShape.getDimSize(i) - start[i];
+ } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
+ // start[i] + size[i] is within bound of input shape's dim[i]
+ outputShape[i] = size[i];
+ }
+ }
+ }
+ }
+ } else {
+ outputShape = convertToMlirShape(size);
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..d2314698afa925 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_slice_size_minus_one
+func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+ // this checks following
+ // dim 0: size=-1, input dim=? => inferred output dim is ?
+ // dim 1: size=-1 => inferred output dim is input_dim - start
+ // dim 2: size=-1, start=-1 => inferred output dim is ?
+ // dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
+ %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_size_out_of_bound
+func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: size=0 => inferred output dim is ?
+ // dim 1: size=-2 => inferred output dim is ?
+ // dim 3: start+size out of bound because size too big: inferred output dim is ?
+ // dim 4: size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_start_out_of_bound
+func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+ // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+ // this checks following
+ // dim 0: start=-1 => inferred output dim is ?
+ // dim 1: start=8 => inferred output dim is ?
+ // dim 2: start+size out of bound: inferred output dim is ?
+ // dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
+ %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_slice_dynamic
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
// CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>
|
This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1
added tests to check: