Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster unstacking of dask arrays #5582

Open
dcherian opened this issue Jul 6, 2021 · 0 comments
Open

Faster unstacking of dask arrays #5582

dcherian opened this issue Jul 6, 2021 · 0 comments

Comments

@dcherian
Copy link
Contributor

dcherian commented Jul 6, 2021

Recent dask version support assigning to a list of ints along one dimension. we can use this for unstacking (diff builds on #5577)

diff --git i/xarray/core/variable.py w/xarray/core/variable.py
index 222e8dab9..a50dfc574 100644
--- i/xarray/core/variable.py
+++ w/xarray/core/variable.py
@@ -1593,11 +1593,9 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):
         else:
             dtype = self.dtype
 
-        if sparse:
+        if sparse and not is_duck_dask_array(reordered):
             # unstacking a dense multitindexed array to a sparse array
-            # Use the sparse.COO constructor until sparse supports advanced indexing
-            # https://github.com/pydata/sparse/issues/114
-            # TODO: how do we allow different sparse array types
+            # Use the sparse.COO constructor since we cannot assign to sparse.COO
             from sparse import COO
 
             codes = zip(*index.codes)
@@ -1618,19 +1616,23 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):
             )
 
         else:
+            # dask supports assigning to a list of ints along one axis only.
+            # So we construct an array with the last dimension flattened,
+            # assign the values, then reshape to the final shape.
+            intermediate_shape = reordered.shape[:-1] + (np.prod(new_dim_sizes),)
+            indexer = np.ravel_multi_index(index.codes, new_dim_sizes)
             data = np.full_like(
                 self.data,
                 fill_value=fill_value,
-                shape=new_shape,
+                shape=intermediate_shape,
                 dtype=dtype,
             )
 
             # Indexer is a list of lists of locations. Each list is the locations
             # on the new dimension. This is robust to the data being sparse; in that
             # case the destinations will be NaN / zero.
-            # sparse doesn't support item assigment,
-            # https://github.com/pydata/sparse/issues/114
-            data[(..., *indexer)] = reordered
+            data[(..., indexer)] = reordered
+            data = data.reshape(new_shape)
 
         return self._replace(dims=new_dims, data=data)

This should be what alignment.reindex_variables is doing but I don't fully understand that function.

The annoying bit is figuring out when to use this version and what to do with things like dask wrapping sparse. I think we want to loop over each variable in Dataset.unstack calling Variable.unstack and dispatch based on the type of Variable.data to easily handle all the edge cases.

cc @Illviljan if you're interested in implementing this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant