Skip to content

Commit

Permalink
differentiate should not cast to numpy.array (pydata#5408)
Browse files Browse the repository at this point in the history
  • Loading branch information
keewis authored Jun 7, 2021
1 parent 34dc577 commit da0489f
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 15 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ Bug fixes
- Fix 1-level multi-index incorrectly converted to single index (:issue:`5384`,
:pull:`5385`).
By `Benoit Bovy <https://github.com/benbovy>`_.
- Don't cast a duck array in a coordinate to :py:class:`numpy.ndarray` in
:py:meth:`DataArray.differentiate` (:pull:`5408`)
By `Justus Magin <https://github.com/keewis>`_.
- Fix the ``repr`` of :py:class:`Variable` objects with ``display_expand_data=True``
(:pull:`5406`)
By `Justus Magin <https://github.com/keewis>`_.
Expand Down
5 changes: 4 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6244,7 +6244,10 @@ def differentiate(self, coord, edge_order=1, datetime_unit=None):
if _contains_datetime_like_objects(v):
v = v._to_numeric(datetime_unit=datetime_unit)
grad = duck_array_ops.gradient(
v.data, coord_var, edge_order=edge_order, axis=v.get_axis_num(dim)
v.data,
coord_var.data,
edge_order=edge_order,
axis=v.get_axis_num(dim),
)
variables[k] = Variable(v.dims, grad)
else:
Expand Down
87 changes: 73 additions & 14 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import xarray as xr
from xarray.core import dtypes
from xarray.core import dtypes, duck_array_ops

from . import assert_allclose, assert_duckarray_allclose, assert_equal, assert_identical
from .test_variable import _PAD_XR_NP_ARGS
Expand Down Expand Up @@ -276,13 +276,13 @@ class method:
This is works a bit similar to using `partial(Class.method, arg, kwarg)`
"""

def __init__(self, name, *args, **kwargs):
def __init__(self, name, *args, fallback_func=None, **kwargs):
self.name = name
self.fallback = fallback_func
self.args = args
self.kwargs = kwargs

def __call__(self, obj, *args, **kwargs):
from collections.abc import Callable
from functools import partial

all_args = merge_args(self.args, args)
Expand All @@ -298,21 +298,23 @@ def __call__(self, obj, *args, **kwargs):
if not isinstance(obj, xarray_classes):
# remove typical xarray args like "dim"
exclude_kwargs = ("dim", "dims")
# TODO: figure out a way to replace dim / dims with axis
all_kwargs = {
key: value
for key, value in all_kwargs.items()
if key not in exclude_kwargs
}

func = getattr(obj, self.name, None)

if func is None or not isinstance(func, Callable):
# fall back to module level numpy functions if not a xarray object
if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)):
numpy_func = getattr(np, self.name)
func = partial(numpy_func, obj)
if self.fallback is not None:
func = partial(self.fallback, obj)
else:
raise AttributeError(f"{obj} has no method named '{self.name}'")
func = getattr(obj, self.name, None)

if func is None or not callable(func):
# fall back to module level numpy functions
numpy_func = getattr(np, self.name)
func = partial(numpy_func, obj)
else:
func = getattr(obj, self.name)

return func(*all_args, **all_kwargs)

Expand Down Expand Up @@ -3662,6 +3664,65 @@ def test_stacking_reordering(self, func, dtype):
assert_units_equal(expected, actual)
assert_identical(expected, actual)

@pytest.mark.parametrize(
"variant",
(
pytest.param(
"dims", marks=pytest.mark.skip(reason="indexes don't support units")
),
"coords",
),
)
@pytest.mark.parametrize(
"func",
(
method("differentiate", fallback_func=np.gradient),
method("integrate", fallback_func=duck_array_ops.cumulative_trapezoid),
method("cumulative_integrate", fallback_func=duck_array_ops.trapz),
),
ids=repr,
)
def test_differentiate_integrate(self, func, variant, dtype):
data_unit = unit_registry.m
unit = unit_registry.s

variants = {
"dims": ("x", unit, 1),
"coords": ("u", 1, unit),
}
coord, dim_unit, coord_unit = variants.get(variant)

array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit

x = np.arange(array.shape[0]) * dim_unit
y = np.arange(array.shape[1]) * dim_unit

u = np.linspace(0, 1, array.shape[0]) * coord_unit

data_array = xr.DataArray(
data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y")
)
# we want to make sure the output unit is correct
units = extract_units(data_array)
units.update(
extract_units(
func(
data_array.data,
getattr(data_array, coord).data,
axis=0,
)
)
)

expected = attach_units(
func(strip_units(data_array), coord=strip_units(coord)),
units,
)
actual = func(data_array, coord=coord)

assert_units_equal(expected, actual)
assert_identical(expected, actual)

@pytest.mark.parametrize(
"variant",
(
Expand All @@ -3676,8 +3737,6 @@ def test_stacking_reordering(self, func, dtype):
"func",
(
method("diff", dim="x"),
method("differentiate", coord="x"),
method("integrate", coord="x"),
method("quantile", q=[0.25, 0.75]),
method("reduce", func=np.sum, dim="x"),
pytest.param(lambda x: x.dot(x), id="method_dot"),
Expand Down

0 comments on commit da0489f

Please sign in to comment.