@@ -1507,8 +1507,9 @@ def cross(a, b, dim):
15071507 for i , arr in enumerate (arrays ):
15081508 if isinstance (arr , Dataset ):
15091509 is_dataset = True
1510- # TODO: How make sure this temporary dimension is matches
1511- # the orther dataset?
1510+ # Turn the dataset to a stacked dataarray to follow the
1511+ # normal code path. Then at the end turn it back to a
1512+ # dataset.
15121513 arrays [i ] = arr = arr .to_stacked_array (
15131514 variable_dim = dim , new_dim = "variable" , sample_dims = arr .dims
15141515 ).unstack ("variable" )
@@ -1546,6 +1547,8 @@ def cross(a, b, dim):
15461547 # If the array doesn't have coords we can can only infer
15471548 # that it is composite values if the size is 2:
15481549 arrays [i ] = arrays [i ].pad ({dim : (0 , 1 )}, constant_values = 0 )
1550+ if is_duck_dask_array (arrays [i ].data ):
1551+ arrays [i ] = arrays [i ].chunk ({dim : - 1 })
15491552 else :
15501553 # Size is 1, then we do not know if the array is a constant or
15511554 # composite value:
@@ -1565,6 +1568,9 @@ def cross(a, b, dim):
15651568 c = c .transpose (* [d for d in all_dims if d in c .dims ])
15661569 if is_dataset :
15671570 c = c .stack (variable = [dim ]).to_unstacked_dataset ("variable" )
1571+ c = c .expand_dims (
1572+ [dim for ds in arrays for dim , size in ds .sizes .items () if size == 1 ]
1573+ )
15681574
15691575 return c
15701576
0 commit comments