Skip to content

Commit

Permalink
[SPARK-48719][SQL] Fix the calculation bug of RegrSlope & `RegrInte…
Browse files Browse the repository at this point in the history
…rcept` when the first parameter is null

### What changes were proposed in this pull request?

This PR aims to fix the calculation bug of `RegrSlope`&`RegrIntercept` when the first parameter is null. Regardless of whether the first parameter(y) or the second parameter(x) is null, this tuple should be filtered out.

### Why are the changes needed?

Fix bug.

### Does this PR introduce _any_ user-facing change?

Yes, the calculation changes when the first value of a tuple is null, but the value is truly correct.

### How was this patch tested?

Pass GA and test with `build/sbt "~sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z linear-regression.sql"`

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #47105 from wayneguow/SPARK-48719.

Authored-by: Wei Guo <guow93@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
wayneguow authored and cloud-fan committed Jul 5, 2024
1 parent 6161632 commit f1eca90
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,14 @@ case class RegrSlope(left: Expression, right: Expression) extends DeclarativeAgg

override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues

override lazy val updateExpressions: Seq[Expression] =
covarPop.updateExpressions ++ varPop.updateExpressions
override lazy val updateExpressions: Seq[Expression] = {
// RegrSlope only handles paris where both y and x are non-empty, so we need additional
// judgment for calculating VariancePop.
val isNull = left.isNull || right.isNull
covarPop.updateExpressions ++ varPop.updateExpressions.zip(varPop.aggBufferAttributes).map {
case (newValue, oldValue) => If(isNull, oldValue, newValue)
}
}

override lazy val mergeExpressions: Seq[Expression] =
covarPop.mergeExpressions ++ varPop.mergeExpressions
Expand Down Expand Up @@ -324,8 +330,14 @@ case class RegrIntercept(left: Expression, right: Expression) extends Declarativ

override lazy val initialValues: Seq[Expression] = covarPop.initialValues ++ varPop.initialValues

override lazy val updateExpressions: Seq[Expression] =
covarPop.updateExpressions ++ varPop.updateExpressions
override lazy val updateExpressions: Seq[Expression] = {
// RegrIntercept only handles paris where both y and x are non-empty, so we need additional
// judgment for calculating VariancePop.
val isNull = left.isNull || right.isNull
covarPop.updateExpressions ++ varPop.updateExpressions.zip(varPop.aggBufferAttributes).map {
case (newValue, oldValue) => If(isNull, oldValue, newValue)
}
}

override lazy val mergeExpressions: Seq[Expression] =
covarPop.mergeExpressions ++ varPop.mergeExpressions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35)
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40)
AS testRegression(k, y, x)
-- !query analysis
CreateViewCommand `testRegression`, SELECT * FROM VALUES
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35)
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40)
AS testRegression(k, y, x), false, true, LocalTempView, UNSUPPORTED, true
+- Project [k#x, y#x, x#x]
+- SubqueryAlias testRegression
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-- Test data.
CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35)
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40)
AS testRegression(k, y, x);

-- SPARK-37613: Support ANSI Aggregate Function: regr_count
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE TEMPORARY VIEW testRegression AS SELECT * FROM VALUES
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35)
(1, 10, null), (2, 10, 11), (2, 20, 22), (2, 25, null), (2, 30, 35), (2, null, 40)
AS testRegression(k, y, x)
-- !query schema
struct<>
Expand Down Expand Up @@ -31,7 +31,7 @@ SELECT k, count(*), regr_count(y, x) FROM testRegression GROUP BY k
struct<k:int,count(1):bigint,regr_count(y, x):bigint>
-- !query output
1 1 0
2 4 3
2 5 3


-- !query
Expand All @@ -40,7 +40,7 @@ SELECT k, count(*) FILTER (WHERE x IS NOT NULL), regr_count(y, x) FROM testRegre
struct<k:int,count(1) FILTER (WHERE (x IS NOT NULL)):bigint,regr_count(y, x):bigint>
-- !query output
1 0 0
2 3 3
2 4 3


-- !query
Expand Down Expand Up @@ -99,7 +99,7 @@ SELECT k, avg(x), avg(y), regr_avgx(y, x), regr_avgy(y, x) FROM testRegression G
struct<k:int,avg(x):double,avg(y):double,regr_avgx(y, x):double,regr_avgy(y, x):double>
-- !query output
1 NULL 10.0 NULL NULL
2 22.666666666666668 21.25 22.666666666666668 20.0
2 27.0 21.25 22.666666666666668 20.0


-- !query
Expand All @@ -116,15 +116,15 @@ SELECT regr_sxx(y, x) FROM testRegression
-- !query schema
struct<regr_sxx(y, x):double>
-- !query output
288.66666666666663
288.6666666666667


-- !query
SELECT regr_sxx(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL
-- !query schema
struct<regr_sxx(y, x):double>
-- !query output
288.66666666666663
288.6666666666667


-- !query
Expand All @@ -133,15 +133,15 @@ SELECT k, regr_sxx(y, x) FROM testRegression GROUP BY k
struct<k:int,regr_sxx(y, x):double>
-- !query output
1 NULL
2 288.66666666666663
2 288.6666666666667


-- !query
SELECT k, regr_sxx(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k
-- !query schema
struct<k:int,regr_sxx(y, x):double>
-- !query output
2 288.66666666666663
2 288.6666666666667


-- !query
Expand Down Expand Up @@ -215,15 +215,15 @@ SELECT regr_slope(y, x) FROM testRegression
-- !query schema
struct<regr_slope(y, x):double>
-- !query output
0.8314087759815244
0.8314087759815242


-- !query
SELECT regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL
-- !query schema
struct<regr_slope(y, x):double>
-- !query output
0.8314087759815244
0.8314087759815242


-- !query
Expand All @@ -232,15 +232,15 @@ SELECT k, regr_slope(y, x) FROM testRegression GROUP BY k
struct<k:int,regr_slope(y, x):double>
-- !query output
1 NULL
2 0.8314087759815244
2 0.8314087759815242


-- !query
SELECT k, regr_slope(y, x) FROM testRegression WHERE x IS NOT NULL AND y IS NOT NULL GROUP BY k
-- !query schema
struct<k:int,regr_slope(y, x):double>
-- !query output
2 0.8314087759815244
2 0.8314087759815242


-- !query
Expand Down

0 comments on commit f1eca90

Please sign in to comment.