Skip to content

Commit f8ea86f

Browse files
authored
Add option to specify chunks in iris.util.broadcast_to_shape (#5620)
* Add option to specify chunks in broadcast_to_shape * Update docstring
1 parent 212ef56 commit f8ea86f

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

docs/src/whatsnew/latest.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ This document explains the changes made to Iris for this release
5353

5454
#. N/A
5555

56+
#. `@bouweandela`_ added the option to specify the Dask chunks of the target
57+
array in :func:`iris.util.broadcast_to_shape`. (:pull:`5620`)
5658

5759
🔥 Deprecations
5860
===============

lib/iris/tests/unit/util/test_broadcast_to_shape.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,31 @@ def test_lazy_masked(self, mocked_compute):
8181
for j in range(4):
8282
self.assertMaskedArrayEqual(b[i, :, j, :].compute().T, m.compute())
8383

84+
@mock.patch.object(dask.base, "compute", wraps=dask.base.compute)
85+
def test_lazy_chunks(self, mocked_compute):
86+
# chunks can be specified along with the target shape and are only used
87+
# along new dimensions or on dimensions that have size 1 in the source
88+
# array.
89+
m = da.ma.masked_array(
90+
data=[[1, 2, 3, 4, 5]],
91+
mask=[[0, 1, 0, 0, 0]],
92+
).rechunk((1, 2))
93+
b = broadcast_to_shape(
94+
m,
95+
dim_map=(1, 2),
96+
shape=(3, 4, 5),
97+
chunks=(
98+
1, # used because target is new dim
99+
2, # used because input size 1
100+
3, # not used because broadcast does not rechunk
101+
),
102+
)
103+
mocked_compute.assert_not_called()
104+
for i in range(3):
105+
for j in range(4):
106+
self.assertMaskedArrayEqual(b[i, j, :].compute(), m[0].compute())
107+
assert b.chunks == ((1, 1, 1), (2, 2), (2, 2, 1))
108+
84109
def test_masked_degenerate(self):
85110
# masked arrays can have degenerate masks too
86111
rng = np.random.default_rng()

lib/iris/util.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import iris.exceptions
2929

3030

31-
def broadcast_to_shape(array, shape, dim_map):
31+
def broadcast_to_shape(array, shape, dim_map, chunks=None):
3232
"""Broadcast an array to a given shape.
3333
3434
Each dimension of the array must correspond to a dimension in the
@@ -49,6 +49,14 @@ def broadcast_to_shape(array, shape, dim_map):
4949
the index in *shape* which the dimension of *array* corresponds
5050
to, so the first element of *dim_map* gives the index of *shape*
5151
that corresponds to the first dimension of *array* etc.
52+
chunks : :class:`tuple`, optional
53+
If the source array is a :class:`dask.array.Array` and a value is
54+
provided, then the result will use these chunks instead of the same
55+
chunks as the source array. Setting chunks explicitly as part of
56+
broadcast_to_shape is more efficient than rechunking afterwards. See
57+
also :func:`dask.array.broadcast_to`. Note that the values provided
58+
here will only be used along dimensions that are new on the result or
59+
have size 1 on the source array.
5260
5361
Examples
5462
--------
@@ -71,27 +79,39 @@ def broadcast_to_shape(array, shape, dim_map):
7179
See more at :doc:`/userguide/real_and_lazy_data`.
7280
7381
"""
82+
if isinstance(array, da.Array):
83+
if chunks is not None:
84+
chunks = list(chunks)
85+
for src_idx, tgt_idx in enumerate(dim_map):
86+
# Only use the specified chunks along new dimensions or on
87+
# dimensions that have size 1 in the source array.
88+
if array.shape[src_idx] != 1:
89+
chunks[tgt_idx] = array.chunks[src_idx]
90+
broadcast = functools.partial(da.broadcast_to, shape=shape, chunks=chunks)
91+
else:
92+
broadcast = functools.partial(np.broadcast_to, shape=shape)
93+
7494
n_orig_dims = len(array.shape)
7595
n_new_dims = len(shape) - n_orig_dims
7696
array = array.reshape(array.shape + (1,) * n_new_dims)
7797

7898
# Get dims in required order.
7999
array = np.moveaxis(array, range(n_orig_dims), dim_map)
80-
new_array = np.broadcast_to(array, shape)
100+
new_array = broadcast(array)
81101

82102
if ma.isMA(array):
83103
# broadcast_to strips masks so we need to handle them explicitly.
84104
mask = ma.getmask(array)
85105
if mask is ma.nomask:
86106
new_mask = ma.nomask
87107
else:
88-
new_mask = np.broadcast_to(mask, shape)
108+
new_mask = broadcast(mask)
89109
new_array = ma.array(new_array, mask=new_mask)
90110

91111
elif is_lazy_masked_data(array):
92112
# broadcast_to strips masks so we need to handle them explicitly.
93113
mask = da.ma.getmaskarray(array)
94-
new_mask = da.broadcast_to(mask, shape)
114+
new_mask = broadcast(mask)
95115
new_array = da.ma.masked_array(new_array, new_mask)
96116

97117
return new_array

0 commit comments

Comments
 (0)