Skip to content

[SPARK-34037][SQL] Remove unnecessary upcasting for Avg & Sum which handle by themself internally #31079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -634,17 +634,6 @@ object TypeCoercion {

m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })

// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
Comment on lines -637 to -638
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a reason about promoting these aggregation functions, why removing them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the type-coercion for numeric types of average and sum is not necessary at all, as the resultType and sumType can prevent the overflow. and it causes the issue here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, Sum/Average already casts the inputs internally.

case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))

case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest.
case Average(e @ IntegralType()) if e.dataType != LongType =>
Average(Cast(e, LongType))
case Average(e @ FractionalType()) if e.dataType != DoubleType =>
Average(Cast(e, DoubleType))

// Hive lets you do aggregation of timestamps... for some reason
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
Expand Down
46 changes: 23 additions & 23 deletions sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,23 @@ struct<plan:string>

== Analyzed Logical Plan ==
sum(DISTINCT val): bigint
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
+- SubqueryAlias spark_catalog.default.explain_temp1
+- Relation[key#x,val#x] parquet

== Optimized Logical Plan ==
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
+- Project [val#x]
+- Relation[key#x,val#x] parquet

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
+- HashAggregate(keys=[], functions=[sum(distinct val#x)], output=[sum(DISTINCT val)#xL])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
+- HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
+- HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
+- HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- HashAggregate(keys=[], functions=[partial_sum(distinct val#x)], output=[sum#xL])
+- HashAggregate(keys=[val#x], functions=[], output=[val#x])
+- Exchange hashpartitioning(val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
+- HashAggregate(keys=[val#x], functions=[], output=[val#x])
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>


Expand Down Expand Up @@ -615,7 +615,7 @@ Input [2]: [key#x, val#x]
(14) HashAggregate
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -626,9 +626,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(16) HashAggregate
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]

(17) AdaptiveSparkPlan
Output [1]: [avg(key)#x]
Expand Down Expand Up @@ -681,7 +681,7 @@ ReadSchema: struct<key:int>
(5) HashAggregate
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -692,9 +692,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(7) HashAggregate
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]

(8) AdaptiveSparkPlan
Output [1]: [avg(key)#x]
Expand All @@ -717,7 +717,7 @@ ReadSchema: struct<key:int>
(10) HashAggregate
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -728,9 +728,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(12) HashAggregate
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]

(13) AdaptiveSparkPlan
Output [1]: [avg(key)#x]
Expand Down Expand Up @@ -947,7 +947,7 @@ ReadSchema: struct<key:int,val:int>
(2) HashAggregate
Input [2]: [key#x, val#x]
Keys: []
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Functions [3]: [partial_count(val#x), partial_sum(key#x), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Aggregate Attributes [3]: [count#xL, sum#xL, count#xL]
Results [3]: [count#xL, sum#xL, count#xL]

Expand All @@ -958,9 +958,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(4) HashAggregate
Input [3]: [count#xL, sum#xL, count#xL]
Keys: []
Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL]
Results [2]: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]
Functions [3]: [count(val#x), sum(key#x), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(key#x)#xL, count(key#x)#xL]
Results [2]: [(count(val#x)#xL + sum(key#x)#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]

(5) AdaptiveSparkPlan
Output [2]: [TOTAL#xL, count(key) FILTER (WHERE (val > 1))#xL]
Expand Down
38 changes: 19 additions & 19 deletions sql/core/src/test/resources/sql-tests/results/explain.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@ struct<plan:string>

== Analyzed Logical Plan ==
sum(DISTINCT val): bigint
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
+- SubqueryAlias spark_catalog.default.explain_temp1
+- Relation[key#x,val#x] parquet

== Optimized Logical Plan ==
Aggregate [sum(distinct cast(val#x as bigint)) AS sum(DISTINCT val)#xL]
Aggregate [sum(distinct val#x) AS sum(DISTINCT val)#xL]
+- Project [val#x]
+- Relation[key#x,val#x] parquet

== Physical Plan ==
*HashAggregate(keys=[], functions=[sum(distinct cast(val#x as bigint)#xL)], output=[sum(DISTINCT val)#xL])
*HashAggregate(keys=[], functions=[sum(distinct val#x)], output=[sum(DISTINCT val)#xL])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
+- *HashAggregate(keys=[], functions=[partial_sum(distinct cast(val#x as bigint)#xL)], output=[sum#xL])
+- *HashAggregate(keys=[cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- Exchange hashpartitioning(cast(val#x as bigint)#xL, 4), ENSURE_REQUIREMENTS, [id=#x]
+- *HashAggregate(keys=[cast(val#x as bigint) AS cast(val#x as bigint)#xL], functions=[], output=[cast(val#x as bigint)#xL])
+- *HashAggregate(keys=[], functions=[partial_sum(distinct val#x)], output=[sum#xL])
+- *HashAggregate(keys=[val#x], functions=[], output=[val#x])
+- Exchange hashpartitioning(val#x, 4), ENSURE_REQUIREMENTS, [id=#x]
+- *HashAggregate(keys=[val#x], functions=[], output=[val#x])
+- *ColumnarToRow
+- FileScan parquet default.explain_temp1[val#x] Batched: true, DataFilters: [], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<val:int>

Expand Down Expand Up @@ -620,7 +620,7 @@ Input [2]: [key#x, val#x]
(15) HashAggregate [codegen id : 1]
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -631,9 +631,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(17) HashAggregate [codegen id : 2]
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]


-- !query
Expand Down Expand Up @@ -684,7 +684,7 @@ Input [1]: [key#x]
(6) HashAggregate [codegen id : 1]
Input [1]: [key#x]
Keys: []
Functions [1]: [partial_avg(cast(key#x as bigint))]
Functions [1]: [partial_avg(key#x)]
Aggregate Attributes [2]: [sum#x, count#xL]
Results [2]: [sum#x, count#xL]

Expand All @@ -695,9 +695,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(8) HashAggregate [codegen id : 2]
Input [2]: [sum#x, count#xL]
Keys: []
Functions [1]: [avg(cast(key#x as bigint))]
Aggregate Attributes [1]: [avg(cast(key#x as bigint))#x]
Results [1]: [avg(cast(key#x as bigint))#x AS avg(key)#x]
Functions [1]: [avg(key#x)]
Aggregate Attributes [1]: [avg(key#x)#x]
Results [1]: [avg(key#x)#x AS avg(key)#x]

Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x]

Expand Down Expand Up @@ -895,7 +895,7 @@ Input [2]: [key#x, val#x]
(3) HashAggregate [codegen id : 1]
Input [2]: [key#x, val#x]
Keys: []
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Functions [3]: [partial_count(val#x), partial_sum(key#x), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Aggregate Attributes [3]: [count#xL, sum#xL, count#xL]
Results [3]: [count#xL, sum#xL, count#xL]

Expand All @@ -906,9 +906,9 @@ Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#x]
(5) HashAggregate [codegen id : 2]
Input [3]: [count#xL, sum#xL, count#xL]
Keys: []
Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL]
Results [2]: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]
Functions [3]: [count(val#x), sum(key#x), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(key#x)#xL, count(key#x)#xL]
Results [2]: [(count(val#x)#xL + sum(key#x)#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT))
aggregate functions are not allowed in GROUP BY, but found sum(data.`b`)


-- !query
Expand All @@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT))
aggregate functions are not allowed in GROUP BY, but found (sum(data.`b`) + CAST(2 AS BIGINT))


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ struct<>
org.apache.spark.sql.AnalysisException

Aggregate/Window/Generate expressions are not valid in where clause of the query.
Expression in where clause: [(sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT)) = CAST(b.`four` AS BIGINT))]
Invalid expressions: [sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT))]
Expression in where clause: [(sum(DISTINCT (outer(a.`four`) + b.`four`)) = CAST(b.`four` AS BIGINT))]
Invalid expressions: [sum(DISTINCT (outer(a.`four`) + b.`four`))]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ AND t2b = (SELECT max(avg)
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(CAST(t2.`t2b` AS BIGINT)) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.
grouping expressions sequence is empty, and 't2.`t2b`' is not an aggregate function. Wrap '(avg(t2.`t2b`) AS `avg`)' in windowing function(s) or wrap 't2.`t2b`' in first() (or first_value) if you don't care which value you get.


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ struct<>
org.apache.spark.sql.AnalysisException

Aggregate/Window/Generate expressions are not valid in where clause of the query.
Expression in where clause: [(sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT)) = CAST(CAST(udf(ansi_cast(four as string)) AS INT) AS BIGINT))]
Invalid expressions: [sum(DISTINCT CAST((outer(a.`four`) + b.`four`) AS BIGINT))]
Expression in where clause: [(sum(DISTINCT (outer(a.`four`) + b.`four`)) = CAST(CAST(udf(ansi_cast(four as string)) AS INT) AS BIGINT))]
Invalid expressions: [sum(DISTINCT (outer(a.`four`) + b.`four`))]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Results [5]: [i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, sum#27

(37) Exchange
Input [5]: [i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, sum#27]
Arguments: hashpartitioning(i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, 5), true, [id=#28]
Arguments: hashpartitioning(i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, 5), ENSURE_REQUIREMENTS, [id=#28]

(38) HashAggregate [codegen id : 7]
Input [5]: [i_brand#21, i_brand_id#20, i_manufact_id#22, i_manufact#23, sum#27]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Results [5]: [i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, sum#27

(37) Exchange
Input [5]: [i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, sum#27]
Arguments: hashpartitioning(i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, 5), true, [id=#28]
Arguments: hashpartitioning(i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, 5), ENSURE_REQUIREMENTS, [id=#28]

(38) HashAggregate [codegen id : 7]
Input [5]: [i_brand#12, i_brand_id#11, i_manufact_id#13, i_manufact#14, sum#27]
Expand Down
Loading