Skip to content

Commit 0be5fb4

Browse files
aokolnychyigatorsmile
aokolnychyi
authored andcommitted
[SPARK-21332][SQL] Incorrect result type inferred for some decimal expressions
## What changes were proposed in this pull request? This PR changes the direction of expression transformation in the DecimalPrecision rule. Previously, the expressions were transformed down, which led to incorrect result types when decimal expressions had other decimal expressions as their operands. The root cause of this issue was in visiting outer nodes before their children. Consider the example below: ``` val inputSchema = StructType(StructField("col", DecimalType(26, 6)) :: Nil) val sc = spark.sparkContext val rdd = sc.parallelize(1 to 2).map(_ => Row(BigDecimal(12))) val df = spark.createDataFrame(rdd, inputSchema) // Works correctly since no nested decimal expression is involved // Expected result type: (26, 6) * (26, 6) = (38, 12) df.select($"col" * $"col").explain(true) df.select($"col" * $"col").printSchema() // Gives a wrong result since there is a nested decimal expression that should be visited first // Expected result type: ((26, 6) * (26, 6)) * (26, 6) = (38, 12) * (26, 6) = (38, 18) df.select($"col" * $"col" * $"col").explain(true) df.select($"col" * $"col" * $"col").printSchema() ``` The example above gives the following output: ``` // Correct result without sub-expressions == Parsed Logical Plan == 'Project [('col * 'col) AS (col * col)#4] +- LogicalRDD [col#1] == Analyzed Logical Plan == (col * col): decimal(38,12) Project [CheckOverflow((promote_precision(cast(col#1 as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) AS (col * col)#4] +- LogicalRDD [col#1] == Optimized Logical Plan == Project [CheckOverflow((col#1 * col#1), DecimalType(38,12)) AS (col * col)#4] +- LogicalRDD [col#1] == Physical Plan == *Project [CheckOverflow((col#1 * col#1), DecimalType(38,12)) AS (col * col)#4] +- Scan ExistingRDD[col#1] // Schema root |-- (col * col): decimal(38,12) (nullable = true) // Incorrect result with sub-expressions == Parsed Logical Plan == 'Project [(('col * 'col) * 'col) AS ((col * col) * col)apache#11] +- LogicalRDD [col#1] == Analyzed Logical Plan == ((col * col) * col): decimal(38,12) Project [CheckOverflow((promote_precision(cast(CheckOverflow((promote_precision(cast(col#1 as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) as decimal(26,6))) * promote_precision(cast(col#1 as decimal(26,6)))), DecimalType(38,12)) AS ((col * col) * col)apache#11] +- LogicalRDD [col#1] == Optimized Logical Plan == Project [CheckOverflow((cast(CheckOverflow((col#1 * col#1), DecimalType(38,12)) as decimal(26,6)) * col#1), DecimalType(38,12)) AS ((col * col) * col)apache#11] +- LogicalRDD [col#1] == Physical Plan == *Project [CheckOverflow((cast(CheckOverflow((col#1 * col#1), DecimalType(38,12)) as decimal(26,6)) * col#1), DecimalType(38,12)) AS ((col * col) * col)apache#11] +- Scan ExistingRDD[col#1] // Schema root |-- ((col * col) * col): decimal(38,12) (nullable = true) ``` ## How was this patch tested? This PR was tested with available unit tests. Moreover, there are tests to cover previously failing scenarios. Author: aokolnychyi <anton.okolnychyi@sap.com> Closes apache#18583 from aokolnychyi/spark-21332.
1 parent 5952ad2 commit 0be5fb4

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ object DecimalPrecision extends Rule[LogicalPlan] {
8080

8181
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
8282
// fix decimal precision for expressions
83-
case q => q.transformExpressions(
83+
case q => q.transformExpressionsUp(
8484
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))
8585
}
8686

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,14 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
9090
checkType(Average(d1), DecimalType(6, 5))
9191

9292
checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
93+
checkType(Add(Add(d1, d1), d1), DecimalType(4, 1))
94+
checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1))
9395
checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
9496
checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
97+
checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2))
98+
checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4))
99+
checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6))
100+
checkType(Sum(Add(d1, d1)), DecimalType(13, 1))
95101
}
96102

97103
test("Comparison operations") {

0 commit comments

Comments
 (0)