Skip to content

Commit f92b494

Browse files
committed
[SPARK-40386][PS][SQL] Implement ddof in DataFrame.cov
### What changes were proposed in this pull request? 1, add a dedicated expression for `DataFrame.cov`; 2, add missing parameter `ddof` in `DataFrame.cov` ### Why are the changes needed? for api coverage ### Does this PR introduce _any_ user-facing change? yes, API change ``` >>> np.random.seed(42) >>> df = ps.DataFrame(np.random.randn(1000, 5), ... columns=['a', 'b', 'c', 'd', 'e']) >>> df.cov() a b c d e a 0.998438 -0.020161 0.059277 -0.008943 0.014144 b -0.020161 1.059352 -0.008543 -0.024738 0.009826 c 0.059277 -0.008543 1.010670 -0.001486 -0.000271 d -0.008943 -0.024738 -0.001486 0.921297 -0.013692 e 0.014144 0.009826 -0.000271 -0.013692 0.977795 >>> df.cov(ddof=2) a b c d e a 0.999439 -0.020181 0.059336 -0.008952 0.014159 b -0.020181 1.060413 -0.008551 -0.024762 0.009836 c 0.059336 -0.008551 1.011683 -0.001487 -0.000271 d -0.008952 -0.024762 -0.001487 0.922220 -0.013705 e 0.014159 0.009836 -0.000271 -0.013705 0.978775 >>> df.cov(ddof=-1) a b c d e a 0.996444 -0.020121 0.059158 -0.008926 0.014116 b -0.020121 1.057235 -0.008526 -0.024688 0.009807 c 0.059158 -0.008526 1.008650 -0.001483 -0.000270 d -0.008926 -0.024688 -0.001483 0.919456 -0.013664 e 0.014116 0.009807 -0.000270 -0.013664 0.975842 ``` ### How was this patch tested? added tests Closes #37829 from zhengruifeng/ps_cov_ddof. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent c597331 commit f92b494

File tree

5 files changed

+66
-6
lines changed

5 files changed

+66
-6
lines changed

python/pyspark/pandas/frame.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8738,8 +8738,7 @@ def update(self, other: "DataFrame", join: str = "left", overwrite: bool = True)
87388738
internal = self._internal.with_new_sdf(sdf, data_fields=data_fields)
87398739
self._update_internal_frame(internal, check_same_anchor=False)
87408740

8741-
# TODO: ddof should be implemented.
8742-
def cov(self, min_periods: Optional[int] = None) -> "DataFrame":
8741+
def cov(self, min_periods: Optional[int] = None, ddof: int = 1) -> "DataFrame":
87438742
"""
87448743
Compute pairwise covariance of columns, excluding NA/null values.
87458744
@@ -8755,8 +8754,7 @@ def cov(self, min_periods: Optional[int] = None) -> "DataFrame":
87558754
below this threshold will be returned as ``NaN``.
87568755
87578756
This method is generally used for the analysis of time series data to
8758-
understand the relationship between different measures
8759-
across time.
8757+
understand the relationship between different measures across time.
87608758
87618759
.. versionadded:: 3.3.0
87628760
@@ -8765,6 +8763,11 @@ def cov(self, min_periods: Optional[int] = None) -> "DataFrame":
87658763
min_periods : int, optional
87668764
Minimum number of observations required per pair of columns
87678765
to have a valid result.
8766+
ddof : int, default 1
8767+
Delta degrees of freedom. The divisor used in calculations
8768+
is ``N - ddof``, where ``N`` represents the number of elements.
8769+
8770+
.. versionadded:: 3.4.0
87688771
87698772
Returns
87708773
-------
@@ -8794,6 +8797,20 @@ def cov(self, min_periods: Optional[int] = None) -> "DataFrame":
87948797
c 0.059277 -0.008543 1.010670 -0.001486 -0.000271
87958798
d -0.008943 -0.024738 -0.001486 0.921297 -0.013692
87968799
e 0.014144 0.009826 -0.000271 -0.013692 0.977795
8800+
>>> df.cov(ddof=2)
8801+
a b c d e
8802+
a 0.999439 -0.020181 0.059336 -0.008952 0.014159
8803+
b -0.020181 1.060413 -0.008551 -0.024762 0.009836
8804+
c 0.059336 -0.008551 1.011683 -0.001487 -0.000271
8805+
d -0.008952 -0.024762 -0.001487 0.922220 -0.013705
8806+
e 0.014159 0.009836 -0.000271 -0.013705 0.978775
8807+
>>> df.cov(ddof=-1)
8808+
a b c d e
8809+
a 0.996444 -0.020121 0.059158 -0.008926 0.014116
8810+
b -0.020121 1.057235 -0.008526 -0.024688 0.009807
8811+
c 0.059158 -0.008526 1.008650 -0.001483 -0.000270
8812+
d -0.008926 -0.024688 -0.001483 0.919456 -0.013664
8813+
e 0.014116 0.009807 -0.000270 -0.013664 0.975842
87978814
87988815
**Minimum number of periods**
87998816
@@ -8813,6 +8830,8 @@ def cov(self, min_periods: Optional[int] = None) -> "DataFrame":
88138830
b NaN 1.248003 0.191417
88148831
c -0.150812 0.191417 0.895202
88158832
"""
8833+
if not isinstance(ddof, int):
8834+
raise TypeError("ddof must be integer")
88168835
min_periods = 1 if min_periods is None else min_periods
88178836

88188837
# Only compute covariance for Boolean and Numeric except Decimal
@@ -8891,8 +8910,8 @@ def cov(self, min_periods: Optional[int] = None) -> "DataFrame":
88918910
step += r
88928911
for c in range(r, num_cols):
88938912
cov_scols.append(
8894-
F.covar_samp(
8895-
F.col(data_cols[r]).cast("double"), F.col(data_cols[c]).cast("double")
8913+
SF.covar(
8914+
F.col(data_cols[r]).cast("double"), F.col(data_cols[c]).cast("double"), ddof
88968915
)
88978916
if count_not_null[r * num_cols + c - step] >= min_periods
88988917
else F.lit(None)

python/pyspark/pandas/spark/functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def mode(col: Column, dropna: bool) -> Column:
5151
return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))
5252

5353

54+
def covar(col1: Column, col2: Column, ddof: int) -> Column:
55+
sc = SparkContext._active_spark_context
56+
return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof))
57+
58+
5459
def repeat(col: Column, n: Union[int, Column]) -> Column:
5560
"""
5661
Repeats a string column n times, and returns it as a new string column.

python/pyspark/pandas/tests/test_dataframe.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6958,6 +6958,16 @@ def test_cov(self):
69586958
self.assert_eq(pdf.cov(min_periods=4), psdf.cov(min_periods=4), almost=True)
69596959
self.assert_eq(pdf.cov(min_periods=5), psdf.cov(min_periods=5))
69606960

6961+
# ddof
6962+
with self.assertRaisesRegex(TypeError, "ddof must be integer"):
6963+
psdf.cov(ddof="ddof")
6964+
for ddof in [-1, 0, 2]:
6965+
self.assert_eq(pdf.cov(ddof=ddof), psdf.cov(ddof=ddof), almost=True)
6966+
self.assert_eq(
6967+
pdf.cov(min_periods=4, ddof=ddof), psdf.cov(min_periods=4, ddof=ddof), almost=True
6968+
)
6969+
self.assert_eq(pdf.cov(min_periods=5, ddof=ddof), psdf.cov(min_periods=5, ddof=ddof))
6970+
69616971
# bool
69626972
pdf = pd.DataFrame(
69636973
{

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,25 @@ case class CovSample(
143143
override protected def withNewChildrenInternal(
144144
newLeft: Expression, newRight: Expression): CovSample = copy(left = newLeft, right = newRight)
145145
}
146+
147+
/**
148+
* Covariance in Pandas' fashion. This expression is dedicated only for Pandas API on Spark.
149+
* Refer to numpy.cov.
150+
*/
151+
case class PandasCovar(
152+
override val left: Expression,
153+
override val right: Expression,
154+
ddof: Int)
155+
extends Covariance(left, right, true) {
156+
157+
override val evaluateExpression: Expression = {
158+
If(n === 0.0, Literal.create(null, DoubleType),
159+
If(n === ddof, divideByZeroEvalResult, ck / (n - ddof)))
160+
}
161+
override def prettyName: String = "pandas_covar"
162+
163+
override protected def withNewChildrenInternal(
164+
newLeft: Expression,
165+
newRight: Expression): PandasCovar =
166+
copy(left = newLeft, right = newRight)
167+
}

sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ private[sql] object PythonSQLUtils extends Logging {
138138
def pandasMode(e: Column, ignoreNA: Boolean): Column = {
139139
Column(PandasMode(e.expr, ignoreNA).toAggregateExpression(false))
140140
}
141+
142+
def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = {
143+
Column(PandasCovar(col1.expr, col2.expr, ddof).toAggregateExpression(false))
144+
}
141145
}
142146

143147
/**

0 commit comments

Comments
 (0)