Skip to content

Commit

Permalink
Refactor decimal multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
satanson authored and mergify-bot committed Mar 22, 2022
1 parent afecea8 commit 2db3785
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 32 deletions.
3 changes: 2 additions & 1 deletion be/src/exprs/vectorized/arithmetic_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class VectorizedArithmeticExpr final : public Expr {
auto l = _children[0]->evaluate(context, ptr);
auto r = _children[1]->evaluate(context, ptr);
if constexpr (pt_is_decimal<Type>) {
return VectorizedStrictDecimalBinaryFunction<OP, false>::template evaluate<Type>(l, r);
// Enable overflow checking in decimal arithmetic
return VectorizedStrictDecimalBinaryFunction<OP, true>::template evaluate<Type>(l, r);
} else {
using ArithmeticOp = ArithmeticBinaryOperator<OP, Type>;
return VectorizedStrictBinaryFunction<ArithmeticOp>::template evaluate<Type>(l, r);
Expand Down
46 changes: 33 additions & 13 deletions fe/fe-core/src/main/java/com/starrocks/analysis/ArithmeticExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ public static TypeTriple getReturnTypeOfDecimal(Operator op, ScalarType lhsType,
"Types of lhs and rhs must be DecimalV3");
final PrimitiveType lhsPtype = lhsType.getPrimitiveType();
final PrimitiveType rhsPtype = rhsType.getPrimitiveType();
final int lhsPrecision = lhsType.getPrecision();
final int rhsPrecision = rhsType.getPrecision();
final int lhsScale = lhsType.getScalarScale();
final int rhsScale = rhsType.getScalarScale();

Expand All @@ -160,6 +162,7 @@ public static TypeTriple getReturnTypeOfDecimal(Operator op, ScalarType lhsType,
result.lhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, rhsScale);
int returnScale = 0;
int returnPrecision = 0;
switch (op) {
case ADD:
case SUBTRACT:
Expand All @@ -168,20 +171,37 @@ public static TypeTriple getReturnTypeOfDecimal(Operator op, ScalarType lhsType,
break;
case MULTIPLY:
returnScale = lhsScale + rhsScale;
// promote type result type of multiplication if it is too narrow to hold all significant bits
if (returnScale > maxPrecision) {
final int maxPrecisionOfDecimal128 =
PrimitiveType.getMaxPrecisionOfDecimal(PrimitiveType.DECIMAL128);
// decimal128 is already the widest decimal types, so throw an error if scale of result exceeds 38
Preconditions.checkState(widerType != PrimitiveType.DECIMAL128,
String.format("Return scale(%d) exceeds maximum value(%d)", returnScale,
maxPrecisionOfDecimal128));
widerType = PrimitiveType.DECIMAL128;
maxPrecision = maxPrecisionOfDecimal128;
result.lhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, rhsScale);
returnPrecision = lhsPrecision + rhsPrecision;
final int maxDecimalPrecision = PrimitiveType.getMaxPrecisionOfDecimal(PrimitiveType.DECIMAL128);
if (returnPrecision <= maxDecimalPrecision) {
// returnPrecision <= 38, result never overflows, use the narrowest decimal type that can holds the result.
// for examples:
// decimal32(4,3) * decimal32(4,3) => decimal32(8,6);
// decimal64(15,3) * decimal32(9,4) => decimal128(24,7).
PrimitiveType commonPtype = ScalarType.createDecimalV3NarrowestType(returnPrecision, returnScale).getPrimitiveType();
// a common type shall never be narrower than type of lhs and rhs
commonPtype = PrimitiveType.getWiderDecimalV3Type(commonPtype, lhsPtype);
commonPtype = PrimitiveType.getWiderDecimalV3Type(commonPtype, rhsPtype);
result.returnType = ScalarType.createDecimalV3Type(commonPtype, returnPrecision, returnScale);
result.lhsTargetType = ScalarType.createDecimalV3Type(commonPtype, lhsPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(commonPtype, rhsPrecision, rhsScale);
return result;
} else if (returnScale <= maxDecimalPrecision) {
// returnPrecision > 38 and returnScale <= 38, the multiplication is computable but the result maybe
// overflow, so use decimal128 arithmetic and adopt maximum decimal precision(38) as precision of
// the result.
// for examples:
// decimal128(23,5) * decimal64(18,4) => decimal128(38, 9).
result.returnType = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, maxDecimalPrecision, returnScale);
result.lhsTargetType = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, lhsPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, rhsPrecision, rhsScale);
return result;
} else {
// returnScale > 38, so it is cannot be represented as decimal.
throw new AnalysisException(
String.format("Return scale(%d) exceeds maximum value(%d), please cast decimal type to low-precision one",
returnScale, maxDecimalPrecision));
}
break;
case INT_DIVIDE:
case DIVIDE:
if (lhsScale <= 6) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,121 @@ public void testDecimal32Add() throws IOException {
}

}

private ScalarType dec(int bits, int precision, int scale) {

PrimitiveType pType = PrimitiveType.INVALID_TYPE;
switch (bits) {
case 32:
pType = PrimitiveType.DECIMAL32;
break;
case 64:
pType = PrimitiveType.DECIMAL64;
break;
case 128:
pType = PrimitiveType.DECIMAL128;
break;
}
return ScalarType.createDecimalV3Type(pType, precision, scale);
}

@Test
public void testDecimalMultiply() throws AnalysisException {
Object[][] cases = new Object[][]{
{
dec(32, 4, 3),
dec(32, 4, 3),
dec(32, 8, 6),
dec(32, 4, 3),
dec(32, 4, 3),
},
{
dec(64, 7, 2),
dec(64, 7, 2),
dec(64, 14, 4),
dec(64, 7, 2),
dec(64, 7, 2),
},
{
dec(32, 7, 2),
dec(32, 9, 4),
dec(64, 16, 6),
dec(64, 7, 2),
dec(64, 9, 4),
},
{
dec(64, 14, 4),
dec(32, 7, 2),
dec(128, 21, 6),
dec(128, 14, 4),
dec(128, 7, 2),
},
{
dec(64, 14, 4),
dec(64, 18, 14),
dec(128, 32, 18),
dec(128, 14, 4),
dec(128, 18, 14),
},
{
dec(128, 35, 18),
dec(128, 35, 20),
dec(128, 38, 38),
dec(128, 35, 18),
dec(128, 35, 20),
},
{
dec(128, 35, 30),
dec(32, 8, 7),
dec(128, 38, 37),
dec(128, 35, 30),
dec(128, 8, 7),
},
{
dec(128, 36, 31),
dec(64, 18, 7),
dec(128, 38, 38),
dec(128, 36, 31),
dec(128, 18, 7),
}
};
for (Object[] c : cases) {
ScalarType lhsType = (ScalarType) c[0];
ScalarType rhsType = (ScalarType) c[1];
ScalarType expectReturnType = (ScalarType) c[2];
ScalarType expectLhsType = (ScalarType) c[3];
ScalarType expectRhsType = (ScalarType) c[4];
ArithmeticExpr.TypeTriple tr = ArithmeticExpr.getReturnTypeOfDecimal(ArithmeticExpr.Operator.MULTIPLY, lhsType, rhsType);
Assert.assertEquals(tr.returnType, expectReturnType);
Assert.assertEquals(tr.lhsTargetType, expectLhsType);
Assert.assertEquals(tr.rhsTargetType, expectRhsType);
}
}

@Test
public void testDecimalMultiplyFail() {
Object[][] cases = new Object[][]{
{
dec(128, 38, 36),
dec(32, 4, 3),
},
{
dec(128, 37, 24),
dec(64, 18, 15),
},
{
dec(128, 30, 19),
dec(128, 30, 20),
},
};
for (Object[] c : cases) {
ScalarType lhsType = (ScalarType) c[0];
ScalarType rhsType = (ScalarType) c[1];
try {
ArithmeticExpr.getReturnTypeOfDecimal(ArithmeticExpr.Operator.MULTIPLY, lhsType, rhsType);
Assert.fail("should throw exception");
} catch (AnalysisException ignored) {
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,16 @@ public void testMoneyFormat() throws Exception {
@Test
public void testMultiply() throws Exception {
String sql = "select col_decimal128p20s3 * 3.14 from db1.decimal_table";
String expectString = "TExprNode(node_type:ARITHMETIC_EXPR, type:TTypeDesc(types:[TTypeNode(type:SCALAR, " +
"scalar_type:TScalarType(type:DECIMAL128, precision:38, scale:5))]), opcode:MULTIPLY, num_children:2," +
" output_scale:-1, output_column:-1, has_nullable_child:true, is_nullable:true, is_monotonic:false)," +
" TExprNode(node_type:CAST_EXPR, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:TScalarType(type:DECIMAL128, precision:38, scale:3))])," +
" opcode:INVALID_OPCODE, num_children:1, output_scale:-1, output_column:-1, child_type:DECIMAL128, has_nullable_child:true, is_nullable:true, is_monotonic:false), " +
"TExprNode(node_type:SLOT_REF, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:TScalarType" +
"(type:DECIMAL128, precision:20, scale:3))]), num_children:0, slot_ref:TSlotRef(slot_id:5, tuple_id:0)," +
" output_scale:-1, output_column:-1, has_nullable_child:false, is_nullable:true, is_monotonic:true)," +
" TExprNode(node_type:DECIMAL_LITERAL, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:" +
"TScalarType(type:DECIMAL128, precision:38, scale:2))]), num_children:0, decimal_literal:" +
"TDecimalLiteral(value:3.14, integer_value:3A 01 00 00 00 00 00 00 00 00 00 00 00 00 00 00), output_scale:-1, has_nullable_child:false," +
" is_nullable:false, is_monotonic:true)";
String expectString = "TExpr(nodes:[TExprNode(node_type:ARITHMETIC_EXPR, type:TTypeDesc(types:[TTypeNode(type:SCALAR," +
" scalar_type:TScalarType(type:DECIMAL128, precision:23, scale:5))]), opcode:MULTIPLY, num_children:2, " +
"output_scale:-1, output_column:-1, has_nullable_child:true, is_nullable:true, is_monotonic:true), " +
"TExprNode(node_type:SLOT_REF, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:TScalarType(type:" +
"DECIMAL128, precision:20, scale:3))]), num_children:0, slot_ref:TSlotRef(slot_id:5, tuple_id:0), " +
"output_scale:-1, output_column:-1, has_nullable_child:false, is_nullable:true, is_monotonic:true), " +
"TExprNode(node_type:DECIMAL_LITERAL, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:TScalarType" +
"(type:DECIMAL128, precision:3, scale:2))]), num_children:0, decimal_literal:TDecimalLiteral(value:3.14, " +
"integer_value:3A 01 00 00 00 00 00 00 00 00 00 00 00 00 00 00), output_scale:-1, has_nullable_child:false, " +
"is_nullable:false, is_monotonic:true)])})";
String plan = UtFrameUtils.getPlanThriftString(ctx, sql);
System.out.println(plan);
Assert.assertTrue(plan.contains(expectString));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ public void testBinaryPredicateGe() throws Exception {
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Expr expr = stmt.selectList.getItems().get(0).getExpr();
Assert.assertEquals(expr.type, Type.BOOLEAN);
Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL64, 18, 12);
Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, 24, 12);
Assert.assertEquals(expr.getChild(0).type, decimal128p38s5);
Assert.assertEquals(expr.getChild(1).type, decimal128p38s5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ public void testBinaryPredicateGe() throws Exception {
((LogicalProjectOperator) logicalPlan.getRoot().getOp()).getColumnRefMap()
.get(logicalPlan.getOutputColumn().get(0));

Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL64, 18, 12);
Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, 24, 12);
Assert.assertEquals(op.getChild(0).getType(), decimal128p38s5);
Assert.assertEquals(op.getChild(1).getType(), decimal128p38s5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@ public void testNullableSameWithChildrenFunctions() throws Exception {

sql = "select distinct cast(2.0 as decimal) * v1 from t0_not_null";
plan = getVerboseExplain(sql);
System.out.println(plan);
Assert.assertTrue(plan.contains("2:AGGREGATE (update finalize)\n" +
" | group by: [4: expr, DECIMAL64(18,0), true]"));
" | group by: [4: expr, DECIMAL128(28,0), true]"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ public void testExpression7() throws Exception {
public void testExpression8() throws Exception {
String sql = "select cast(v1 as decimal128(10,5)) * cast(v2 as decimal64(9,7)) from t0";
String planFragment = getFragmentPlan(sql);
Assert.assertTrue(planFragment.contains(" 1:Project\n" +
" | <slot 4> : CAST(CAST(1: v1 AS DECIMAL128(10,5)) AS DECIMAL128(38,5)) * CAST(CAST(2: v2 AS DECIMAL64(9,7)) AS DECIMAL128(38,7))\n"));
Assert.assertTrue(planFragment.contains("1:Project\n" +
" | <slot 4> : CAST(1: v1 AS DECIMAL128(10,5)) * CAST(CAST(2: v2 AS DECIMAL64(9,7)) AS DECIMAL128(9,7))"));
}

@Test
Expand Down Expand Up @@ -620,7 +620,7 @@ public void testArithmeticCommutative() throws Exception {

sql = "select k5 from bigtable where k5 * 2 <= 3";
planFragment = getFragmentPlan(sql);
Assert.assertTrue(planFragment.contains("PREDICATES: CAST(5: k5 AS DECIMAL64(18,3)) * 2 <= 3"));
Assert.assertTrue(planFragment.contains("PREDICATES: 5: k5 * 2 <= 3"));

sql = "select k5 from bigtable where 2 / k5 <= 3";
planFragment = getFragmentPlan(sql);
Expand Down

0 comments on commit 2db3785

Please sign in to comment.