diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala index 40518982958cd..7d73cf211a6e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/linearRegression.scala @@ -271,8 +271,14 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues - override lazy val updateExpressions: Seq[Expression] = - covarPop.updateExpressions ++ varPop.updateExpressions + override lazy val updateExpressions: Seq[Expression] = { + // RegrSlope only handles paris where both y and x are non-empty, so we need additional + // judgment for calculating VariancePop. + val isNull = left.isNull || right.isNull + covarPop.updateExpressions ++ varPop.updateExpressions.zip(varPop.aggBufferAttributes).map { + case (newValue, oldValue) => If(isNull, oldValue, newValue) + } + } override lazy val mergeExpressions: Seq[Expression] = covarPop.mergeExpressions ++ varPop.mergeExpressions @@ -324,8 +330,14 @@ case class RegrIntercept(left: Expression, right: Expression) extends Declarativ override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues - override lazy val updateExpressions: Seq[Expression] = - covarPop.updateExpressions ++ varPop.updateExpressions + override lazy val updateExpressions: Seq[Expression] = { + // RegrIntercept only handles paris where both y and x are non-empty, so we need additional + // judgment for calculating VariancePop. + val isNull = left.isNull || right.isNull + covarPop.updateExpressions ++ varPop.updateExpressions.zip(varPop.aggBufferAttributes).map { + case (newValue, oldValue) => If(isNull, oldValue, newValue) + } + } override lazy val mergeExpressions: Seq[Expression] = covarPop.mergeExpressions ++ varPop.mergeExpressions diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out index 5f130cd1d422c..3a33dd7c84ed2 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/linear-regression.sql.out @@ -1,11 +1,11 @@ -- Automatically generated by SQLQueryTestSuite -- !query CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES -(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35) +(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40) AS testRegression(k, y, x) -- !query analysis CreateViewCommand `testRegression`, SELECT * FROM VALUES -(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35) +(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40) AS testRegression(k, y, x), false, true, LocalTempView, UNSUPPORTED, true +- Project [k#x, y#x, x#x] +- SubqueryAlias testRegression diff --git a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql index c7cb5bf1117a7..df286d2a9b0a9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/linear-regression.sql @@ -1,6 +1,6 @@ -- Test data. CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES -(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35) +(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40) AS testRegression(k, y, x); -- SPARK-37613: Support ANSI Aggregate Function: regr_count diff --git a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out index 1379713a9fb0d..e511ea75aae5a 100644 --- a/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/linear-regression.sql.out @@ -1,7 +1,7 @@ -- Automatically generated by SQLQueryTestSuite -- !query CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES -(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35) +(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40) AS testRegression(k, y, x) -- !query schema struct<> @@ -31,7 +31,7 @@ SELECT k, count(*), regr_count(y, x) FROM testRegression GROUP BY k struct -- !query output 1 1 0 -2 4 3 +2 5 3 -- !query @@ -40,7 +40,7 @@ SELECT k, count(*) FILTER (WHERE x IS NOT NULL), regr_count(y, x) FROM testRegre struct -- !query output 1 0 0 -2 3 3 +2 4 3 -- !query @@ -99,7 +99,7 @@ SELECT k, avg(x), avg(y), regr_avgx(y, x), regr_avgy(y, x) FROM testRegression G struct -- !query output 1 NULL 10.0 NULL NULL -2 22.666666666666668 21.25 22.666666666666668 20.0 +2 27.0 21.25 22.666666666666668 20.0 -- !query @@ -116,7 +116,7 @@ SELECT regr_sxx(y, x) FROM testRegression -- !query schema struct -- !query output -288.66666666666663 +288.6666666666667 -- !query @@ -124,7 +124,7 @@ SELECT regr_sxx(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL -- !query schema struct -- !query output -288.66666666666663 +288.6666666666667 -- !query @@ -133,7 +133,7 @@ SELECT k, regr_sxx(y, x) FROM testRegression GROUP BY k struct -- !query output 1 NULL -2 288.66666666666663 +2 288.6666666666667 -- !query @@ -141,7 +141,7 @@ SELECT k, regr_sxx(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NU -- !query schema struct -- !query output -2 288.66666666666663 +2 288.6666666666667 -- !query @@ -215,7 +215,7 @@ SELECT regr_slope(y, x) FROM testRegression -- !query schema struct -- !query output -0.8314087759815244 +0.8314087759815242 -- !query @@ -223,7 +223,7 @@ SELECT regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NUL -- !query schema struct -- !query output -0.8314087759815244 +0.8314087759815242 -- !query @@ -232,7 +232,7 @@ SELECT k, regr_slope(y, x) FROM testRegression GROUP BY k struct -- !query output 1 NULL -2 0.8314087759815244 +2 0.8314087759815242 -- !query @@ -240,7 +240,7 @@ SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT -- !query schema struct -- !query output -2 0.8314087759815244 +2 0.8314087759815242 -- !query