Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ import org.blaze.protobuf.PhysicalExprNode
object NativeConverters extends Logging {
val udfJsonEnabled: Boolean =
SparkEnv.get.conf.getBoolean("spark.blaze.udf.UDFJson.enabled", defaultValue = true)
private val decimalBinaryEnabled: Boolean =
SparkEnv.get.conf.getBoolean("spark.blaze.decimal.binary.enabled", defaultValue = false)

def scalarTypeSupported(dataType: DataType): Boolean = {
dataType match {
Expand Down Expand Up @@ -360,6 +362,10 @@ object NativeConverters extends Logging {
val buildBinaryExprNode = this.buildBinaryExprNode(_, _, _, isPruningExpr, fallback)
val buildScalarFunction = this.buildScalarFunctionNode(_, _, _, isPruningExpr, fallback)
val buildExtScalarFunction = this.buildExtScalarFunctionNode(_, _, _, isPruningExpr, fallback)
def isDecimalType(lhs: Expression, rhs: Expression): Boolean =
lhs.dataType.isInstanceOf[DecimalType] && rhs.dataType.isInstanceOf[DecimalType]
def enableDecimalBinaryOperator(lhs: Expression, rhs: Expression): Boolean =
decimalBinaryEnabled || (!decimalBinaryEnabled && !isDecimalType(lhs, rhs))

sparkExpr match {
case e: NativeExprWrapperBase => e.wrapped
Expand Down Expand Up @@ -502,11 +508,10 @@ object NativeConverters extends Logging {
case GreaterThanOrEqual(lhs, rhs) => buildBinaryExprNode(lhs, rhs, "GtEq")
case LessThanOrEqual(lhs, rhs) => buildBinaryExprNode(lhs, rhs, "LtEq")

case e: Add =>
case e: Add if enableDecimalBinaryOperator(e.left, e.right) =>
val lhs = e.left
val rhs = e.right
val resultType = e.dataType
if (lhs.dataType.isInstanceOf[DecimalType] && rhs.dataType.isInstanceOf[DecimalType]) {
if (isDecimalType(lhs, rhs)) {
def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
val resultScale = max(s1, s2)
val resultPrecision = max(p1 - s1, p2 - s2) + resultScale + 1
Expand Down Expand Up @@ -542,11 +547,10 @@ object NativeConverters extends Logging {
buildBinaryExprNode(lhs, rhs, "Plus")
}

case e: Subtract =>
case e: Subtract if enableDecimalBinaryOperator(e.left, e.right) =>
val lhs = e.left
val rhs = e.right
val resultType = e.dataType
if (lhs.dataType.isInstanceOf[DecimalType] && rhs.dataType.isInstanceOf[DecimalType]) {
if (isDecimalType(lhs, rhs)) {
// copied from spark3.5
def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
val resultScale = max(s1, s2)
Expand Down Expand Up @@ -583,10 +587,10 @@ object NativeConverters extends Logging {
buildBinaryExprNode(lhs, rhs, "Minus")
}

case e: Multiply =>
case e: Multiply if enableDecimalBinaryOperator(e.left, e.right) =>
val lhs = e.left
val rhs = e.right
if (lhs.dataType.isInstanceOf[DecimalType] && rhs.dataType.isInstanceOf[DecimalType]) {
if (isDecimalType(lhs, rhs)) {
// copied from spark3.5
def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
val resultScale = s1 + s2
Expand Down Expand Up @@ -623,10 +627,10 @@ object NativeConverters extends Logging {
buildBinaryExprNode(lhs, rhs, "Multiply")
}

case e: Divide =>
case e: Divide if enableDecimalBinaryOperator(e.left, e.right) =>
val lhs = e.left
val rhs = e.right
if (lhs.dataType.isInstanceOf[DecimalType] && rhs.dataType.isInstanceOf[DecimalType]) {
if (isDecimalType(lhs, rhs)) {
// copied from spark3.5
def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
if (SQLConf.get.decimalOperationsAllowPrecisionLoss) {
Expand Down