Skip to content

Commit b1ce4b5

Browse files
committed
[SPARK-32816][SQL] Fix analyzer bug when aggregating multiple distinct DECIMAL columns
1 parent 125cbe3 commit b1ce4b5

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
142142
RewriteNonCorrelatedExists,
143143
ComputeCurrentTime,
144144
GetCurrentDatabaseAndCatalog(catalogManager),
145-
RewriteDistinctAggregates,
146145
ReplaceDeduplicateWithAggregate) ::
147146
//////////////////////////////////////////////////////////////////////////////////////////
148147
// Optimizer rules start here
@@ -196,6 +195,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
196195
EliminateSorts) :+
197196
Batch("Decimal Optimizations", fixedPoint,
198197
DecimalAggregates) :+
198+
Batch("Distinct Aggregate Rewrite", Once,
199+
RewriteDistinctAggregates) :+
199200
Batch("Object Expressions Optimization", fixedPoint,
200201
EliminateMapObjects,
201202
CombineTypedFilters,

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2555,6 +2555,19 @@ class DataFrameSuite extends QueryTest
25552555
val df = Seq(0.0 -> -0.0).toDF("pos", "neg")
25562556
checkAnswer(df.select($"pos" > $"neg"), Row(false))
25572557
}
2558+
2559+
test("SPARK-32816: aggregating multiple distinct DECIMAL columns") {
2560+
withTempPath { path =>
2561+
spark.range(0, 100, 1, 1)
2562+
.selectExpr("id", "cast(id as decimal(9, 0)) as decimal_col")
2563+
.write.mode("overwrite")
2564+
.parquet(path.getAbsolutePath)
2565+
spark.read.parquet(path.getAbsolutePath).createOrReplaceTempView("test_table")
2566+
checkAnswer(
2567+
sql("select avg(distinct decimal_col), sum(distinct decimal_col) from test_table"),
2568+
Row(49.5, 4950))
2569+
}
2570+
}
25582571
}
25592572

25602573
case class GroupByKey(a: Int, b: Int)

0 commit comments

Comments
 (0)