Skip to content

Commit cad4474

Browse files
Fix polyval overloads (#6593)
* add polyval overloads * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 11041bd commit cad4474

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

xarray/core/computation.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Iterable,
1818
Mapping,
1919
Sequence,
20+
overload,
2021
)
2122

2223
import numpy as np
@@ -1845,26 +1846,31 @@ def where(cond, x, y, keep_attrs=None):
18451846
)
18461847

18471848

1848-
# These overloads seem not to work — mypy says it can't find a matching overload for
1849-
# `DataArray` & `DataArray`, despite that being in the first overload. Would be nice to
1850-
# have overloaded functions rather than just `T_Xarray` for everything.
1849+
@overload
1850+
def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray:
1851+
...
18511852

1852-
# @overload
1853-
# def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray:
1854-
# ...
18551853

1854+
@overload
1855+
def polyval(coord: DataArray, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
1856+
...
18561857

1857-
# @overload
1858-
# def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
1859-
# ...
18601858

1859+
@overload
1860+
def polyval(coord: Dataset, coeffs: DataArray, degree_dim: Hashable) -> Dataset:
1861+
...
18611862

1862-
# @overload
1863-
# def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset:
1864-
# ...
1863+
1864+
@overload
1865+
def polyval(coord: Dataset, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
1866+
...
18651867

18661868

1867-
def polyval(coord: T_Xarray, coeffs: T_Xarray, degree_dim="degree") -> T_Xarray:
1869+
def polyval(
1870+
coord: Dataset | DataArray,
1871+
coeffs: Dataset | DataArray,
1872+
degree_dim: Hashable = "degree",
1873+
) -> Dataset | DataArray:
18681874
"""Evaluate a polynomial at specific values
18691875
18701876
Parameters
@@ -1899,7 +1905,7 @@ def polyval(coord: T_Xarray, coeffs: T_Xarray, degree_dim="degree") -> T_Xarray:
18991905
coeffs = coeffs.reindex(
19001906
{degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False
19011907
)
1902-
coord = _ensure_numeric(coord)
1908+
coord = _ensure_numeric(coord) # type: ignore # https://github.com/python/mypy/issues/1533 ?
19031909

19041910
# using Horner's method
19051911
# https://en.wikipedia.org/wiki/Horner%27s_method

xarray/tests/test_computation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,14 +2012,19 @@ def test_where_attrs() -> None:
20122012
),
20132013
],
20142014
)
2015-
def test_polyval(use_dask, x: T_Xarray, coeffs: T_Xarray, expected) -> None:
2015+
def test_polyval(
2016+
use_dask: bool,
2017+
x: xr.DataArray | xr.Dataset,
2018+
coeffs: xr.DataArray | xr.Dataset,
2019+
expected: xr.DataArray | xr.Dataset,
2020+
) -> None:
20162021
if use_dask:
20172022
if not has_dask:
20182023
pytest.skip("requires dask")
20192024
coeffs = coeffs.chunk({"degree": 2})
20202025
x = x.chunk({"x": 2})
20212026
with raise_if_dask_computes():
2022-
actual = xr.polyval(coord=x, coeffs=coeffs)
2027+
actual = xr.polyval(coord=x, coeffs=coeffs) # type: ignore
20232028
xr.testing.assert_allclose(actual, expected)
20242029

20252030

0 commit comments

Comments
 (0)