2828import iris .exceptions
2929
3030
31- def broadcast_to_shape (array , shape , dim_map ):
31+ def broadcast_to_shape (array , shape , dim_map , chunks = None ):
3232 """Broadcast an array to a given shape.
3333
3434 Each dimension of the array must correspond to a dimension in the
@@ -49,6 +49,14 @@ def broadcast_to_shape(array, shape, dim_map):
4949 the index in *shape* which the dimension of *array* corresponds
5050 to, so the first element of *dim_map* gives the index of *shape*
5151 that corresponds to the first dimension of *array* etc.
52+ chunks : :class:`tuple`, optional
53+ If the source array is a :class:`dask.array.Array` and a value is
54+ provided, then the result will use these chunks instead of the same
55+ chunks as the source array. Setting chunks explicitly as part of
56+ broadcast_to_shape is more efficient than rechunking afterwards. See
57+ also :func:`dask.array.broadcast_to`. Note that the values provided
58+ here will only be used along dimensions that are new on the result or
59+ have size 1 on the source array.
5260
5361 Examples
5462 --------
@@ -71,27 +79,39 @@ def broadcast_to_shape(array, shape, dim_map):
7179 See more at :doc:`/userguide/real_and_lazy_data`.
7280
7381 """
82+ if isinstance (array , da .Array ):
83+ if chunks is not None :
84+ chunks = list (chunks )
85+ for src_idx , tgt_idx in enumerate (dim_map ):
86+ # Only use the specified chunks along new dimensions or on
87+ # dimensions that have size 1 in the source array.
88+ if array .shape [src_idx ] != 1 :
89+ chunks [tgt_idx ] = array .chunks [src_idx ]
90+ broadcast = functools .partial (da .broadcast_to , shape = shape , chunks = chunks )
91+ else :
92+ broadcast = functools .partial (np .broadcast_to , shape = shape )
93+
7494 n_orig_dims = len (array .shape )
7595 n_new_dims = len (shape ) - n_orig_dims
7696 array = array .reshape (array .shape + (1 ,) * n_new_dims )
7797
7898 # Get dims in required order.
7999 array = np .moveaxis (array , range (n_orig_dims ), dim_map )
80- new_array = np . broadcast_to (array , shape )
100+ new_array = broadcast (array )
81101
82102 if ma .isMA (array ):
83103 # broadcast_to strips masks so we need to handle them explicitly.
84104 mask = ma .getmask (array )
85105 if mask is ma .nomask :
86106 new_mask = ma .nomask
87107 else :
88- new_mask = np . broadcast_to (mask , shape )
108+ new_mask = broadcast (mask )
89109 new_array = ma .array (new_array , mask = new_mask )
90110
91111 elif is_lazy_masked_data (array ):
92112 # broadcast_to strips masks so we need to handle them explicitly.
93113 mask = da .ma .getmaskarray (array )
94- new_mask = da . broadcast_to (mask , shape )
114+ new_mask = broadcast (mask )
95115 new_array = da .ma .masked_array (new_array , new_mask )
96116
97117 return new_array
0 commit comments