-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][tosa] Enhance error_if and verify checks for RESCALE Op #137021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
psunn
commented
Apr 23, 2025
- add verifier for rank-0 input with per-channel
- add checkErrorIfRescale to tosa validation pass that align with TOSAv1.0
- add LIT tests
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Peng Sun (psunn) Changes
Full diff: https://github.com/llvm/llvm-project/pull/137021.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1ab4ce7d4558b..f1bed1241f971 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3186,6 +3186,12 @@ LogicalResult RescaleOp::verify() {
// otherwise numChannel is dimension in input shape's last axis
int64_t numChannels = 1;
if (getPerChannel()) {
+ if (inputType.getRank() < 1) {
+ emitOpError("requires input to be at least rank 1 when per_channel is "
+ "true, but got rank ")
+ << inputType.getRank();
+ return failure();
+ }
numChannels = inputType.getDimSize(inputType.getRank() - 1);
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index baa202833e285..fa337f350197c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1033,8 +1033,89 @@ bool checkErrorIfTable(Operation *op) {
return true;
}
+bool checkErrorIfRescale(Operation *op) {
+ auto rescale = dyn_cast<tosa::RescaleOp>(op);
+ if (!rescale)
+ return true;
+
+ auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
+ auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
+ if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
+ !outputType.getElementType().isInteger())
+ return true;
+
+ auto inElemType = inputType.getElementType();
+ auto outElemType = outputType.getElementType();
+ auto inWidth = inElemType.getIntOrFloatBitWidth();
+ auto outWidth = outElemType.getIntOrFloatBitWidth();
+
+ bool inputUnsigned = rescale.getInputUnsigned();
+ bool outputUnsigned = rescale.getOutputUnsigned();
+
+ bool scale32 = rescale.getScale32();
+ auto roundingMode = rescale.getRoundingMode();
+
+
+ // ERROR_IF(scale32 && is_same<in_t,i48_t>())
+ if (scale32 && inWidth == 48) {
+ op->emitOpError() << "scale32 is not allowed with 48-bit input.";
+ return false;
+ }
+
+ // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
+ if (!scale32 && roundingMode == "DOUBLE_ROUND") {
+ op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
+ return false;
+ }
+
+ // ERROR_IF(input_unsigned && output_unsigned)
+ if (inputUnsigned && outputUnsigned) {
+ op->emitOpError() << "input and output cannot be both unsigned.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
+ if (outWidth == 32 && inputUnsigned) {
+ op->emitOpError() << "i32 output type is not allowed with unsigned input.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
+ if (inWidth == 32 && outputUnsigned) {
+ op->emitOpError() << "i32 input type is not allowed with unsigned output.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
+ if (inWidth == 48 && outputUnsigned) {
+ op->emitOpError() << "i48 input type is not allowed with unsigned output.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
+ if (inWidth == 48 && inputUnsigned) {
+ op->emitOpError() << "i48 input type cannot be unsigned.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
+ if (inWidth == 32 && inputUnsigned) {
+ op->emitOpError() << "i32 input type cannot be unsigned.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
+ if (outWidth == 32 && outputUnsigned) {
+ op->emitOpError() << "i32 output type cannot be unsigned.";
+ return false;
+ }
+
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
- if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
+ if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
+ !checkErrorIfTable(op) || !checkErrorIfRescale(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 65a69be91e0c8..c6a173c92ff9a 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -129,3 +129,111 @@ func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) ->
%0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
return %0 : tensor<2x64xi8>
}
+
+// -----
+// CHECK-LABEL: test_error_input_zp_not_allowed
+func.func @test_error_input_zp_not_allowed(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_scale32_with_i48
+func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_input_output_unsigned
+func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+ // expected-error@+1 {{'tosa.rescale' op input and output cannot be both unsigned}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+ return %0 : tensor<1xi16>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_output_unsigned_input
+func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error@+1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+ return %0 : tensor<1xi32>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_input_unsigned_output
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_input_unsigned_output
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_unsigned_input
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_input
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_output
+func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error@+1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+ return %0 : tensor<1xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..fe4cc49e89c0d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1669,6 +1669,18 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
return %0 : tensor<13x21x3xi16>
}
+// -----
+// CHECK-LABEL: test_error_double_round_without_scale32
+func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+ // expected-error@+1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+ return %0 : tensor<1xi16>
+}
+
// -----
// CHECK-LABEL: test_matmul_a_zp_same_element_type
func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index fb8726cba1853..a42cf03a0a5cb 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -319,3 +319,16 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf
: (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32>
return %0 : tensor<1x4x8x19x34xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_error_scalar_input_with_per_channel
+func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor<i16> {
+ %multiplier = "tosa.const"() {values = dense<4> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+ // expected-error@+1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
+ return %0 : tensor<i16>
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
621c120
to
e0d037f
Compare
Thanks for the patch. Looks good to me overall. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks! had one small nit
* add verifier for rank-0 input with per-channel * add checkErrorIfRescale to tosa validation pass that align with TOSAv1.0 * add LIT tests Change-Id: Ia07e8c2ee66d8ee4113bea5ad9fa859b5986b009 Signed-off-by: Peng Sun <peng.sun@arm.com>
e0d037f
to
99ea7d6
Compare