Skip to content

Commit

Permalink
Add var and std to weighted computations (#5870)
Browse files Browse the repository at this point in the history
Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com>
  • Loading branch information
cjauvin and Illviljan authored Oct 28, 2021
1 parent 7b93333 commit b3b77f5
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 8 deletions.
6 changes: 6 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -779,12 +779,18 @@ Weighted objects

core.weighted.DataArrayWeighted
core.weighted.DataArrayWeighted.mean
core.weighted.DataArrayWeighted.std
core.weighted.DataArrayWeighted.sum
core.weighted.DataArrayWeighted.sum_of_squares
core.weighted.DataArrayWeighted.sum_of_weights
core.weighted.DataArrayWeighted.var
core.weighted.DatasetWeighted
core.weighted.DatasetWeighted.mean
core.weighted.DatasetWeighted.std
core.weighted.DatasetWeighted.sum
core.weighted.DatasetWeighted.sum_of_squares
core.weighted.DatasetWeighted.sum_of_weights
core.weighted.DatasetWeighted.var


Coarsen objects
Expand Down
20 changes: 17 additions & 3 deletions doc/user-guide/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ Weighted array reductions

:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted`
and :py:meth:`Dataset.weighted` array reduction methods. They currently
support weighted ``sum`` and weighted ``mean``.
support weighted ``sum``, ``mean``, ``std`` and ``var``.

.. ipython:: python
Expand Down Expand Up @@ -298,13 +298,27 @@ The weighted sum corresponds to:
weighted_sum = (prec * weights).sum()
weighted_sum
and the weighted mean to:
the weighted mean to:

.. ipython:: python
weighted_mean = weighted_sum / weights.sum()
weighted_mean
the weighted variance to:

.. ipython:: python
weighted_var = weighted_prec.sum_of_squares() / weights.sum()
weighted_var
and the weighted standard deviation to:

.. ipython:: python
weighted_std = np.sqrt(weighted_var)
weighted_std
However, the functions also take missing values in the data into account:

.. ipython:: python
Expand All @@ -327,7 +341,7 @@ If the weights add up to to 0, ``sum`` returns 0:
data.weighted(weights).sum()
and ``mean`` returns ``NaN``:
and ``mean``, ``std`` and ``var`` return ``NaN``:

.. ipython:: python
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ v0.19.1 (unreleased)
New Features
~~~~~~~~~~~~
- Add :py:meth:`var`, :py:meth:`std` and :py:meth:`sum_of_squares` to :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted`.
By `Christian Jauvin <https://github.com/cjauvin>`_.
- Added a :py:func:`get_options` method to xarray's root namespace (:issue:`5698`, :pull:`5716`)
By `Pushkar Kopparla <https://github.com/pkopparla>`_.
- Xarray now does a better job rendering variable names that are long LaTeX sequences when plotting (:issue:`5681`, :pull:`5682`).
Expand Down
89 changes: 86 additions & 3 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union, cast

import numpy as np

from . import duck_array_ops
from .computation import dot
Expand Down Expand Up @@ -35,7 +37,7 @@
"""

_SUM_OF_WEIGHTS_DOCSTRING = """
Calculate the sum of weights, accounting for missing values in the data
Calculate the sum of weights, accounting for missing values in the data.
Parameters
----------
Expand Down Expand Up @@ -177,13 +179,25 @@ def _sum_of_weights(

return sum_of_weights.where(valid_weights)

def _sum_of_squares(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""

demeaned = da - da.weighted(self.weights).mean(dim=dim)

return self._reduce((demeaned ** 2), self.weights, dim=dim, skipna=skipna)

def _weighted_sum(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a by a weighted ``sum`` along some dimension(s)."""
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""

return self._reduce(da, self.weights, dim=dim, skipna=skipna)

Expand All @@ -201,6 +215,30 @@ def _weighted_mean(

return weighted_sum / sum_of_weights

def _weighted_var(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""

sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna)

sum_of_weights = self._sum_of_weights(da, dim=dim)

return sum_of_squares / sum_of_weights

def _weighted_std(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""

return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))

def _implementation(self, func, dim, **kwargs):

raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
Expand All @@ -215,6 +253,17 @@ def sum_of_weights(
self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
)

def sum_of_squares(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> T_Xarray:

return self._implementation(
self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def sum(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
Expand All @@ -237,6 +286,28 @@ def mean(
self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def var(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> T_Xarray:

return self._implementation(
self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def std(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> T_Xarray:

return self._implementation(
self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def __repr__(self):
"""provide a nice str repr of our Weighted object"""

Expand Down Expand Up @@ -275,6 +346,18 @@ def _inject_docstring(cls, cls_name):
cls=cls_name, fcn="mean", on_zero="NaN"
)

cls.sum_of_squares.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls=cls_name, fcn="sum_of_squares", on_zero="0"
)

cls.var.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls=cls_name, fcn="var", on_zero="NaN"
)

cls.std.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls=cls_name, fcn="std", on_zero="NaN"
)


_inject_docstring(DataArrayWeighted, "DataArray")
_inject_docstring(DatasetWeighted, "Dataset")
Loading

0 comments on commit b3b77f5

Please sign in to comment.