Skip to content

Commit

Permalink
[fix](nereids)the common type of decimalv2 and decimalv3 shoud be dec…
Browse files Browse the repository at this point in the history
…imalv3 in BinaryArithmetic operator (apache#24215)

the common type of decimalv2 and decimalv3 shoud be decimalv3 in BinaryArithmetic operator
  • Loading branch information
starocean999 authored Sep 14, 2023
1 parent 51a5895 commit 40e1c2a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,11 @@ public static Expression processBinaryArithmetic(BinaryArithmetic binaryArithmet
commonType = DoubleType.INSTANCE;
}

if (t1.isDecimalV3Type() && t2.isDecimalV2Type()
|| t1.isDecimalV2Type() && t2.isDecimalV3Type()) {
return processDecimalV3BinaryArithmetic(binaryArithmetic, left, right);
}

if (t1.isDecimalV2Type() || t2.isDecimalV2Type()) {
// to be consitent with old planner
// see findCommonType() method in ArithmeticExpr.java
Expand Down Expand Up @@ -735,26 +740,7 @@ public static Expression processBinaryArithmetic(BinaryArithmetic binaryArithmet

// double and float already process, we only process decimalv2 and fixed point number.
if (t1 instanceof DecimalV3Type || t2 instanceof DecimalV3Type) {
DecimalV3Type dt1 = DecimalV3Type.forType(t1);
DecimalV3Type dt2 = DecimalV3Type.forType(t2);

// check return type whether overflow, if true, turn to double
DecimalV3Type retType;
try {
retType = binaryArithmetic.getDataTypeForDecimalV3(dt1, dt2);
} catch (Exception e) {
// exception means overflow.
return castChildren(binaryArithmetic, left, right, DoubleType.INSTANCE);
}

// add, subtract and mod should cast children to exactly same type as return type
if (binaryArithmetic instanceof Add
|| binaryArithmetic instanceof Subtract
|| binaryArithmetic instanceof Mod) {
return castChildren(binaryArithmetic, left, right, retType);
}
// multiply do not need to cast children to same type
return binaryArithmetic.withChildren(castIfNotSameType(left, dt1), castIfNotSameType(right, dt2));
return processDecimalV3BinaryArithmetic(binaryArithmetic, left, right);
}

// double, float and decimalv3 already process, we only process fixed point number
Expand Down Expand Up @@ -1443,4 +1429,30 @@ public static BoundFunction fillJsonTypeArgument(BoundFunction function, boolean
throw new AnalysisException(t.getMessage());
}
}

private static Expression processDecimalV3BinaryArithmetic(BinaryArithmetic binaryArithmetic,
Expression left, Expression right) {
DecimalV3Type dt1 =
DecimalV3Type.forType(TypeCoercionUtils.getNumResultType(left.getDataType()));
DecimalV3Type dt2 =
DecimalV3Type.forType(TypeCoercionUtils.getNumResultType(right.getDataType()));

// check return type whether overflow, if true, turn to double
DecimalV3Type retType;
try {
retType = binaryArithmetic.getDataTypeForDecimalV3(dt1, dt2);
} catch (Exception e) {
// exception means overflow.
return castChildren(binaryArithmetic, left, right, DoubleType.INSTANCE);
}

// add, subtract and mod should cast children to exactly same type as return type
if (binaryArithmetic instanceof Add || binaryArithmetic instanceof Subtract
|| binaryArithmetic instanceof Mod) {
return castChildren(binaryArithmetic, left, right, retType);
}
// multiply do not need to cast children to same type
return binaryArithmetic.withChildren(castIfNotSameType(left, dt1),
castIfNotSameType(right, dt2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@

package org.apache.doris.nereids.util;

import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
Expand Down Expand Up @@ -49,6 +56,7 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.util.Optional;

public class TypeCoercionUtilsTest {
Expand Down Expand Up @@ -688,4 +696,31 @@ public void testCastIfNotSameType() {
Assertions.assertEquals(new Cast(new DoubleLiteral(5L), BooleanType.INSTANCE),
TypeCoercionUtils.castIfNotMatchType(new DoubleLiteral(5L), BooleanType.INSTANCE));
}

@Test
public void testDecimalArithmetic() {
Multiply multiply = new Multiply(new DecimalLiteral(new BigDecimal("987654.321")),
new DecimalV3Literal(new BigDecimal("123.45")));
Expression expression = TypeCoercionUtils.processBinaryArithmetic(multiply);
Assertions.assertEquals(expression.child(0),
new Cast(multiply.child(0), DecimalV3Type.createDecimalV3Type(9, 3)));

Divide divide = new Divide(new DecimalLiteral(new BigDecimal("987654.321")),
new DecimalV3Literal(new BigDecimal("123.45")));
expression = TypeCoercionUtils.processBinaryArithmetic(divide);
Assertions.assertEquals(expression.child(0),
new Cast(multiply.child(0), DecimalV3Type.createDecimalV3Type(9, 3)));

Add add = new Add(new DecimalLiteral(new BigDecimal("987654.321")),
new DecimalV3Literal(new BigDecimal("123.45")));
expression = TypeCoercionUtils.processBinaryArithmetic(add);
Assertions.assertEquals(expression.child(0),
new Cast(multiply.child(0), DecimalV3Type.createDecimalV3Type(9, 3)));

Subtract sub = new Subtract(new DecimalLiteral(new BigDecimal("987654.321")),
new DecimalV3Literal(new BigDecimal("123.45")));
expression = TypeCoercionUtils.processBinaryArithmetic(sub);
Assertions.assertEquals(expression.child(0),
new Cast(multiply.child(0), DecimalV3Type.createDecimalV3Type(9, 3)));
}
}

0 comments on commit 40e1c2a

Please sign in to comment.