Skip to content

Commit ec6fc74

Browse files
committed
[SPARK-39210][SQL] Provide query context of Decimal overflow in AVG when WSCG is off
### What changes were proposed in this pull request? Similar to #36525, this PR provides runtime error query context for the Average expression when WSCG is off. ### Why are the changes needed? Enhance the runtime error query context of Average function. After changes, it works when the whole stage codegen is not available. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New UT Closes #36582 from gengliangwang/fixAvgContext. Authored-by: Gengliang Wang <gengliang@apache.org> Signed-off-by: Gengliang Wang <gengliang@apache.org> (cherry picked from commit 8b5b3e9) Signed-off-by: Gengliang Wang <gengliang@apache.org>
1 parent 72eb58a commit ec6fc74

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ abstract class AverageBase
8181

8282
// If all input are nulls, count will be 0 and we will get null after the division.
8383
// We can't directly use `/` as it throws an exception under ansi mode.
84-
protected def getEvaluateExpression = child.dataType match {
84+
protected def getEvaluateExpression(queryContext: String) = child.dataType match {
8585
case _: DecimalType =>
8686
DecimalPrecision.decimalAndDecimal()(
8787
Divide(
88-
CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd),
88+
CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd, queryContext),
8989
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
9090
case _: YearMonthIntervalType =>
9191
If(EqualTo(count, Literal(0L)),
@@ -123,7 +123,7 @@ abstract class AverageBase
123123
since = "1.0.0")
124124
case class Average(
125125
child: Expression,
126-
useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends AverageBase {
126+
useAnsiAdd: Boolean = SQLConf.get.ansiEnabled) extends AverageBase with SupportQueryContext {
127127
def this(child: Expression) = this(child, useAnsiAdd = SQLConf.get.ansiEnabled)
128128

129129
override protected def withNewChildInternal(newChild: Expression): Average =
@@ -133,7 +133,13 @@ case class Average(
133133

134134
override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
135135

136-
override lazy val evaluateExpression: Expression = getEvaluateExpression
136+
override lazy val evaluateExpression: Expression = getEvaluateExpression(queryContext)
137+
138+
override def initQueryContext(): String = if (useAnsiAdd) {
139+
origin.context
140+
} else {
141+
""
142+
}
137143
}
138144

139145
// scalastyle:off line.size.limit
@@ -192,7 +198,7 @@ case class TryAverage(child: Expression) extends AverageBase {
192198
}
193199

194200
override lazy val evaluateExpression: Expression = {
195-
addTryEvalIfNeeded(getEvaluateExpression)
201+
addTryEvalIfNeeded(getEvaluateExpression(""))
196202
}
197203

198204
override protected def withNewChildInternal(newChild: Expression): Expression =

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4423,16 +4423,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
44234423
}
44244424
}
44254425

4426-
test("SPARK-39190, SPARK-39208: Query context of decimal overflow error should be serialized " +
4427-
"to executors when WSCG is off") {
4426+
test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " +
4427+
"be serialized to executors when WSCG is off") {
44284428
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
44294429
SQLConf.ANSI_ENABLED.key -> "true") {
44304430
withTable("t") {
44314431
sql("create table t(d decimal(38, 0)) using parquet")
44324432
sql("insert into t values (6e37BD),(6e37BD)")
44334433
Seq(
44344434
"select d / 0.1 from t",
4435-
"select sum(d) from t").foreach { query =>
4435+
"select sum(d) from t",
4436+
"select avg(d) from t").foreach { query =>
44364437
val msg = intercept[SparkException] {
44374438
sql(query).collect()
44384439
}.getMessage

0 commit comments

Comments
 (0)