Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] bug fix infer shape for slice #108306

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Tai78641
Copy link
Contributor

This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1

added tests to check:

  • size = -1
  • size is out of bound
  • start is out of bound

This fixes the infer output shape of TOSA slice op for start/size values
that are out-of-bound or -1

added tests to check:
  - size = -1
  - size is out of bound
  - start is out of bound

Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I8b59502a93cb332fe5c9a9f87970b83742538126
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1

added tests to check:

  • size = -1
  • size is out of bound
  • start is out of bound

Full diff: https://github.com/llvm/llvm-project/pull/108306.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+34-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+42)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..4ca42cc99a507a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -842,8 +842,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     SliceOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  inferredReturnShapes.push_back(
-      ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
+  auto start = adaptor.getStart();
+  auto size = adaptor.getSize();
+
+  // if size[i] is -1, all remaining elements in dimension i are included
+  // in the slice, similar to TF.
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  // initialize outputShape to all unknown
+  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
+  if (inputShape.hasRank()) {
+    for (size_t i = 0; i < size.size(); i++) {
+      if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
+          (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
+           start[i] < inputShape.getDimSize(i))) {
+        // size[i] is not 0 and not < -1, and start[i] is in valid range
+        if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
+          // input shape has unknown dim[i] - only valid if size[i] > 0
+          if (size[i] > 0) {
+            outputShape[i] = size[i];
+          }
+        } else {
+          // input shape has known dim[i]
+          if (size[i] == -1) {
+            outputShape[i] = inputShape.getDimSize(i) - start[i];
+          } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
+            // start[i] + size[i] is within bound of input shape's dim[i]
+            outputShape[i] = size[i];
+          }
+        }
+      }
+    }
+  } else {
+    outputShape = convertToMlirShape(size);
+  }
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..d2314698afa925 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_slice_size_minus_one
+func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
+  // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+  // this checks following
+  //  dim 0: size=-1, input dim=? => inferred output dim is ?
+  //  dim 1: size=-1 => inferred output dim is input_dim - start
+  //  dim 2: size=-1, start=-1 => inferred output dim is ?
+  //  dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
+  %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_size_out_of_bound
+func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+  // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+  // this checks following
+  //  dim 0: size=0 => inferred output dim is ?
+  //  dim 1: size=-2 => inferred output dim is ?
+  //  dim 3: start+size out of bound because size too big: inferred output dim is ?
+  //  dim 4: size=4, input dim=? => inferred output dim is 4
+  %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_start_out_of_bound
+func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+  // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+  // this checks following
+  //  dim 0: start=-1 => inferred output dim is ?
+  //  dim 1: start=8 => inferred output dim is ?
+  //  dim 2: start+size out of bound: inferred output dim is ?
+  //  dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
+  %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_slice_dynamic
 func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
   // CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This fixes the infer output shape of TOSA slice op for start/size values that are out-of-bound or -1

added tests to check:

  • size = -1
  • size is out of bound
  • start is out of bound

Full diff: https://github.com/llvm/llvm-project/pull/108306.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+34-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+42)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0d0241fea5152c..4ca42cc99a507a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -842,8 +842,40 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     SliceOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  inferredReturnShapes.push_back(
-      ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
+  auto start = adaptor.getStart();
+  auto size = adaptor.getSize();
+
+  // if size[i] is -1, all remaining elements in dimension i are included
+  // in the slice, similar to TF.
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  // initialize outputShape to all unknown
+  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
+  if (inputShape.hasRank()) {
+    for (size_t i = 0; i < size.size(); i++) {
+      if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
+          (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
+           start[i] < inputShape.getDimSize(i))) {
+        // size[i] is not 0 and not < -1, and start[i] is in valid range
+        if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
+          // input shape has unknown dim[i] - only valid if size[i] > 0
+          if (size[i] > 0) {
+            outputShape[i] = size[i];
+          }
+        } else {
+          // input shape has known dim[i]
+          if (size[i] == -1) {
+            outputShape[i] = inputShape.getDimSize(i) - start[i];
+          } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
+            // start[i] + size[i] is within bound of input shape's dim[i]
+            outputShape[i] = size[i];
+          }
+        }
+      }
+    }
+  } else {
+    outputShape = convertToMlirShape(size);
+  }
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index d46de740800e93..d2314698afa925 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -532,6 +532,48 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_slice_size_minus_one
+func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
+  // CHECK: tosa.slice %arg0 {size = array<i64: -1, -1, -1, -1>, start = array<i64: 0, 1, -1, 8>} : (tensor<?x8x8x8xi32>) -> tensor<?x7x?x?xi32>
+  // this checks following
+  //  dim 0: size=-1, input dim=? => inferred output dim is ?
+  //  dim 1: size=-1 => inferred output dim is input_dim - start
+  //  dim 2: size=-1, start=-1 => inferred output dim is ?
+  //  dim 3: size=-1, start=8 => inferred output dim is ? because start is out of bound
+  %2= tosa.slice %arg0 { start = array<i64: 0, 1, -1, 8>, size = array<i64: -1, -1, -1, -1> } : (tensor<?x8x8x8xi32>) -> tensor<?x?x?x?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_size_out_of_bound
+func.func @test_slice_size_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+  // CHECK: tosa.slice %arg0 {size = array<i64: 0, -2, 9, 4>, start = array<i64: 0, 0, 0, 0>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+  // this checks following
+  //  dim 0: size=0 => inferred output dim is ?
+  //  dim 1: size=-2 => inferred output dim is ?
+  //  dim 3: start+size out of bound because size too big: inferred output dim is ?
+  //  dim 4: size=4, input dim=? => inferred output dim is 4
+  %2= tosa.slice %arg0 { start = array<i64: 0, 0, 0, 0>, size = array<i64: 0, -2, 9, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @test_slice_start_out_of_bound
+func.func @test_slice_start_out_of_bound(%arg0 : tensor<8x8x8x?xi32>) -> () {
+  // CHECK: tosa.slice %arg0 {size = array<i64: 1, 1, 3, 4>, start = array<i64: -1, 8, 6, 8000000>} : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x4xi32>
+  // this checks following
+  //  dim 0: start=-1 => inferred output dim is ?
+  //  dim 1: start=8 => inferred output dim is ?
+  //  dim 2: start+size out of bound: inferred output dim is ?
+  //  dim 3: start=8000000, size=4, input dim=? => inferred output dim is 4
+  %2= tosa.slice %arg0 { start = array<i64: -1, 8, 6, 8000000>, size = array<i64: 1, 1, 3, 4> } : (tensor<8x8x8x?xi32>) -> tensor<?x?x?x?xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_slice_dynamic
 func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
   // CHECK: tosa.slice %arg0 {size = array<i64: 7, -1, 1>, start = array<i64: 1, 0, 0>} : (tensor<10x?x2xf32>) -> tensor<7x?x1xf32>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants