Skip to content

Commit efabe74

Browse files
aulemahalmathausekeewis
authored
Fix polyfit fail on deficient rank (#4193)
* Fix polyfit fail on deficient rank * Add docs and RankWarning * Fix deficient ranks outputs | workaround dask bug | add tests * Add a note to the doc | whats new entry * Update xarray/core/nputils.py Apply suggestion from review. Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com> * Fix test and catch warnings * forgot to run black * adapt polyfit test to properly test issue #4190 * Fix syntax in doc/whats-new.rst Co-authored-by: keewis <keewis@users.noreply.github.com> Co-authored-by: Mathias Hauser <mathause@users.noreply.github.com> Co-authored-by: keewis <keewis@users.noreply.github.com>
1 parent 526f735 commit efabe74

File tree

7 files changed

+68
-18
lines changed

7 files changed

+68
-18
lines changed

doc/whats-new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ Bug fixes
6565
unchunked dimensions (:pull:`4312`) By `Tobias Kölling <https://github.com/d70-t>`_.
6666
- Fixed a bug in backend caused by basic installation of Dask (:issue:`4164`, :pull:`4318`)
6767
`Sam Morley <https://github.com/inakleinbottle>`_.
68+
- Fixed a few bugs with :py:meth:`Dataset.polyfit` when encountering deficient matrix ranks (:issue:`4190`, :pull:`4193`). By `Pascal Bourgault <https://github.com/aulemahal>`_.
6869
- Fixed inconsistencies between docstring and functionality for :py:meth:`DataArray.str.get`
6970
and :py:meth:`DataArray.str.wrap` (:issue:`4334`). By `Mathias Hauser <https://github.com/mathause>`_.
7071
- Fixed overflow issue causing incorrect results in computing means of :py:class:`cftime.datetime`

xarray/core/dask_array_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,7 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
131131
coeffs = coeffs.reshape(coeffs.shape[0])
132132
residuals = residuals.reshape(residuals.shape[0])
133133
else:
134+
# Residuals here are (1, 1) but should be (K,) as rhs is (N, K)
135+
# See issue dask/dask#6516
134136
coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
135137
return coeffs, residuals

xarray/core/dataarray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3478,7 +3478,8 @@ def polyfit(
34783478
polyfit_coefficients
34793479
The coefficients of the best fit.
34803480
polyfit_residuals
3481-
The residuals of the least-square computation (only included if `full=True`)
3481+
The residuals of the least-square computation (only included if `full=True`).
3482+
When the matrix rank is deficient, np.nan is returned.
34823483
[dim]_matrix_rank
34833484
The effective rank of the scaled Vandermonde coefficient matrix (only included if `full=True`)
34843485
[dim]_singular_value

xarray/core/dataset.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5957,13 +5957,21 @@ def polyfit(
59575957
The coefficients of the best fit for each variable in this dataset.
59585958
[var]_polyfit_residuals
59595959
The residuals of the least-square computation for each variable (only included if `full=True`)
5960+
When the matrix rank is deficient, np.nan is returned.
59605961
[dim]_matrix_rank
59615962
The effective rank of the scaled Vandermonde coefficient matrix (only included if `full=True`)
5963+
The rank is computed ignoring the NaN values that might be skipped.
59625964
[dim]_singular_values
59635965
The singular values of the scaled Vandermonde coefficient matrix (only included if `full=True`)
59645966
[var]_polyfit_covariance
59655967
The covariance matrix of the polynomial coefficient estimates (only included if `full=False` and `cov=True`)
59665968
5969+
Warns
5970+
-----
5971+
RankWarning
5972+
The rank of the coefficient matrix in the least-squares fit is deficient.
5973+
The warning is not raised with in-memory (not dask) data and `full=True`.
5974+
59675975
See also
59685976
--------
59695977
numpy.polyfit
@@ -5997,10 +6005,6 @@ def polyfit(
59976005
degree_dim = utils.get_temp_dimname(self.dims, "degree")
59986006

59996007
rank = np.linalg.matrix_rank(lhs)
6000-
if rank != order and not full:
6001-
warnings.warn(
6002-
"Polyfit may be poorly conditioned", np.RankWarning, stacklevel=4
6003-
)
60046008

60056009
if full:
60066010
rank = xr.DataArray(rank, name=xname + "matrix_rank")
@@ -6009,7 +6013,7 @@ def polyfit(
60096013
sing = xr.DataArray(
60106014
sing,
60116015
dims=(degree_dim,),
6012-
coords={degree_dim: np.arange(order)[::-1]},
6016+
coords={degree_dim: np.arange(rank - 1, -1, -1)},
60136017
name=xname + "singular_values",
60146018
)
60156019
variables[sing.name] = sing
@@ -6018,11 +6022,14 @@ def polyfit(
60186022
if dim not in da.dims:
60196023
continue
60206024

6021-
if skipna is None:
6022-
if isinstance(da.data, dask_array_type):
6023-
skipna_da = True
6024-
else:
6025-
skipna_da = np.any(da.isnull())
6025+
if isinstance(da.data, dask_array_type) and (
6026+
rank != order or full or skipna is None
6027+
):
6028+
# Current algorithm with dask and skipna=False neither supports
6029+
# deficient ranks nor does it output the "full" info (issue dask/dask#6516)
6030+
skipna_da = True
6031+
elif skipna is None:
6032+
skipna_da = np.any(da.isnull())
60266033

60276034
dims_to_stack = [dimname for dimname in da.dims if dimname != dim]
60286035
stacked_coords: Dict[Hashable, DataArray] = {}
@@ -6040,9 +6047,15 @@ def polyfit(
60406047
if w is not None:
60416048
rhs *= w[:, np.newaxis]
60426049

6043-
coeffs, residuals = duck_array_ops.least_squares(
6044-
lhs, rhs.data, rcond=rcond, skipna=skipna_da
6045-
)
6050+
with warnings.catch_warnings():
6051+
if full: # Copy np.polyfit behavior
6052+
warnings.simplefilter("ignore", np.RankWarning)
6053+
else: # Raise only once per variable
6054+
warnings.simplefilter("once", np.RankWarning)
6055+
6056+
coeffs, residuals = duck_array_ops.least_squares(
6057+
lhs, rhs.data, rcond=rcond, skipna=skipna_da
6058+
)
60466059

60476060
if isinstance(name, str):
60486061
name = "{}_".format(name)

xarray/core/nputils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,17 @@ def _nanpolyfit_1d(arr, x, rcond=None):
232232
out = np.full((x.shape[1] + 1,), np.nan)
233233
mask = np.isnan(arr)
234234
if not np.all(mask):
235-
out[:-1], out[-1], _, _ = np.linalg.lstsq(x[~mask, :], arr[~mask], rcond=rcond)
235+
out[:-1], resid, rank, _ = np.linalg.lstsq(x[~mask, :], arr[~mask], rcond=rcond)
236+
out[-1] = resid if resid.size > 0 else np.nan
237+
warn_on_deficient_rank(rank, x.shape[1])
236238
return out
237239

238240

241+
def warn_on_deficient_rank(rank, order):
242+
if rank != order:
243+
warnings.warn("Polyfit may be poorly conditioned", np.RankWarning, stacklevel=2)
244+
245+
239246
def least_squares(lhs, rhs, rcond=None, skipna=False):
240247
if skipna:
241248
added_dim = rhs.ndim == 1
@@ -248,16 +255,21 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
248255
_nanpolyfit_1d, 0, rhs[:, nan_cols], lhs
249256
)
250257
if np.any(~nan_cols):
251-
out[:-1, ~nan_cols], out[-1, ~nan_cols], _, _ = np.linalg.lstsq(
258+
out[:-1, ~nan_cols], resids, rank, _ = np.linalg.lstsq(
252259
lhs, rhs[:, ~nan_cols], rcond=rcond
253260
)
261+
out[-1, ~nan_cols] = resids if resids.size > 0 else np.nan
262+
warn_on_deficient_rank(rank, lhs.shape[1])
254263
coeffs = out[:-1, :]
255264
residuals = out[-1, :]
256265
if added_dim:
257266
coeffs = coeffs.reshape(coeffs.shape[0])
258267
residuals = residuals.reshape(residuals.shape[0])
259268
else:
260-
coeffs, residuals, _, _ = np.linalg.lstsq(lhs, rhs, rcond=rcond)
269+
coeffs, residuals, rank, _ = np.linalg.lstsq(lhs, rhs, rcond=rcond)
270+
if residuals.size == 0:
271+
residuals = coeffs[0] * np.nan
272+
warn_on_deficient_rank(rank, lhs.shape[1])
261273
return coeffs, residuals
262274

263275

xarray/tests/test_dataarray.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4301,8 +4301,14 @@ def test_polyfit(self, use_dask, use_datetime):
43014301
).T
43024302
assert_allclose(out.polyfit_coefficients, expected, rtol=1e-3)
43034303

4304+
# Full output and deficient rank
4305+
with warnings.catch_warnings():
4306+
warnings.simplefilter("ignore", np.RankWarning)
4307+
out = da.polyfit("x", 12, full=True)
4308+
assert out.polyfit_residuals.isnull().all()
4309+
43044310
# With NaN
4305-
da_raw[0, 1] = np.nan
4311+
da_raw[0, 1:3] = np.nan
43064312
if use_dask:
43074313
da = da_raw.chunk({"d": 1})
43084314
else:
@@ -4317,6 +4323,11 @@ def test_polyfit(self, use_dask, use_datetime):
43174323
assert out.x_matrix_rank == 3
43184324
np.testing.assert_almost_equal(out.polyfit_residuals, [0, 0])
43194325

4326+
with warnings.catch_warnings():
4327+
warnings.simplefilter("ignore", np.RankWarning)
4328+
out = da.polyfit("x", 8, full=True)
4329+
np.testing.assert_array_equal(out.polyfit_residuals.isnull(), [True, False])
4330+
43204331
def test_pad_constant(self):
43214332
ar = DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5))
43224333
actual = ar.pad(dim_0=(1, 3))

xarray/tests/test_dataset.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5621,6 +5621,16 @@ def test_polyfit_output(self):
56215621
out = ds.polyfit("time", 2)
56225622
assert len(out.data_vars) == 0
56235623

5624+
def test_polyfit_warnings(self):
5625+
ds = create_test_data(seed=1)
5626+
5627+
with warnings.catch_warnings(record=True) as ws:
5628+
ds.var1.polyfit("dim2", 10, full=False)
5629+
assert len(ws) == 1
5630+
assert ws[0].category == np.RankWarning
5631+
ds.var1.polyfit("dim2", 10, full=True)
5632+
assert len(ws) == 1
5633+
56245634
def test_pad(self):
56255635
ds = create_test_data(seed=1)
56265636
padded = ds.pad(dim2=(1, 1), constant_values=42)

0 commit comments

Comments
 (0)