Skip to content

Commit 5f58f3d

Browse files
[mlir][tosa] Avoid overflow in reduction folders (#132786)
Avoid operations that can overflow in constant folders for `tosa.reduce_max` and `tosa.reduce_min` Includes tests to avoid regressions Signed-off-by: Ian Tayler Lessa <ian.taylerlessa@arm.com>
1 parent bed4c58 commit 5f58f3d

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

+2-4
Original file line numberDiff line numberDiff line change
@@ -1728,8 +1728,7 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
17281728

17291729
/// Return the max of the two integer operands
17301730
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1731-
const llvm::APInt subtractRes = leftOperand - rightOperand;
1732-
return (!subtractRes.isNegative()) ? leftOperand : rightOperand;
1731+
return (leftOperand.sge(rightOperand)) ? leftOperand : rightOperand;
17331732
}
17341733
}];
17351734
}
@@ -1769,8 +1768,7 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
17691768

17701769
/// Return the min of the two integer operands
17711770
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1772-
const llvm::APInt subtractRes = leftOperand - rightOperand;
1773-
return (!subtractRes.isNegative()) ? rightOperand : leftOperand;
1771+
return (leftOperand.sle(rightOperand)) ? leftOperand : rightOperand;
17741772
}
17751773
}];
17761774
}

mlir/test/Dialect/Tosa/constant-op-fold.mlir

+25
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,18 @@ func.func @reduce_max_constant() -> tensor<1x1x1xi32> {
883883
return %0 : tensor<1x1x1xi32>
884884
}
885885

886+
// -----
887+
888+
func.func @reduce_max_constant_no_overflow() -> tensor<1xi8> {
889+
// CHECK-LABEL: func.func @reduce_max_constant_no_overflow() -> tensor<1xi8> {
890+
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<120> : tensor<1xi8>}> : () -> tensor<1xi8>
891+
// CHECK: return %[[VAL_0]] : tensor<1xi8>
892+
// CHECK: }
893+
%const = "tosa.const"() <{values = dense<[-127, 120, -126]> : tensor<3xi8>}> : () -> tensor<3xi8>
894+
%0 = tosa.reduce_max %const {axis = 0 : i32} : (tensor<3xi8>) -> tensor<1xi8>
895+
return %0 : tensor<1xi8>
896+
}
897+
886898
// -----
887899

888900
func.func @reduce_min_constant() -> tensor<1x3xi32> {
@@ -968,6 +980,19 @@ func.func @reduce_min_constant() -> tensor<1x1x1xi32> {
968980
return %0 : tensor<1x1x1xi32>
969981
}
970982

983+
// -----
984+
985+
func.func @reduce_min_constant_no_overflow() -> tensor<1xi8> {
986+
// CHECK-LABEL: func.func @reduce_min_constant_no_overflow() -> tensor<1xi8> {
987+
// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<-127> : tensor<1xi8>}> : () -> tensor<1xi8>
988+
// CHECK: return %[[VAL_0]] : tensor<1xi8>
989+
// CHECK: }
990+
%const = "tosa.const"() <{values = dense<[-127, 120, -126]> : tensor<3xi8>}> : () -> tensor<3xi8>
991+
%0 = tosa.reduce_min %const {axis = 0 : i32} : (tensor<3xi8>) -> tensor<1xi8>
992+
return %0 : tensor<1xi8>
993+
}
994+
995+
971996
// -----
972997

973998
func.func @reduce_any_constant() -> tensor<1x3xi1> {

0 commit comments

Comments
 (0)