Skip to content

Commit b8e3be8

Browse files
authored
Do not rely on np.broadcast_to to perform trivial dimension insertion (#10277)
1 parent 47fbbb6 commit b8e3be8

File tree

3 files changed

+83
-5
lines changed

3 files changed

+83
-5
lines changed

doc/whats-new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ Internal Changes
104104
~~~~~~~~~~~~~~~~
105105
- Avoid stacking when grouping by a chunked array. This can be a large performance improvement.
106106
By `Deepak Cherian <https://github.com/dcherian>`_.
107+
- The implementation of ``Variable.set_dims`` has changed to use array indexing syntax
108+
instead of ``np.broadcast_to`` to perform dimension expansions where
109+
all new dimensions have a size of 1. This should improve compatibility with
110+
duck arrays that do not support broadcasting (:issue:`9462`, :pull:`10277`).
111+
By `Mark Harfouche <https://github.com/hmaarrfk>`_.
107112

108113
.. _whats-new.2025.03.1:
109114

xarray/core/variable.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,7 @@ def set_dims(self, dim, shape=None):
13551355
dim = [dim]
13561356

13571357
if shape is None and is_dict_like(dim):
1358-
shape = dim.values()
1358+
shape = tuple(dim.values())
13591359

13601360
missing_dims = set(self.dims) - set(dim)
13611361
if missing_dims:
@@ -1371,13 +1371,18 @@ def set_dims(self, dim, shape=None):
13711371
# don't use broadcast_to unless necessary so the result remains
13721372
# writeable if possible
13731373
expanded_data = self.data
1374-
elif shape is not None:
1374+
elif shape is None or all(
1375+
s == 1 for s, e in zip(shape, dim, strict=True) if e not in self_dims
1376+
):
1377+
# "Trivial" broadcasting, i.e. simply inserting a new dimension
1378+
# This is typically easier for duck arrays to implement
1379+
# than the full "broadcast_to" semantics
1380+
indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
1381+
expanded_data = self.data[indexer]
1382+
else: # elif shape is not None:
13751383
dims_map = dict(zip(dim, shape, strict=True))
13761384
tmp_shape = tuple(dims_map[d] for d in expanded_dims)
13771385
expanded_data = duck_array_ops.broadcast_to(self._data, tmp_shape)
1378-
else:
1379-
indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
1380-
expanded_data = self.data[indexer]
13811386

13821387
expanded_var = Variable(
13831388
expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True

xarray/tests/test_variable.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,74 @@ def test_set_dims_object_dtype(self):
16531653
expected = Variable(["x"], exp_values)
16541654
assert_identical(actual, expected)
16551655

1656+
def test_set_dims_without_broadcast(self):
1657+
class ArrayWithoutBroadcastTo(NDArrayMixin, indexing.ExplicitlyIndexed):
1658+
def __init__(self, array):
1659+
self.array = array
1660+
1661+
# Broadcasting with __getitem__ is "easier" to implement
1662+
# especially for dims of 1
1663+
def __getitem__(self, key):
1664+
return self.array[key]
1665+
1666+
def __array_function__(self, *args, **kwargs):
1667+
raise NotImplementedError(
1668+
"Not we don't want to use broadcast_to here "
1669+
"https://github.com/pydata/xarray/issues/9462"
1670+
)
1671+
1672+
arr = ArrayWithoutBroadcastTo(np.zeros((3, 4)))
1673+
# We should be able to add a new axis without broadcasting
1674+
assert arr[np.newaxis, :, :].shape == (1, 3, 4)
1675+
with pytest.raises(NotImplementedError):
1676+
np.broadcast_to(arr, (1, 3, 4))
1677+
1678+
v = Variable(["x", "y"], arr)
1679+
v_expanded = v.set_dims(["z", "x", "y"])
1680+
assert v_expanded.dims == ("z", "x", "y")
1681+
assert v_expanded.shape == (1, 3, 4)
1682+
1683+
v_expanded = v.set_dims(["x", "z", "y"])
1684+
assert v_expanded.dims == ("x", "z", "y")
1685+
assert v_expanded.shape == (3, 1, 4)
1686+
1687+
v_expanded = v.set_dims(["x", "y", "z"])
1688+
assert v_expanded.dims == ("x", "y", "z")
1689+
assert v_expanded.shape == (3, 4, 1)
1690+
1691+
# Explicitly asking for a shape of 1 triggers a different
1692+
# codepath in set_dims
1693+
# https://github.com/pydata/xarray/issues/9462
1694+
v_expanded = v.set_dims(["z", "x", "y"], shape=(1, 3, 4))
1695+
assert v_expanded.dims == ("z", "x", "y")
1696+
assert v_expanded.shape == (1, 3, 4)
1697+
1698+
v_expanded = v.set_dims(["x", "z", "y"], shape=(3, 1, 4))
1699+
assert v_expanded.dims == ("x", "z", "y")
1700+
assert v_expanded.shape == (3, 1, 4)
1701+
1702+
v_expanded = v.set_dims(["x", "y", "z"], shape=(3, 4, 1))
1703+
assert v_expanded.dims == ("x", "y", "z")
1704+
assert v_expanded.shape == (3, 4, 1)
1705+
1706+
v_expanded = v.set_dims({"z": 1, "x": 3, "y": 4})
1707+
assert v_expanded.dims == ("z", "x", "y")
1708+
assert v_expanded.shape == (1, 3, 4)
1709+
1710+
v_expanded = v.set_dims({"x": 3, "z": 1, "y": 4})
1711+
assert v_expanded.dims == ("x", "z", "y")
1712+
assert v_expanded.shape == (3, 1, 4)
1713+
1714+
v_expanded = v.set_dims({"x": 3, "y": 4, "z": 1})
1715+
assert v_expanded.dims == ("x", "y", "z")
1716+
assert v_expanded.shape == (3, 4, 1)
1717+
1718+
with pytest.raises(NotImplementedError):
1719+
v.set_dims({"z": 2, "x": 3, "y": 4})
1720+
1721+
with pytest.raises(NotImplementedError):
1722+
v.set_dims(["z", "x", "y"], shape=(2, 3, 4))
1723+
16561724
def test_stack(self):
16571725
v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"})
16581726
actual = v.stack(z=("x", "y"))

0 commit comments

Comments
 (0)