Skip to content

Commit

Permalink
Refactor decimal multiplication (StarRocks#4211)
Browse files Browse the repository at this point in the history
  • Loading branch information
satanson authored and liuyehcf committed Apr 18, 2022
1 parent 43b447e commit faebf89
Show file tree
Hide file tree
Showing 7 changed files with 1,782 additions and 16 deletions.
3 changes: 2 additions & 1 deletion be/src/exprs/vectorized/arithmetic_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class VectorizedArithmeticExpr final : public Expr {
auto l = _children[0]->evaluate(context, ptr);
auto r = _children[1]->evaluate(context, ptr);
if constexpr (pt_is_decimal<Type>) {
return VectorizedStrictDecimalBinaryFunction<OP, false>::template evaluate<Type>(l, r);
// Enable overflow checking in decimal arithmetic
return VectorizedStrictDecimalBinaryFunction<OP, true>::template evaluate<Type>(l, r);
} else {
using ArithmeticOp = ArithmeticBinaryOperator<OP, Type>;
return VectorizedStrictBinaryFunction<ArithmeticOp>::template evaluate<Type>(l, r);
Expand Down
46 changes: 33 additions & 13 deletions fe/fe-core/src/main/java/com/starrocks/analysis/ArithmeticExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ public static TypeTriple getReturnTypeOfDecimal(Operator op, ScalarType lhsType,
"Types of lhs and rhs must be DecimalV3");
final PrimitiveType lhsPtype = lhsType.getPrimitiveType();
final PrimitiveType rhsPtype = rhsType.getPrimitiveType();
final int lhsPrecision = lhsType.getPrecision();
final int rhsPrecision = rhsType.getPrecision();
final int lhsScale = lhsType.getScalarScale();
final int rhsScale = rhsType.getScalarScale();

Expand All @@ -160,6 +162,7 @@ public static TypeTriple getReturnTypeOfDecimal(Operator op, ScalarType lhsType,
result.lhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, rhsScale);
int returnScale = 0;
int returnPrecision = 0;
switch (op) {
case ADD:
case SUBTRACT:
Expand All @@ -168,20 +171,37 @@ public static TypeTriple getReturnTypeOfDecimal(Operator op, ScalarType lhsType,
break;
case MULTIPLY:
returnScale = lhsScale + rhsScale;
// promote type result type of multiplication if it is too narrow to hold all significant bits
if (returnScale > maxPrecision) {
final int maxPrecisionOfDecimal128 =
PrimitiveType.getMaxPrecisionOfDecimal(PrimitiveType.DECIMAL128);
// decimal128 is already the widest decimal types, so throw an error if scale of result exceeds 38
Preconditions.checkState(widerType != PrimitiveType.DECIMAL128,
String.format("Return scale(%d) exceeds maximum value(%d)", returnScale,
maxPrecisionOfDecimal128));
widerType = PrimitiveType.DECIMAL128;
maxPrecision = maxPrecisionOfDecimal128;
result.lhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, rhsScale);
returnPrecision = lhsPrecision + rhsPrecision;
final int maxDecimalPrecision = PrimitiveType.getMaxPrecisionOfDecimal(PrimitiveType.DECIMAL128);
if (returnPrecision <= maxDecimalPrecision) {
// returnPrecision <= 38, result never overflows, use the narrowest decimal type that can holds the result.
// for examples:
// decimal32(4,3) * decimal32(4,3) => decimal32(8,6);
// decimal64(15,3) * decimal32(9,4) => decimal128(24,7).
PrimitiveType commonPtype = ScalarType.createDecimalV3NarrowestType(returnPrecision, returnScale).getPrimitiveType();
// a common type shall never be narrower than type of lhs and rhs
commonPtype = PrimitiveType.getWiderDecimalV3Type(commonPtype, lhsPtype);
commonPtype = PrimitiveType.getWiderDecimalV3Type(commonPtype, rhsPtype);
result.returnType = ScalarType.createDecimalV3Type(commonPtype, returnPrecision, returnScale);
result.lhsTargetType = ScalarType.createDecimalV3Type(commonPtype, lhsPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(commonPtype, rhsPrecision, rhsScale);
return result;
} else if (returnScale <= maxDecimalPrecision) {
// returnPrecision > 38 and returnScale <= 38, the multiplication is computable but the result maybe
// overflow, so use decimal128 arithmetic and adopt maximum decimal precision(38) as precision of
// the result.
// for examples:
// decimal128(23,5) * decimal64(18,4) => decimal128(38, 9).
result.returnType = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, maxDecimalPrecision, returnScale);
result.lhsTargetType = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, lhsPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, rhsPrecision, rhsScale);
return result;
} else {
// returnScale > 38, so it is cannot be represented as decimal.
throw new AnalysisException(
String.format("Return scale(%d) exceeds maximum value(%d), please cast decimal type to low-precision one",
returnScale, maxDecimalPrecision));
}
break;
case INT_DIVIDE:
case DIVIDE:
if (lhsScale <= 6) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,121 @@ public void testDecimal32Add() throws IOException {
}

}

private ScalarType dec(int bits, int precision, int scale) {

PrimitiveType pType = PrimitiveType.INVALID_TYPE;
switch (bits) {
case 32:
pType = PrimitiveType.DECIMAL32;
break;
case 64:
pType = PrimitiveType.DECIMAL64;
break;
case 128:
pType = PrimitiveType.DECIMAL128;
break;
}
return ScalarType.createDecimalV3Type(pType, precision, scale);
}

@Test
public void testDecimalMultiply() throws AnalysisException {
Object[][] cases = new Object[][]{
{
dec(32, 4, 3),
dec(32, 4, 3),
dec(32, 8, 6),
dec(32, 4, 3),
dec(32, 4, 3),
},
{
dec(64, 7, 2),
dec(64, 7, 2),
dec(64, 14, 4),
dec(64, 7, 2),
dec(64, 7, 2),
},
{
dec(32, 7, 2),
dec(32, 9, 4),
dec(64, 16, 6),
dec(64, 7, 2),
dec(64, 9, 4),
},
{
dec(64, 14, 4),
dec(32, 7, 2),
dec(128, 21, 6),
dec(128, 14, 4),
dec(128, 7, 2),
},
{
dec(64, 14, 4),
dec(64, 18, 14),
dec(128, 32, 18),
dec(128, 14, 4),
dec(128, 18, 14),
},
{
dec(128, 35, 18),
dec(128, 35, 20),
dec(128, 38, 38),
dec(128, 35, 18),
dec(128, 35, 20),
},
{
dec(128, 35, 30),
dec(32, 8, 7),
dec(128, 38, 37),
dec(128, 35, 30),
dec(128, 8, 7),
},
{
dec(128, 36, 31),
dec(64, 18, 7),
dec(128, 38, 38),
dec(128, 36, 31),
dec(128, 18, 7),
}
};
for (Object[] c : cases) {
ScalarType lhsType = (ScalarType) c[0];
ScalarType rhsType = (ScalarType) c[1];
ScalarType expectReturnType = (ScalarType) c[2];
ScalarType expectLhsType = (ScalarType) c[3];
ScalarType expectRhsType = (ScalarType) c[4];
ArithmeticExpr.TypeTriple tr = ArithmeticExpr.getReturnTypeOfDecimal(ArithmeticExpr.Operator.MULTIPLY, lhsType, rhsType);
Assert.assertEquals(tr.returnType, expectReturnType);
Assert.assertEquals(tr.lhsTargetType, expectLhsType);
Assert.assertEquals(tr.rhsTargetType, expectRhsType);
}
}

@Test
public void testDecimalMultiplyFail() {
Object[][] cases = new Object[][]{
{
dec(128, 38, 36),
dec(32, 4, 3),
},
{
dec(128, 37, 24),
dec(64, 18, 15),
},
{
dec(128, 30, 19),
dec(128, 30, 20),
},
};
for (Object[] c : cases) {
ScalarType lhsType = (ScalarType) c[0];
ScalarType rhsType = (ScalarType) c[1];
try {
ArithmeticExpr.getReturnTypeOfDecimal(ArithmeticExpr.Operator.MULTIPLY, lhsType, rhsType);
Assert.fail("should throw exception");
} catch (AnalysisException ignored) {
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ public void testBinaryPredicateGe() throws Exception {
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Expr expr = stmt.selectList.getItems().get(0).getExpr();
Assert.assertEquals(expr.type, Type.BOOLEAN);
Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL64, 18, 12);
Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, 24, 12);
Assert.assertEquals(expr.getChild(0).type, decimal128p38s5);
Assert.assertEquals(expr.getChild(1).type, decimal128p38s5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ public void testBinaryPredicateGe() throws Exception {
((LogicalProjectOperator) logicalPlan.getRoot().getOp()).getColumnRefMap()
.get(logicalPlan.getOutputColumn().get(0));

Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL64, 18, 12);
Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, 24, 12);
Assert.assertEquals(op.getChild(0).getType(), decimal128p38s5);
Assert.assertEquals(op.getChild(1).getType(), decimal128p38s5);
}
Expand Down
Loading

0 comments on commit faebf89

Please sign in to comment.