Skip to content

Commit 73d6c78

Browse files
committed
bugfix in corr_cov for multiple dims
1 parent d9d4098 commit 73d6c78

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

xarray/core/computation.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from xarray.core.merge import merge_attrs, merge_coordinates_without_align
3131
from xarray.core.options import OPTIONS, _get_keep_attrs
3232
from xarray.core.pycompat import is_duck_dask_array
33-
from xarray.core.types import T_DataArray
33+
from xarray.core.types import Dims, T_DataArray
3434
from xarray.core.utils import is_dict_like, is_scalar
3535
from xarray.core.variable import Variable
3636

@@ -1219,7 +1219,7 @@ def apply_ufunc(
12191219

12201220

12211221
def cov(
1222-
da_a: T_DataArray, da_b: T_DataArray, dim: Hashable | None = None, ddof: int = 1
1222+
da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None, ddof: int = 1
12231223
) -> T_DataArray:
12241224
"""
12251225
Compute covariance between two DataArray objects along a shared dimension.
@@ -1230,7 +1230,7 @@ def cov(
12301230
Array to compute.
12311231
da_b : DataArray
12321232
Array to compute.
1233-
dim : Hashable, optional
1233+
dim : str, iterable of hashable, "..." or None, optional
12341234
The dimension along which the covariance will be computed
12351235
ddof : int, default: 1
12361236
If ddof=1, covariance is normalized by N-1, giving an unbiased estimate,
@@ -1300,9 +1300,7 @@ def cov(
13001300
return _cov_corr(da_a, da_b, dim=dim, ddof=ddof, method="cov")
13011301

13021302

1303-
def corr(
1304-
da_a: T_DataArray, da_b: T_DataArray, dim: Hashable | None = None
1305-
) -> T_DataArray:
1303+
def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims = None) -> T_DataArray:
13061304
"""
13071305
Compute the Pearson correlation coefficient between
13081306
two DataArray objects along a shared dimension.
@@ -1313,7 +1311,7 @@ def corr(
13131311
Array to compute.
13141312
da_b : DataArray
13151313
Array to compute.
1316-
dim : Hashable, optional
1314+
dim : str, iterable of hashable, "..." or None, optional
13171315
The dimension along which the correlation will be computed
13181316
13191317
Returns
@@ -1383,15 +1381,14 @@ def corr(
13831381
def _cov_corr(
13841382
da_a: T_DataArray,
13851383
da_b: T_DataArray,
1386-
dim: Hashable | None = None,
1384+
dim: Dims = None,
13871385
ddof: int = 0,
13881386
method: Literal["cov", "corr", None] = None,
13891387
) -> T_DataArray:
13901388
"""
13911389
Internal method for xr.cov() and xr.corr() so only have to
13921390
sanitize the input arrays once and we don't repeat code.
13931391
"""
1394-
dim = None if dim is None else (dim,)
13951392
# 1. Broadcast the two arrays
13961393
da_a, da_b = align(da_a, da_b, join="inner", copy=False)
13971394

@@ -1633,7 +1630,7 @@ def cross(
16331630

16341631
def dot(
16351632
*arrays,
1636-
dims: str | Iterable[Hashable] | ellipsis | None = None,
1633+
dims: Dims = None,
16371634
**kwargs: Any,
16381635
):
16391636
"""Generalized dot product for xarray objects. Like np.einsum, but
@@ -1643,7 +1640,7 @@ def dot(
16431640
----------
16441641
*arrays : DataArray or Variable
16451642
Arrays to compute.
1646-
dims : ..., str or tuple of str, optional
1643+
dims : str, iterable of hashable, "..." or None, optional
16471644
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
16481645
If not specified, then all the common dimensions are summed over.
16491646
**kwargs : dict

0 commit comments

Comments
 (0)