Skip to content

Commit c5ee050

Browse files
TomNicholasmax-sixtydcherianshoyer
authored
Add to_numpy() and as_numpy() methods (#5568)
* added to_numpy() and as_numpy() methods * remove special-casing of cupy arrays in .values in favour of using .to_numpy() * lint * Fix mypy (I think?) * added Dataset.as_numpy() * improved docstrings * add what's new * add to API docs * linting * fix failures by only importing pint when needed * refactor pycompat into class * compute instead of load * added tests * fixed sparse test * tests and fixes for ds.as_numpy() * fix sparse tests * fix linting * tests for Variable * test IndexVariable too * use numpy.asarray to avoid a copy * also convert coords * Force tests again after #5600 * Apply suggestions from code review * Update xarray/core/variable.py * fix import * formatting * remove type check Co-authored-by: Stephan Hoyer <shoyer@google.com> * remove attempt to call to_numpy Co-authored-by: Maximilian Roos <m@maxroos.com> Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> Co-authored-by: Stephan Hoyer <shoyer@google.com>
1 parent 86ca67e commit c5ee050

File tree

10 files changed

+363
-36
lines changed

10 files changed

+363
-36
lines changed

doc/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ Dataset methods
686686
open_zarr
687687
Dataset.to_netcdf
688688
Dataset.to_pandas
689+
Dataset.as_numpy
689690
Dataset.to_zarr
690691
save_mfdataset
691692
Dataset.to_array
@@ -716,6 +717,8 @@ DataArray methods
716717
DataArray.to_pandas
717718
DataArray.to_series
718719
DataArray.to_dataframe
720+
DataArray.to_numpy
721+
DataArray.as_numpy
719722
DataArray.to_index
720723
DataArray.to_masked_array
721724
DataArray.to_cdms2

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ New Features
5656
- Allow removal of the coordinate attribute ``coordinates`` on variables by setting ``.attrs['coordinates']= None``
5757
(:issue:`5510`).
5858
By `Elle Smith <https://github.com/ellesmith88>`_.
59+
- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`).
60+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
5961

6062
Breaking changes
6163
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,12 @@ def __init__(
426426
self._close = None
427427

428428
def _replace(
429-
self,
429+
self: T_DataArray,
430430
variable: Variable = None,
431431
coords=None,
432432
name: Union[Hashable, None, Default] = _default,
433433
indexes=None,
434-
) -> "DataArray":
434+
) -> T_DataArray:
435435
if variable is None:
436436
variable = self.variable
437437
if coords is None:
@@ -623,7 +623,16 @@ def __len__(self) -> int:
623623

624624
@property
625625
def data(self) -> Any:
626-
"""The array's data as a dask or numpy array"""
626+
"""
627+
The DataArray's data as an array. The underlying array type
628+
(e.g. dask, sparse, pint) is preserved.
629+
630+
See Also
631+
--------
632+
DataArray.to_numpy
633+
DataArray.as_numpy
634+
DataArray.values
635+
"""
627636
return self.variable.data
628637

629638
@data.setter
@@ -632,13 +641,46 @@ def data(self, value: Any) -> None:
632641

633642
@property
634643
def values(self) -> np.ndarray:
635-
"""The array's data as a numpy.ndarray"""
644+
"""
645+
The array's data as a numpy.ndarray.
646+
647+
If the array's data is not a numpy.ndarray this will attempt to convert
648+
it naively using np.array(), which will raise an error if the array
649+
type does not support coercion like this (e.g. cupy).
650+
"""
636651
return self.variable.values
637652

638653
@values.setter
639654
def values(self, value: Any) -> None:
640655
self.variable.values = value
641656

657+
def to_numpy(self) -> np.ndarray:
658+
"""
659+
Coerces wrapped data to numpy and returns a numpy.ndarray.
660+
661+
See also
662+
--------
663+
DataArray.as_numpy : Same but returns the surrounding DataArray instead.
664+
Dataset.as_numpy
665+
DataArray.values
666+
DataArray.data
667+
"""
668+
return self.variable.to_numpy()
669+
670+
def as_numpy(self: T_DataArray) -> T_DataArray:
671+
"""
672+
Coerces wrapped data and coordinates into numpy arrays, returning a DataArray.
673+
674+
See also
675+
--------
676+
DataArray.to_numpy : Same but returns only the data as a numpy.ndarray object.
677+
Dataset.as_numpy : Converts all variables in a Dataset.
678+
DataArray.values
679+
DataArray.data
680+
"""
681+
coords = {k: v.as_numpy() for k, v in self._coords.items()}
682+
return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes)
683+
642684
@property
643685
def _in_memory(self) -> bool:
644686
return self.variable._in_memory
@@ -931,7 +973,7 @@ def persist(self, **kwargs) -> "DataArray":
931973
ds = self._to_temp_dataset().persist(**kwargs)
932974
return self._from_temp_dataset(ds)
933975

934-
def copy(self, deep: bool = True, data: Any = None) -> "DataArray":
976+
def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
935977
"""Returns a copy of this array.
936978
937979
If `deep=True`, a deep copy is made of the data array.

xarray/core/dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,18 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset":
13231323

13241324
return self._replace(variables, attrs=attrs)
13251325

1326+
def as_numpy(self: "Dataset") -> "Dataset":
1327+
"""
1328+
Coerces wrapped data and coordinates into numpy arrays, returning a Dataset.
1329+
1330+
See also
1331+
--------
1332+
DataArray.as_numpy
1333+
DataArray.to_numpy : Returns only the data as a numpy.ndarray object.
1334+
"""
1335+
numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()}
1336+
return self._replace(variables=numpy_variables)
1337+
13261338
@property
13271339
def _level_coords(self) -> Dict[str, Hashable]:
13281340
"""Return a mapping of all MultiIndex levels and their corresponding

xarray/core/pycompat.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,63 @@
11
from distutils.version import LooseVersion
2+
from importlib import import_module
23

34
import numpy as np
45

56
from .utils import is_duck_array
67

78
integer_types = (int, np.integer)
89

9-
try:
10-
import dask
11-
import dask.array
12-
from dask.base import is_dask_collection
1310

14-
dask_version = LooseVersion(dask.__version__)
11+
class DuckArrayModule:
12+
"""
13+
Solely for internal isinstance and version checks.
1514
16-
# solely for isinstance checks
17-
dask_array_type = (dask.array.Array,)
15+
Motivated by having to only import pint when required (as pint currently imports xarray)
16+
https://github.com/pydata/xarray/pull/5561#discussion_r664815718
17+
"""
1818

19-
def is_duck_dask_array(x):
20-
return is_duck_array(x) and is_dask_collection(x)
19+
def __init__(self, mod):
20+
try:
21+
duck_array_module = import_module(mod)
22+
duck_array_version = LooseVersion(duck_array_module.__version__)
23+
24+
if mod == "dask":
25+
duck_array_type = (import_module("dask.array").Array,)
26+
elif mod == "pint":
27+
duck_array_type = (duck_array_module.Quantity,)
28+
elif mod == "cupy":
29+
duck_array_type = (duck_array_module.ndarray,)
30+
elif mod == "sparse":
31+
duck_array_type = (duck_array_module.SparseArray,)
32+
else:
33+
raise NotImplementedError
34+
35+
except ImportError: # pragma: no cover
36+
duck_array_module = None
37+
duck_array_version = LooseVersion("0.0.0")
38+
duck_array_type = ()
2139

40+
self.module = duck_array_module
41+
self.version = duck_array_version
42+
self.type = duck_array_type
43+
self.available = duck_array_module is not None
2244

23-
except ImportError: # pragma: no cover
24-
dask_version = LooseVersion("0.0.0")
25-
dask_array_type = ()
26-
is_duck_dask_array = lambda _: False
27-
is_dask_collection = lambda _: False
2845

29-
try:
30-
# solely for isinstance checks
31-
import sparse
46+
def is_duck_dask_array(x):
47+
if DuckArrayModule("dask").available:
48+
from dask.base import is_dask_collection
49+
50+
return is_duck_array(x) and is_dask_collection(x)
51+
else:
52+
return False
53+
3254

33-
sparse_version = LooseVersion(sparse.__version__)
34-
sparse_array_type = (sparse.SparseArray,)
35-
except ImportError: # pragma: no cover
36-
sparse_version = LooseVersion("0.0.0")
37-
sparse_array_type = ()
55+
dsk = DuckArrayModule("dask")
56+
dask_version = dsk.version
57+
dask_array_type = dsk.type
3858

39-
try:
40-
# solely for isinstance checks
41-
import cupy
59+
sp = DuckArrayModule("sparse")
60+
sparse_array_type = sp.type
61+
sparse_version = sp.version
4262

43-
cupy_version = LooseVersion(cupy.__version__)
44-
cupy_array_type = (cupy.ndarray,)
45-
except ImportError: # pragma: no cover
46-
cupy_version = LooseVersion("0.0.0")
47-
cupy_array_type = ()
63+
cupy_array_type = DuckArrayModule("cupy").type

xarray/core/variable.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
from .indexing import BasicIndexer, OuterIndexer, VectorizedIndexer, as_indexable
3030
from .options import _get_keep_attrs
3131
from .pycompat import (
32+
DuckArrayModule,
3233
cupy_array_type,
3334
dask_array_type,
3435
integer_types,
3536
is_duck_dask_array,
37+
sparse_array_type,
3638
)
3739
from .utils import (
3840
NdimSizeLenMixin,
@@ -259,7 +261,7 @@ def _as_array_or_item(data):
259261
260262
TODO: remove this (replace with np.asarray) once these issues are fixed
261263
"""
262-
data = data.get() if isinstance(data, cupy_array_type) else np.asarray(data)
264+
data = np.asarray(data)
263265
if data.ndim == 0:
264266
if data.dtype.kind == "M":
265267
data = np.datetime64(data, "ns")
@@ -1069,6 +1071,29 @@ def chunk(self, chunks={}, name=None, lock=False):
10691071

10701072
return self._replace(data=data)
10711073

1074+
def to_numpy(self) -> np.ndarray:
1075+
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
1076+
# TODO an entrypoint so array libraries can choose coercion method?
1077+
data = self.data
1078+
# TODO first attempt to call .to_numpy() once some libraries implement it
1079+
if isinstance(data, dask_array_type):
1080+
data = data.compute()
1081+
if isinstance(data, cupy_array_type):
1082+
data = data.get()
1083+
# pint has to be imported dynamically as pint imports xarray
1084+
pint_array_type = DuckArrayModule("pint").type
1085+
if isinstance(data, pint_array_type):
1086+
data = data.magnitude
1087+
if isinstance(data, sparse_array_type):
1088+
data = data.todense()
1089+
data = np.asarray(data)
1090+
1091+
return data
1092+
1093+
def as_numpy(self: VariableType) -> VariableType:
1094+
"""Coerces wrapped data into a numpy array, returning a Variable."""
1095+
return self._replace(data=self.to_numpy())
1096+
10721097
def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA):
10731098
"""
10741099
use sparse-array as backend.

xarray/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def LooseVersion(vstring):
8383
has_numbagg, requires_numbagg = _importorskip("numbagg")
8484
has_seaborn, requires_seaborn = _importorskip("seaborn")
8585
has_sparse, requires_sparse = _importorskip("sparse")
86+
has_cupy, requires_cupy = _importorskip("cupy")
8687
has_cartopy, requires_cartopy = _importorskip("cartopy")
8788
# Need Pint 0.15 for __dask_tokenize__ tests for Quantity wrapped Dask Arrays
8889
has_pint_0_15, requires_pint_0_15 = _importorskip("pint", minversion="0.15")

xarray/tests/test_dataarray.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@
3636
has_dask,
3737
raise_if_dask_computes,
3838
requires_bottleneck,
39+
requires_cupy,
3940
requires_dask,
4041
requires_iris,
4142
requires_numbagg,
4243
requires_numexpr,
44+
requires_pint_0_15,
4345
requires_scipy,
4446
requires_sparse,
4547
source_ndarray,
@@ -7375,3 +7377,87 @@ def test_drop_duplicates(keep):
73757377
expected = xr.DataArray(data, dims="time", coords={"time": time}, name="test")
73767378
result = ds.drop_duplicates("time", keep=keep)
73777379
assert_equal(expected, result)
7380+
7381+
7382+
class TestNumpyCoercion:
7383+
# TODO once flexible indexes refactor complete also test coercion of dimension coords
7384+
def test_from_numpy(self):
7385+
da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])})
7386+
7387+
assert_identical(da.as_numpy(), da)
7388+
np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3]))
7389+
np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6]))
7390+
7391+
@requires_dask
7392+
def test_from_dask(self):
7393+
da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])})
7394+
da_chunked = da.chunk(1)
7395+
7396+
assert_identical(da_chunked.as_numpy(), da.compute())
7397+
np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3]))
7398+
np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6]))
7399+
7400+
@requires_pint_0_15
7401+
def test_from_pint(self):
7402+
from pint import Quantity
7403+
7404+
arr = np.array([1, 2, 3])
7405+
da = xr.DataArray(
7406+
Quantity(arr, units="Pa"),
7407+
dims="x",
7408+
coords={"lat": ("x", Quantity(arr + 3, units="m"))},
7409+
)
7410+
7411+
expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)})
7412+
assert_identical(da.as_numpy(), expected)
7413+
np.testing.assert_equal(da.to_numpy(), arr)
7414+
np.testing.assert_equal(da["lat"].to_numpy(), arr + 3)
7415+
7416+
@requires_sparse
7417+
def test_from_sparse(self):
7418+
import sparse
7419+
7420+
arr = np.diagflat([1, 2, 3])
7421+
sparr = sparse.COO.from_numpy(arr)
7422+
da = xr.DataArray(
7423+
sparr, dims=["x", "y"], coords={"elev": (("x", "y"), sparr + 3)}
7424+
)
7425+
7426+
expected = xr.DataArray(
7427+
arr, dims=["x", "y"], coords={"elev": (("x", "y"), arr + 3)}
7428+
)
7429+
assert_identical(da.as_numpy(), expected)
7430+
np.testing.assert_equal(da.to_numpy(), arr)
7431+
7432+
@requires_cupy
7433+
def test_from_cupy(self):
7434+
import cupy as cp
7435+
7436+
arr = np.array([1, 2, 3])
7437+
da = xr.DataArray(
7438+
cp.array(arr), dims="x", coords={"lat": ("x", cp.array(arr + 3))}
7439+
)
7440+
7441+
expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)})
7442+
assert_identical(da.as_numpy(), expected)
7443+
np.testing.assert_equal(da.to_numpy(), arr)
7444+
7445+
@requires_dask
7446+
@requires_pint_0_15
7447+
def test_from_pint_wrapping_dask(self):
7448+
import dask
7449+
from pint import Quantity
7450+
7451+
arr = np.array([1, 2, 3])
7452+
d = dask.array.from_array(arr)
7453+
da = xr.DataArray(
7454+
Quantity(d, units="Pa"),
7455+
dims="x",
7456+
coords={"lat": ("x", Quantity(d, units="m") * 2)},
7457+
)
7458+
7459+
result = da.as_numpy()
7460+
result.name = None # remove dask-assigned name
7461+
expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr * 2)})
7462+
assert_identical(result, expected)
7463+
np.testing.assert_equal(da.to_numpy(), arr)

0 commit comments

Comments
 (0)