Skip to content

[mlir][tosa] Enhance error_if and verify checks for RESCALE Op #137021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025

Conversation

psunn
Copy link
Contributor

@psunn psunn commented Apr 23, 2025

  • add verifier for rank-0 input with per-channel
  • add checkErrorIfRescale to tosa validation pass that align with TOSAv1.0
  • add LIT tests

@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Peng Sun (psunn)

Changes
  • add verifier for rank-0 input with per-channel
  • add checkErrorIfRescale to tosa validation pass that align with TOSAv1.0
  • add LIT tests

Full diff: https://github.com/llvm/llvm-project/pull/137021.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+6)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+82-1)
  • (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+108)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+12)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+13)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1ab4ce7d4558b..f1bed1241f971 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3186,6 +3186,12 @@ LogicalResult RescaleOp::verify() {
   // otherwise numChannel is dimension in input shape's last axis
   int64_t numChannels = 1;
   if (getPerChannel()) {
+    if (inputType.getRank() < 1) {
+      emitOpError("requires input to be at least rank 1 when per_channel is "
+                  "true, but got rank ")
+          << inputType.getRank();
+      return failure();
+    }
     numChannels = inputType.getDimSize(inputType.getRank() - 1);
   }
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index baa202833e285..fa337f350197c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1033,8 +1033,89 @@ bool checkErrorIfTable(Operation *op) {
   return true;
 }
 
+bool checkErrorIfRescale(Operation *op) {
+  auto rescale = dyn_cast<tosa::RescaleOp>(op);
+  if (!rescale)
+    return true;
+
+  auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
+  auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
+  if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
+      !outputType.getElementType().isInteger())
+    return true;
+
+  auto inElemType = inputType.getElementType();
+  auto outElemType = outputType.getElementType();
+  auto inWidth = inElemType.getIntOrFloatBitWidth();
+  auto outWidth = outElemType.getIntOrFloatBitWidth();
+
+  bool inputUnsigned = rescale.getInputUnsigned();
+  bool outputUnsigned = rescale.getOutputUnsigned();
+
+  bool scale32 = rescale.getScale32();
+  auto roundingMode = rescale.getRoundingMode();
+
+
+  // ERROR_IF(scale32 && is_same<in_t,i48_t>())
+  if (scale32 && inWidth == 48) {
+    op->emitOpError() << "scale32 is not allowed with 48-bit input.";
+    return false;
+  }
+
+  // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
+  if (!scale32 && roundingMode == "DOUBLE_ROUND") {
+    op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
+    return false;
+  }
+
+  // ERROR_IF(input_unsigned && output_unsigned)
+  if (inputUnsigned && outputUnsigned) {
+    op->emitOpError() << "input and output cannot be both unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
+  if (outWidth == 32 && inputUnsigned) {
+    op->emitOpError() << "i32 output type is not allowed with unsigned input.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
+  if (inWidth == 32 && outputUnsigned) {
+    op->emitOpError() << "i32 input type is not allowed with unsigned output.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
+  if (inWidth == 48 && outputUnsigned) {
+    op->emitOpError() << "i48 input type is not allowed with unsigned output.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
+  if (inWidth == 48 && inputUnsigned) {
+    op->emitOpError() << "i48 input type cannot be unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
+  if (inWidth == 32 && inputUnsigned) {
+    op->emitOpError() << "i32 input type cannot be unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
+  if (outWidth == 32 && outputUnsigned) {
+    op->emitOpError() << "i32 output type cannot be unsigned.";
+    return false;
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
-  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
+  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
+      !checkErrorIfTable(op) || !checkErrorIfRescale(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 65a69be91e0c8..c6a173c92ff9a 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -129,3 +129,111 @@ func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) ->
     %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
     return %0 : tensor<2x64xi8>
 }
+
+// -----
+// CHECK-LABEL: test_error_input_zp_not_allowed
+func.func @test_error_input_zp_not_allowed(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_scale32_with_i48
+func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_input_output_unsigned
+func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error@+1 {{'tosa.rescale' op input and output cannot be both unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_output_unsigned_input
+func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error@+1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_input_unsigned_output
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_input_unsigned_output
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_unsigned_input
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_input
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_output
+func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error@+1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..fe4cc49e89c0d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1669,6 +1669,18 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
   return %0 : tensor<13x21x3xi16>
 }
 
+// -----
+// CHECK-LABEL: test_error_double_round_without_scale32
+func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error@+1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
 // -----
 // CHECK-LABEL: test_matmul_a_zp_same_element_type
 func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index fb8726cba1853..a42cf03a0a5cb 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -319,3 +319,16 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf
     : (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32>
   return %0 : tensor<1x4x8x19x34xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_error_scalar_input_with_per_channel
+func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor<i16> {
+  %multiplier = "tosa.const"() {values = dense<4> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error@+1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
+  return %0 : tensor<i16>
+}

Copy link

github-actions bot commented Apr 23, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@psunn psunn force-pushed the psunn/rescale_errorif branch from 621c120 to e0d037f Compare April 23, 2025 17:18
@wonjeon
Copy link
Contributor

wonjeon commented Apr 23, 2025

Thanks for the patch. Looks good to me overall.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks! had one small nit

  * add verifier for rank-0 input with per-channel
  * add checkErrorIfRescale to tosa validation pass that align with
TOSAv1.0
  * add LIT tests

Change-Id: Ia07e8c2ee66d8ee4113bea5ad9fa859b5986b009
Signed-off-by: Peng Sun <peng.sun@arm.com>
@psunn psunn force-pushed the psunn/rescale_errorif branch from e0d037f to 99ea7d6 Compare April 24, 2025 20:53
@psunn psunn requested a review from lhutton1 April 24, 2025 20:54
@lhutton1 lhutton1 merged commit e046f20 into llvm:main Apr 25, 2025
11 checks passed
@psunn psunn deleted the psunn/rescale_errorif branch April 25, 2025 08:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants