Skip to content

[SPARK-39316][SQL] Merge PromotePrecision and CheckOverflow into decimal binary arithmetic #36698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,8 @@ import org.apache.spark.sql.types._
*
* Operation Result Precision Result Scale
* ------------------------------------------------------------------------
* e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
* e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
* e1 * e2 p1 + p2 + 1 s1 + s2
* e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
* e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
* e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
*
* When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale
* needed are out of the range of available values, the scale is reduced up to 6, in order to
* prevent the truncation of the integer part of the decimals.
*
* To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
* precision, do the math on unlimited-precision numbers, then introduce casts back to the
* required fixed precision. This allows us to do all rounding and overflow handling in the
Expand All @@ -60,7 +51,7 @@ import org.apache.spark.sql.types._
*/
// scalastyle:on
object DecimalPrecision extends TypeCoercionRule {
import scala.math.{max, min}
import scala.math.max

private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType

Expand All @@ -75,132 +66,17 @@ object DecimalPrecision extends TypeCoercionRule {
DecimalType.bounded(range + scale, scale)
}

private def promotePrecision(e: Expression, dataType: DataType): Expression = {
PromotePrecision(Cast(e, dataType))
}

override def transform: PartialFunction[Expression, Expression] = {
decimalAndDecimal()
.orElse(integralAndDecimalLiteral)
.orElse(nondecimalAndDecimal(conf.literalPickMinimumPrecision))
}

private[catalyst] def decimalAndDecimal(): PartialFunction[Expression, Expression] = {
decimalAndDecimal(conf.decimalOperationsAllowPrecisionLoss, !conf.ansiEnabled)
}

/** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
private[catalyst] def decimalAndDecimal(allowPrecisionLoss: Boolean, nullOnOverflow: Boolean)
: PartialFunction[Expression, Expression] = {
/** Decimal precision promotion for binary comparison. */
private def decimalAndDecimal(): PartialFunction[Expression, Expression] = {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e

// Skip nodes who is already promoted
case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e

case a @ Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(
a.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
resultType, nullOnOverflow)

case s @ Subtract(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(
s.copy(left = promotePrecision(e1, resultType), right = promotePrecision(e2, resultType)),
resultType, nullOnOverflow)

case m @ Multiply(
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(
m.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
resultType, nullOnOverflow)

case d @ Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
val intDig = p1 - s1 + s2
val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
val prec = intDig + scale
DecimalType.adjustPrecisionScale(prec, scale)
} else {
var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
val diff = (intDig + decDig) - DecimalType.MAX_SCALE
if (diff > 0) {
decDig -= diff / 2 + 1
intDig = DecimalType.MAX_SCALE - decDig
}
DecimalType.bounded(intDig + decDig, decDig)
}
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(
d.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
resultType, nullOnOverflow)

case r @ Remainder(
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(
r.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
resultType, nullOnOverflow)

case p @ Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
// resultType may have lower precision, so we cast them into wider type first.
val widerType = widerDecimalType(p1, s1, p2, s2)
CheckOverflow(
p.copy(left = promotePrecision(e1, widerType), right = promotePrecision(e2, widerType)),
resultType, nullOnOverflow)

case expr @ IntegralDivide(
e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2), _) =>
val widerType = widerDecimalType(p1, s1, p2, s2)
val promotedExpr = expr.copy(
left = promotePrecision(e1, widerType),
right = promotePrecision(e2, widerType))
if (expr.dataType.isInstanceOf[DecimalType]) {
// This follows division rule
val intDig = p1 - s1 + s2
// No precision loss can happen as the result scale is 0.
// Overflow can happen only in the promote precision of the operands, but if none of them
// overflows in that phase, no overflow can happen, but CheckOverflow is needed in order
// to return a decimal with the proper scale and precision
CheckOverflow(promotedExpr, DecimalType.bounded(intDig, 0), nullOnOverflow)
} else {
promotedExpr
}

case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
val resultType = widerDecimalType(p1, s1, p2, s2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,6 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
collect(child, !negate)
case CheckOverflow(child, _, _) =>
collect(child, negate)
case PromotePrecision(child) =>
collect(child, negate)
case Cast(child, dataType, _, _) =>
dataType match {
case _: NumericType | _: TimestampType => collect(child, negate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern}
Expand Down Expand Up @@ -67,6 +67,11 @@ abstract class AverageBase
lazy val sum = AttributeReference("sum", sumDataType)()
lazy val count = AttributeReference("count", LongType)()

protected def add(left: Expression, right: Expression): Expression = left.dataType match {
case _: DecimalType => DecimalAddNoOverflowCheck(left, right, left.dataType)
case _ => Add(left, right, useAnsiAdd)
}

override lazy val aggBufferAttributes = sum :: count :: Nil

override lazy val initialValues = Seq(
Expand All @@ -75,18 +80,17 @@ abstract class AverageBase
)

protected def getMergeExpressions = Seq(
/* sum = */ Add(sum.left, sum.right, useAnsiAdd),
/* sum = */ add(sum.left, sum.right),
/* count = */ count.left + count.right
)

// If all input are nulls, count will be 0 and we will get null after the division.
// We can't directly use `/` as it throws an exception under ansi mode.
protected def getEvaluateExpression(queryContext: String) = child.dataType match {
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal()(
Divide(
CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd, queryContext),
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
Divide(
CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd, queryContext),
count.cast(DecimalType.LongDecimal), failOnError = false).cast(resultType)
case _: YearMonthIntervalType =>
If(EqualTo(count, Literal(0L)),
Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))
Expand All @@ -99,10 +103,9 @@ abstract class AverageBase

protected def getUpdateExpressions: Seq[Expression] = Seq(
/* sum = */
Add(
add(
sum,
coalesce(child.cast(sumDataType), Literal.default(sumDataType)),
failOnError = useAnsiAdd),
coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
/* count = */ If(child.isNull, count, count + 1L)
)

Expand Down Expand Up @@ -190,7 +193,7 @@ case class TryAverage(child: Expression) extends AverageBase {
Literal.create(null, resultType),
// If both the buffer and the input do not overflow, just add them, as they can't be
// null.
TryEval(Add(KnownNotNull(sum.left), KnownNotNull(sum.right), useAnsiAdd))),
TryEval(add(KnownNotNull(sum.left), KnownNotNull(sum.right)))),
expressions(1))
} else {
expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate

private lazy val zero = Literal.default(resultType)

private def add(left: Expression, right: Expression): Expression = left.dataType match {
case _: DecimalType => DecimalAddNoOverflowCheck(left, right, left.dataType)
case _ => Add(left, right, useAnsiAdd)
}

override lazy val aggBufferAttributes = if (shouldTrackIsEmpty) {
sum :: isEmpty :: Nil
} else {
Expand All @@ -82,9 +87,9 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
// null if overflow happens under non-ansi mode.
val sumExpr = if (child.nullable) {
If(child.isNull, sum,
Add(sum, KnownNotNull(child).cast(resultType), failOnError = useAnsiAdd))
add(sum, KnownNotNull(child).cast(resultType)))
} else {
Add(sum, child.cast(resultType), failOnError = useAnsiAdd)
add(sum, child.cast(resultType))
}
// The buffer becomes non-empty after seeing the first not-null input.
val isEmptyExpr = if (child.nullable) {
Expand All @@ -99,10 +104,10 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
// in case the input is nullable. The `sum` can only be null if there is no value, as
// non-decimal type can produce overflowed value under non-ansi mode.
if (child.nullable) {
Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), failOnError = useAnsiAdd),
Seq(coalesce(add(coalesce(sum, zero), child.cast(resultType)),
sum))
} else {
Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = useAnsiAdd))
Seq(add(coalesce(sum, zero), child.cast(resultType)))
}
}

Expand All @@ -128,11 +133,11 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate
// If both the buffer and the input do not overflow, just add them, as they can't be
// null. See the comments inside `updateExpressions`: `sum` can only be null if
// overflow happens.
Add(KnownNotNull(sum.left), KnownNotNull(sum.right), useAnsiAdd)),
add(KnownNotNull(sum.left), KnownNotNull(sum.right))),
isEmpty.left && isEmpty.right)
} else {
Seq(coalesce(
Add(coalesce(sum.left, zero), sum.right, failOnError = useAnsiAdd),
add(coalesce(sum.left, zero), sum.right),
sum.left))
}

Expand Down
Loading