Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor decimal multiplication #4211

Merged
merged 1 commit into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion be/src/exprs/vectorized/arithmetic_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class VectorizedArithmeticExpr final : public Expr {
auto l = _children[0]->evaluate(context, ptr);
auto r = _children[1]->evaluate(context, ptr);
if constexpr (pt_is_decimal<Type>) {
return VectorizedStrictDecimalBinaryFunction<OP, false>::template evaluate<Type>(l, r);
// Enable overflow checking in decimal arithmetic
return VectorizedStrictDecimalBinaryFunction<OP, true>::template evaluate<Type>(l, r);
satanson marked this conversation as resolved.
Show resolved Hide resolved
} else {
using ArithmeticOp = ArithmeticBinaryOperator<OP, Type>;
return VectorizedStrictBinaryFunction<ArithmeticOp>::template evaluate<Type>(l, r);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ public static TypeTriple getReturnTypeOfDecimal(Operator op, ScalarType lhsType,
"Types of lhs and rhs must be DecimalV3");
final PrimitiveType lhsPtype = lhsType.getPrimitiveType();
final PrimitiveType rhsPtype = rhsType.getPrimitiveType();
final int lhsPrecision = lhsType.getPrecision();
final int rhsPrecision = rhsType.getPrecision();
final int lhsScale = lhsType.getScalarScale();
final int rhsScale = rhsType.getScalarScale();

Expand All @@ -160,6 +162,7 @@ public static TypeTriple getReturnTypeOfDecimal(Operator op, ScalarType lhsType,
result.lhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, rhsScale);
int returnScale = 0;
int returnPrecision = 0;
switch (op) {
case ADD:
case SUBTRACT:
Expand All @@ -168,20 +171,37 @@ public static TypeTriple getReturnTypeOfDecimal(Operator op, ScalarType lhsType,
break;
case MULTIPLY:
returnScale = lhsScale + rhsScale;
// promote type result type of multiplication if it is too narrow to hold all significant bits
if (returnScale > maxPrecision) {
final int maxPrecisionOfDecimal128 =
PrimitiveType.getMaxPrecisionOfDecimal(PrimitiveType.DECIMAL128);
// decimal128 is already the widest decimal types, so throw an error if scale of result exceeds 38
Preconditions.checkState(widerType != PrimitiveType.DECIMAL128,
String.format("Return scale(%d) exceeds maximum value(%d)", returnScale,
maxPrecisionOfDecimal128));
widerType = PrimitiveType.DECIMAL128;
maxPrecision = maxPrecisionOfDecimal128;
result.lhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(widerType, maxPrecision, rhsScale);
returnPrecision = lhsPrecision + rhsPrecision;
final int maxDecimalPrecision = PrimitiveType.getMaxPrecisionOfDecimal(PrimitiveType.DECIMAL128);
if (returnPrecision <= maxDecimalPrecision) {
// returnPrecision <= 38, result never overflows, use the narrowest decimal type that can holds the result.
// for examples:
// decimal32(4,3) * decimal32(4,3) => decimal32(8,6);
// decimal64(15,3) * decimal32(9,4) => decimal128(24,7).
PrimitiveType commonPtype = ScalarType.createDecimalV3NarrowestType(returnPrecision, returnScale).getPrimitiveType();
// a common type shall never be narrower than type of lhs and rhs
commonPtype = PrimitiveType.getWiderDecimalV3Type(commonPtype, lhsPtype);
commonPtype = PrimitiveType.getWiderDecimalV3Type(commonPtype, rhsPtype);
result.returnType = ScalarType.createDecimalV3Type(commonPtype, returnPrecision, returnScale);
result.lhsTargetType = ScalarType.createDecimalV3Type(commonPtype, lhsPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(commonPtype, rhsPrecision, rhsScale);
return result;
} else if (returnScale <= maxDecimalPrecision) {
// returnPrecision > 38 and returnScale <= 38, the multiplication is computable but the result maybe
// overflow, so use decimal128 arithmetic and adopt maximum decimal precision(38) as precision of
// the result.
// for examples:
// decimal128(23,5) * decimal64(18,4) => decimal128(38, 9).
result.returnType = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, maxDecimalPrecision, returnScale);
result.lhsTargetType = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, lhsPrecision, lhsScale);
result.rhsTargetType = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, rhsPrecision, rhsScale);
return result;
} else {
// returnScale > 38, so it is cannot be represented as decimal.
throw new AnalysisException(
String.format("Return scale(%d) exceeds maximum value(%d), please cast decimal type to low-precision one",
returnScale, maxDecimalPrecision));
}
break;
case INT_DIVIDE:
case DIVIDE:
if (lhsScale <= 6) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,121 @@ public void testDecimal32Add() throws IOException {
}

}

private ScalarType dec(int bits, int precision, int scale) {

PrimitiveType pType = PrimitiveType.INVALID_TYPE;
switch (bits) {
case 32:
pType = PrimitiveType.DECIMAL32;
break;
case 64:
pType = PrimitiveType.DECIMAL64;
break;
case 128:
pType = PrimitiveType.DECIMAL128;
break;
}
return ScalarType.createDecimalV3Type(pType, precision, scale);
}

@Test
public void testDecimalMultiply() throws AnalysisException {
Object[][] cases = new Object[][]{
{
dec(32, 4, 3),
dec(32, 4, 3),
dec(32, 8, 6),
dec(32, 4, 3),
dec(32, 4, 3),
},
{
dec(64, 7, 2),
dec(64, 7, 2),
dec(64, 14, 4),
dec(64, 7, 2),
dec(64, 7, 2),
},
{
dec(32, 7, 2),
dec(32, 9, 4),
dec(64, 16, 6),
dec(64, 7, 2),
dec(64, 9, 4),
},
{
dec(64, 14, 4),
dec(32, 7, 2),
dec(128, 21, 6),
dec(128, 14, 4),
dec(128, 7, 2),
},
{
dec(64, 14, 4),
dec(64, 18, 14),
dec(128, 32, 18),
dec(128, 14, 4),
dec(128, 18, 14),
},
{
dec(128, 35, 18),
dec(128, 35, 20),
dec(128, 38, 38),
dec(128, 35, 18),
dec(128, 35, 20),
},
{
dec(128, 35, 30),
dec(32, 8, 7),
dec(128, 38, 37),
dec(128, 35, 30),
dec(128, 8, 7),
},
{
dec(128, 36, 31),
dec(64, 18, 7),
dec(128, 38, 38),
dec(128, 36, 31),
dec(128, 18, 7),
}
};
for (Object[] c : cases) {
ScalarType lhsType = (ScalarType) c[0];
ScalarType rhsType = (ScalarType) c[1];
ScalarType expectReturnType = (ScalarType) c[2];
ScalarType expectLhsType = (ScalarType) c[3];
ScalarType expectRhsType = (ScalarType) c[4];
ArithmeticExpr.TypeTriple tr = ArithmeticExpr.getReturnTypeOfDecimal(ArithmeticExpr.Operator.MULTIPLY, lhsType, rhsType);
Assert.assertEquals(tr.returnType, expectReturnType);
Assert.assertEquals(tr.lhsTargetType, expectLhsType);
Assert.assertEquals(tr.rhsTargetType, expectRhsType);
}
}

@Test
public void testDecimalMultiplyFail() {
Object[][] cases = new Object[][]{
{
dec(128, 38, 36),
dec(32, 4, 3),
},
{
dec(128, 37, 24),
dec(64, 18, 15),
},
{
dec(128, 30, 19),
dec(128, 30, 20),
},
};
for (Object[] c : cases) {
ScalarType lhsType = (ScalarType) c[0];
ScalarType rhsType = (ScalarType) c[1];
try {
ArithmeticExpr.getReturnTypeOfDecimal(ArithmeticExpr.Operator.MULTIPLY, lhsType, rhsType);
Assert.fail("should throw exception");
} catch (AnalysisException ignored) {
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,16 @@ public void testMoneyFormat() throws Exception {
@Test
public void testMultiply() throws Exception {
String sql = "select col_decimal128p20s3 * 3.14 from db1.decimal_table";
String expectString = "TExprNode(node_type:ARITHMETIC_EXPR, type:TTypeDesc(types:[TTypeNode(type:SCALAR, " +
"scalar_type:TScalarType(type:DECIMAL128, precision:38, scale:5))]), opcode:MULTIPLY, num_children:2," +
" output_scale:-1, output_column:-1, has_nullable_child:true, is_nullable:true, is_monotonic:false)," +
" TExprNode(node_type:CAST_EXPR, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:TScalarType(type:DECIMAL128, precision:38, scale:3))])," +
" opcode:INVALID_OPCODE, num_children:1, output_scale:-1, output_column:-1, child_type:DECIMAL128, has_nullable_child:true, is_nullable:true, is_monotonic:false), " +
"TExprNode(node_type:SLOT_REF, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:TScalarType" +
"(type:DECIMAL128, precision:20, scale:3))]), num_children:0, slot_ref:TSlotRef(slot_id:5, tuple_id:0)," +
" output_scale:-1, output_column:-1, has_nullable_child:false, is_nullable:true, is_monotonic:true)," +
" TExprNode(node_type:DECIMAL_LITERAL, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:" +
"TScalarType(type:DECIMAL128, precision:38, scale:2))]), num_children:0, decimal_literal:" +
"TDecimalLiteral(value:3.14, integer_value:3A 01 00 00 00 00 00 00 00 00 00 00 00 00 00 00), output_scale:-1, has_nullable_child:false," +
" is_nullable:false, is_monotonic:true)";
String expectString = "TExpr(nodes:[TExprNode(node_type:ARITHMETIC_EXPR, type:TTypeDesc(types:[TTypeNode(type:SCALAR," +
" scalar_type:TScalarType(type:DECIMAL128, precision:23, scale:5))]), opcode:MULTIPLY, num_children:2, " +
"output_scale:-1, output_column:-1, has_nullable_child:true, is_nullable:true, is_monotonic:true), " +
"TExprNode(node_type:SLOT_REF, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:TScalarType(type:" +
"DECIMAL128, precision:20, scale:3))]), num_children:0, slot_ref:TSlotRef(slot_id:5, tuple_id:0), " +
"output_scale:-1, output_column:-1, has_nullable_child:false, is_nullable:true, is_monotonic:true), " +
"TExprNode(node_type:DECIMAL_LITERAL, type:TTypeDesc(types:[TTypeNode(type:SCALAR, scalar_type:TScalarType" +
"(type:DECIMAL128, precision:3, scale:2))]), num_children:0, decimal_literal:TDecimalLiteral(value:3.14, " +
"integer_value:3A 01 00 00 00 00 00 00 00 00 00 00 00 00 00 00), output_scale:-1, has_nullable_child:false, " +
"is_nullable:false, is_monotonic:true)])})";
String plan = UtFrameUtils.getPlanThriftString(ctx, sql);
System.out.println(plan);
Assert.assertTrue(plan.contains(expectString));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ public void testBinaryPredicateGe() throws Exception {
stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter());
Expr expr = stmt.selectList.getItems().get(0).getExpr();
Assert.assertEquals(expr.type, Type.BOOLEAN);
Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL64, 18, 12);
Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, 24, 12);
Assert.assertEquals(expr.getChild(0).type, decimal128p38s5);
Assert.assertEquals(expr.getChild(1).type, decimal128p38s5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ public void testBinaryPredicateGe() throws Exception {
((LogicalProjectOperator) logicalPlan.getRoot().getOp()).getColumnRefMap()
.get(logicalPlan.getOutputColumn().get(0));

Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL64, 18, 12);
Type decimal128p38s5 = ScalarType.createDecimalV3Type(PrimitiveType.DECIMAL128, 24, 12);
Assert.assertEquals(op.getChild(0).getType(), decimal128p38s5);
Assert.assertEquals(op.getChild(1).getType(), decimal128p38s5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@ public void testNullableSameWithChildrenFunctions() throws Exception {

sql = "select distinct cast(2.0 as decimal) * v1 from t0_not_null";
plan = getVerboseExplain(sql);
System.out.println(plan);
Assert.assertTrue(plan.contains("2:AGGREGATE (update finalize)\n" +
" | group by: [4: expr, DECIMAL64(18,0), true]"));
" | group by: [4: expr, DECIMAL128(28,0), true]"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ public void testExpression7() throws Exception {
public void testExpression8() throws Exception {
String sql = "select cast(v1 as decimal128(10,5)) * cast(v2 as decimal64(9,7)) from t0";
String planFragment = getFragmentPlan(sql);
Assert.assertTrue(planFragment.contains(" 1:Project\n" +
" | <slot 4> : CAST(CAST(1: v1 AS DECIMAL128(10,5)) AS DECIMAL128(38,5)) * CAST(CAST(2: v2 AS DECIMAL64(9,7)) AS DECIMAL128(38,7))\n"));
Assert.assertTrue(planFragment.contains("1:Project\n" +
" | <slot 4> : CAST(1: v1 AS DECIMAL128(10,5)) * CAST(CAST(2: v2 AS DECIMAL64(9,7)) AS DECIMAL128(9,7))"));
}

@Test
Expand Down Expand Up @@ -620,7 +620,7 @@ public void testArithmeticCommutative() throws Exception {

sql = "select k5 from bigtable where k5 * 2 <= 3";
planFragment = getFragmentPlan(sql);
Assert.assertTrue(planFragment.contains("PREDICATES: CAST(5: k5 AS DECIMAL64(18,3)) * 2 <= 3"));
Assert.assertTrue(planFragment.contains("PREDICATES: 5: k5 * 2 <= 3"));

sql = "select k5 from bigtable where 2 / k5 <= 3";
planFragment = getFragmentPlan(sql);
Expand Down