Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ class ShimsImpl extends Shims with Logging {
fallback: Expression => pb.PhysicalExprNode): Option[pb.PhysicalExprNode] = {
import org.apache.spark.sql.catalyst.expressions.PromotePrecision
e match {
case PromotePrecision(_1) =>
case PromotePrecision(_1) if NativeConverters.decimalArithOpEnabled =>
Some(NativeConverters.convertExprWithFallback(_1, isPruningExpr, fallback))
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ import org.blaze.protobuf.PhysicalExprNode
object NativeConverters extends Logging {
val udfJsonEnabled: Boolean =
SparkEnv.get.conf.getBoolean("spark.blaze.udf.UDFJson.enabled", defaultValue = true)
private val decimalBinaryEnabled: Boolean =
SparkEnv.get.conf.getBoolean("spark.blaze.decimal.binary.enabled", defaultValue = false)
val decimalArithOpEnabled: Boolean =
SparkEnv.get.conf.getBoolean("spark.blaze.decimal.arithOp.enabled", defaultValue = false)

def scalarTypeSupported(dataType: DataType): Boolean = {
dataType match {
Expand Down Expand Up @@ -365,10 +365,6 @@ object NativeConverters extends Logging {
this.buildScalarFunctionNode(_, _, _, isPruningExpr, fallback)
val buildExtScalarFunction: (String, Seq[Expression], DataType) => PhysicalExprNode =
this.buildExtScalarFunctionNode(_, _, _, isPruningExpr, fallback)
def isDecimalType(lhs: Expression, rhs: Expression): Boolean =
lhs.dataType.isInstanceOf[DecimalType] && rhs.dataType.isInstanceOf[DecimalType]
def enableDecimalBinaryOperator(lhs: Expression, rhs: Expression): Boolean =
decimalBinaryEnabled || (!decimalBinaryEnabled && !isDecimalType(lhs, rhs))

sparkExpr match {
case e: NativeExprWrapperBase => e.wrapped
Expand Down Expand Up @@ -511,10 +507,10 @@ object NativeConverters extends Logging {
case GreaterThanOrEqual(lhs, rhs) => buildBinaryExprNode(lhs, rhs, "GtEq")
case LessThanOrEqual(lhs, rhs) => buildBinaryExprNode(lhs, rhs, "LtEq")

case e: Add if enableDecimalBinaryOperator(e.left, e.right) =>
case e: Add if e.dataType.isInstanceOf[DecimalType] || decimalArithOpEnabled =>
val lhs = e.left
val rhs = e.right
if (isDecimalType(lhs, rhs)) {
if (e.dataType.isInstanceOf[DecimalType]) {
def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
val resultScale = max(s1, s2)
val resultPrecision = max(p1 - s1, p2 - s2) + resultScale + 1
Expand Down Expand Up @@ -550,10 +546,10 @@ object NativeConverters extends Logging {
buildBinaryExprNode(lhs, rhs, "Plus")
}

case e: Subtract if enableDecimalBinaryOperator(e.left, e.right) =>
case e: Subtract if e.dataType.isInstanceOf[DecimalType] || decimalArithOpEnabled =>
val lhs = e.left
val rhs = e.right
if (isDecimalType(lhs, rhs)) {
if (e.dataType.isInstanceOf[DecimalType]) {
// copied from spark3.5
def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
val resultScale = max(s1, s2)
Expand Down Expand Up @@ -590,10 +586,10 @@ object NativeConverters extends Logging {
buildBinaryExprNode(lhs, rhs, "Minus")
}

case e: Multiply if enableDecimalBinaryOperator(e.left, e.right) =>
case e: Multiply if e.dataType.isInstanceOf[DecimalType] || decimalArithOpEnabled =>
val lhs = e.left
val rhs = e.right
if (isDecimalType(lhs, rhs)) {
if (e.dataType.isInstanceOf[DecimalType]) {
// copied from spark3.5
def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
val resultScale = s1 + s2
Expand Down Expand Up @@ -630,10 +626,10 @@ object NativeConverters extends Logging {
buildBinaryExprNode(lhs, rhs, "Multiply")
}

case e: Divide if enableDecimalBinaryOperator(e.left, e.right) =>
case e: Divide if e.dataType.isInstanceOf[DecimalType] || decimalArithOpEnabled =>
val lhs = e.left
val rhs = e.right
if (isDecimalType(lhs, rhs)) {
if (e.dataType.isInstanceOf[DecimalType]) {
// copied from spark3.5
def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand Down Expand Up @@ -941,22 +937,19 @@ object NativeConverters extends Logging {
pb.PhysicalExprNode.newBuilder().setCase(caseExpr).build()

// expressions for DecimalPrecision rule
case UnscaledValue(_1) =>
case UnscaledValue(_1) if decimalArithOpEnabled =>
val args = _1 :: Nil
buildExtScalarFunction("UnscaledValue", args, LongType)

case e: MakeDecimal =>
// case MakeDecimal(_1, precision, scale) =>
// assert(!SQLConf.get.ansiEnabled)
case e: MakeDecimal if decimalArithOpEnabled =>
val precision = e.precision
val scale = e.scale
val args =
e.child :: Literal
.apply(precision, IntegerType) :: Literal.apply(scale, IntegerType) :: Nil
buildExtScalarFunction("MakeDecimal", args, DecimalType(precision, scale))

case e: CheckOverflow =>
// case CheckOverflow(_1, DecimalType(precision, scale)) =>
case e: CheckOverflow if decimalArithOpEnabled =>
val precision = e.dataType.precision
val scale = e.dataType.scale
val args =
Expand Down