Skip to content

Commit fc1dbb5

Browse files
Make broadcast and concat work with the Array API (#7387)
* Make `broadcast` and `concat` work with the Array API. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b93dae4 commit fc1dbb5

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

xarray/core/duck_array_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import pandas as pd
1717
from numpy import all as array_all # noqa
1818
from numpy import any as array_any # noqa
19+
from numpy import around # noqa
1920
from numpy import zeros_like # noqa
20-
from numpy import around, broadcast_to # noqa
2121
from numpy import concatenate as _concatenate
2222
from numpy import ( # noqa
2323
einsum,
@@ -207,6 +207,11 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
207207
return [astype(x, out_type, copy=False) for x in arrays]
208208

209209

210+
def broadcast_to(array, shape):
211+
xp = get_array_namespace(array)
212+
return xp.broadcast_to(array, shape)
213+
214+
210215
def lazy_array_equiv(arr1, arr2):
211216
"""Like array_equal, but doesn't actually compare values.
212217
Returns True when arr1, arr2 identical or their dask tokens are equal.
@@ -311,6 +316,9 @@ def fillna(data, other):
311316

312317
def concatenate(arrays, axis=0):
313318
"""concatenate() with better dtype promotion rules."""
319+
if hasattr(arrays[0], "__array_namespace__"):
320+
xp = get_array_namespace(arrays[0])
321+
return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis)
314322
return _concatenate(as_shared_dtype(arrays), axis=axis)
315323

316324

xarray/tests/test_array_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,27 @@ def test_astype(arrays) -> None:
6464
assert_equal(actual, expected)
6565

6666

67+
def test_broadcast(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
68+
np_arr, xp_arr = arrays
69+
np_arr2 = xr.DataArray(np.array([1.0, 2.0]), dims="x")
70+
xp_arr2 = xr.DataArray(xp.asarray([1.0, 2.0]), dims="x")
71+
72+
expected = xr.broadcast(np_arr, np_arr2)
73+
actual = xr.broadcast(xp_arr, xp_arr2)
74+
assert len(actual) == len(expected)
75+
for a, e in zip(actual, expected):
76+
assert isinstance(a.data, Array)
77+
assert_equal(a, e)
78+
79+
80+
def test_concat(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
81+
np_arr, xp_arr = arrays
82+
expected = xr.concat((np_arr, np_arr), dim="x")
83+
actual = xr.concat((xp_arr, xp_arr), dim="x")
84+
assert isinstance(actual.data, Array)
85+
assert_equal(actual, expected)
86+
87+
6788
def test_indexing(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
6889
np_arr, xp_arr = arrays
6990
expected = np_arr[:, 0]

0 commit comments

Comments
 (0)