Skip to content

Commit bce2f3e

Browse files
committed
rechunk padded values, handle 1 sized datasets
1 parent 72330ce commit bce2f3e

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

xarray/core/computation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,8 +1507,9 @@ def cross(a, b, dim):
15071507
for i, arr in enumerate(arrays):
15081508
if isinstance(arr, Dataset):
15091509
is_dataset = True
1510-
# TODO: How make sure this temporary dimension is matches
1511-
# the orther dataset?
1510+
# Turn the dataset to a stacked dataarray to follow the
1511+
# normal code path. Then at the end turn it back to a
1512+
# dataset.
15121513
arrays[i] = arr = arr.to_stacked_array(
15131514
variable_dim=dim, new_dim="variable", sample_dims=arr.dims
15141515
).unstack("variable")
@@ -1546,6 +1547,8 @@ def cross(a, b, dim):
15461547
# If the array doesn't have coords we can can only infer
15471548
# that it is composite values if the size is 2:
15481549
arrays[i] = arrays[i].pad({dim: (0, 1)}, constant_values=0)
1550+
if is_duck_dask_array(arrays[i].data):
1551+
arrays[i] = arrays[i].chunk({dim: -1})
15491552
else:
15501553
# Size is 1, then we do not know if the array is a constant or
15511554
# composite value:
@@ -1565,6 +1568,9 @@ def cross(a, b, dim):
15651568
c = c.transpose(*[d for d in all_dims if d in c.dims])
15661569
if is_dataset:
15671570
c = c.stack(variable=[dim]).to_unstacked_dataset("variable")
1571+
c = c.expand_dims(
1572+
[dim for ds in arrays for dim, size in ds.sizes.items() if size == 1]
1573+
)
15681574

15691575
return c
15701576

0 commit comments

Comments
 (0)