Skip to content

Commit c099caa

Browse files
authored
[MLIR][TOSA-Linalg] Fix rescale lowering for unsigned input zp (#138780)
Lowering of tosa.rescale to Linalg unconditionally sign-extend the input zero-point value, even when unsigned_input is true. This commit refactor zeropoint handling to share the same logic between input and output zeropoint.
1 parent e7bf750 commit c099caa

File tree

5 files changed

+85
-57
lines changed

5 files changed

+85
-57
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,6 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
8282
rhsOrResult);
8383
}
8484

85-
template <typename T>
86-
static arith::ConstantOp
87-
createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
88-
OpBuilder &rewriter) {
89-
auto castedN = static_cast<T>(zp);
90-
return rewriter.create<arith::ConstantOp>(
91-
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
92-
}
93-
9485
static Value createLinalgBodyCalculationForElementwiseOp(
9586
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
9687
ConversionPatternRewriter &rewriter) {
@@ -1467,21 +1458,19 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14671458
Value value = blockArgs[0];
14681459
Type valueTy = value.getType();
14691460

1470-
// For now we do all of our math in 64-bit. This is not optimal but
1471-
// should be correct for now, consider computing correct bit depth
1472-
// later.
1473-
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1474-
14751461
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
14761462
if (failed(maybeIZp)) {
14771463
(void)rewriter.notifyMatchFailure(
14781464
op, "input zero point cannot be statically determined");
14791465
return;
14801466
}
14811467

1482-
auto inputZp = createConstOpFromZpVal<int32_t>(
1483-
op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
1484-
nestedBuilder);
1468+
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1469+
// Extend zeropoint for sub-32bits widths.
1470+
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1471+
auto inputZp = nestedBuilder.create<arith::ConstantOp>(
1472+
loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1473+
*maybeIZp));
14851474

14861475
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
14871476
if (failed(maybeOZp)) {
@@ -1490,16 +1479,14 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14901479
return;
14911480
};
14921481

1493-
// pre-process OutputZP as it can be unsigned
1494-
auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
1495-
APInt OZp(outBitwidth, !op.getOutputUnsigned());
1496-
OZp = static_cast<int64_t>(*maybeOZp);
1497-
*maybeOZp = op.getOutputUnsigned()
1498-
? static_cast<int64_t>(OZp.getZExtValue())
1499-
: OZp.getSExtValue();
1500-
1501-
auto outputZp = createConstOpFromZpVal<int32_t>(
1502-
op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
1482+
IntegerType outIntType =
1483+
cast<IntegerType>(blockArgs.back().getType());
1484+
unsigned outBitWidth = outIntType.getWidth();
1485+
const int32_t outAttrBitwidth = 32;
1486+
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1487+
auto outputZp = nestedBuilder.create<arith::ConstantOp>(
1488+
loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1489+
*maybeOZp));
15031490

15041491
Value multiplier = multiplierConstant ? multiplierConstant
15051492
: blockArgs[multiplierArg];
@@ -1527,10 +1514,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15271514
nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
15281515

15291516
// Saturate to the output size.
1530-
IntegerType outIntType =
1531-
cast<IntegerType>(blockArgs.back().getType());
1532-
unsigned outBitWidth = outIntType.getWidth();
1533-
15341517
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
15351518
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
15361519

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,7 +2118,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
21182118
// return failure if val is not a constant
21192119
// set zp to -1 if val is non-zero float or val is not integer nor float
21202120
// otherwise set zp to val's constant value
2121-
static FailureOr<int64_t> getZeroPoint(Value val) {
2121+
static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
21222122
ElementsAttr zpAttr;
21232123
if (!matchPattern(val, m_Constant(&zpAttr))) {
21242124
return failure();
@@ -2135,7 +2135,10 @@ static FailureOr<int64_t> getZeroPoint(Value val) {
21352135
}
21362136

21372137
if (llvm::isa<IntegerType>(zpElemType)) {
2138-
return zpAttr.getValues<APInt>()[0].getSExtValue();
2138+
if (signExtend)
2139+
return zpAttr.getValues<APInt>()[0].getSExtValue();
2140+
else
2141+
return zpAttr.getValues<APInt>()[0].getZExtValue();
21392142
}
21402143

21412144
// return non-zero value to trigger error check
@@ -2175,8 +2178,7 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
21752178
return op.emitOpError()
21762179
<< "expect " << tensorName << "_zp of 0, got " << zp;
21772180
}
2178-
if (zpElemType.isInteger(16) && tensorUnsigned &&
2179-
zp != static_cast<int16_t>(32768)) {
2181+
if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
21802182
return op.emitOpError() << "expect " << tensorName
21812183
<< "_zp of 0 or 32768 for unsigned int16 "
21822184
<< tensorName << ", got " << zp;
@@ -2186,30 +2188,30 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
21862188
return success();
21872189
}
21882190

2189-
#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
2191+
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
21902192
FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2191-
return getZeroPoint(get##OPERAND_NAME##Zp()); \
2193+
return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
21922194
} \
21932195
LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
21942196
return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
21952197
}
21962198

2197-
ZERO_POINT_HELPER(Conv2DOp, Input)
2198-
ZERO_POINT_HELPER(Conv2DOp, Weight)
2199-
ZERO_POINT_HELPER(Conv3DOp, Input)
2200-
ZERO_POINT_HELPER(Conv3DOp, Weight)
2201-
ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
2202-
ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
2203-
ZERO_POINT_HELPER(TransposeConv2DOp, Input)
2204-
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
2205-
ZERO_POINT_HELPER(AvgPool2dOp, Input)
2206-
ZERO_POINT_HELPER(AvgPool2dOp, Output)
2207-
ZERO_POINT_HELPER(MatMulOp, A)
2208-
ZERO_POINT_HELPER(MatMulOp, B)
2209-
ZERO_POINT_HELPER(NegateOp, Input1)
2210-
ZERO_POINT_HELPER(NegateOp, Output)
2211-
ZERO_POINT_HELPER(RescaleOp, Input)
2212-
ZERO_POINT_HELPER(RescaleOp, Output)
2199+
ZERO_POINT_HELPER(Conv2DOp, Input, true)
2200+
ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2201+
ZERO_POINT_HELPER(Conv3DOp, Input, true)
2202+
ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2203+
ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2204+
ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2205+
ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2206+
ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2207+
ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2208+
ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2209+
ZERO_POINT_HELPER(MatMulOp, A, true)
2210+
ZERO_POINT_HELPER(MatMulOp, B, true)
2211+
ZERO_POINT_HELPER(NegateOp, Input1, true)
2212+
ZERO_POINT_HELPER(NegateOp, Output, true)
2213+
ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2214+
ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
22132215
#undef ZERO_POINT_HELPER
22142216

22152217
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
12411241
// CHECK: [[INIT:%.+]] = tensor.empty()
12421242
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
12431243
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
1244-
// CHECK: [[C17:%.+]] = arith.constant 17
1244+
// CHECK: [[C128:%.+]] = arith.constant 128
12451245
// CHECK: [[C22:%.+]] = arith.constant 22
12461246
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
1247-
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1247+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
12481248
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
12491249
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
12501250
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
12551255
// CHECK: linalg.yield [[TRUNC]]
12561256
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
12571257
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1258-
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
1258+
%input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
12591259
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
12601260
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
12611261

12621262
return
12631263
}
12641264

1265+
// -----
1266+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1267+
1268+
// CHECK-LABEL: @rescale_i48_unsigned_output
1269+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1270+
func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
1271+
// CHECK: [[C19689:%.+]] = arith.constant 19689
1272+
// CHECK: [[C15:%.+]] = arith.constant 15
1273+
// CHECK: [[INIT:%.+]] = tensor.empty()
1274+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
1275+
// CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
1276+
// CHECK: [[C0:%.+]] = arith.constant 0
1277+
// CHECK: [[C234:%.+]] = arith.constant 234
1278+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
1279+
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
1280+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
1281+
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
1282+
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
1283+
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1284+
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1285+
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1286+
// CHECK: linalg.yield [[TRUNC]]
1287+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1288+
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1289+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
1290+
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
1291+
%1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
1292+
1293+
// CHECK: return
1294+
return
1295+
}
1296+
12651297
// -----
12661298

12671299
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1517,7 +1517,7 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten
15171517
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
15181518
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
15191519
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16>
1520-
// expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
1520+
// expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got 65535}}
15211521
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
15221522
return %0 : tensor<13x21x3xi16>
15231523
}

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,17 @@ func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439
753753
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
754754
}
755755

756+
// -----
757+
// CHECK-LABEL: rescale_i16_zp32768
758+
func.func @test_rescale_i16_zp32768(%arg0 : tensor<2xi8>) -> tensor<2xi16> {
759+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
760+
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
761+
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
762+
%output_zp = "tosa.const"() {values = dense<32768> : tensor<1xi16>} : () -> tensor<1xi16>
763+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<2xi16>
764+
return %0 : tensor<2xi16>
765+
}
766+
756767
// -----
757768
// CHECK-LABEL: const
758769
func.func @test_const(%arg0 : index) -> tensor<4xi32> {

0 commit comments

Comments
 (0)