Skip to content

Faster unstacking of dask arrays #5582

Open
@dcherian

Description

@dcherian

Recent dask version support assigning to a list of ints along one dimension. we can use this for unstacking (diff builds on #5577)

diff --git i/xarray/core/variable.py w/xarray/core/variable.py
index 222e8dab9..a50dfc574 100644
--- i/xarray/core/variable.py
+++ w/xarray/core/variable.py
@@ -1593,11 +1593,9 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):
         else:
             dtype = self.dtype
 
-        if sparse:
+        if sparse and not is_duck_dask_array(reordered):
             # unstacking a dense multitindexed array to a sparse array
-            # Use the sparse.COO constructor until sparse supports advanced indexing
-            # https://github.com/pydata/sparse/issues/114
-            # TODO: how do we allow different sparse array types
+            # Use the sparse.COO constructor since we cannot assign to sparse.COO
             from sparse import COO
 
             codes = zip(*index.codes)
@@ -1618,19 +1616,23 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):
             )
 
         else:
+            # dask supports assigning to a list of ints along one axis only.
+            # So we construct an array with the last dimension flattened,
+            # assign the values, then reshape to the final shape.
+            intermediate_shape = reordered.shape[:-1] + (np.prod(new_dim_sizes),)
+            indexer = np.ravel_multi_index(index.codes, new_dim_sizes)
             data = np.full_like(
                 self.data,
                 fill_value=fill_value,
-                shape=new_shape,
+                shape=intermediate_shape,
                 dtype=dtype,
             )
 
             # Indexer is a list of lists of locations. Each list is the locations
             # on the new dimension. This is robust to the data being sparse; in that
             # case the destinations will be NaN / zero.
-            # sparse doesn't support item assigment,
-            # https://github.com/pydata/sparse/issues/114
-            data[(..., *indexer)] = reordered
+            data[(..., indexer)] = reordered
+            data = data.reshape(new_shape)
 
         return self._replace(dims=new_dims, data=data)

This should be what alignment.reindex_variables is doing but I don't fully understand that function.

The annoying bit is figuring out when to use this version and what to do with things like dask wrapping sparse. I think we want to loop over each variable in Dataset.unstack calling Variable.unstack and dispatch based on the type of Variable.data to easily handle all the edge cases.

cc @Illviljan if you're interested in implementing this

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions