Skip to content

Commit 6925f89

Browse files
authored
fix: Some dask behavior (#185)
1 parent 3072832 commit 6925f89

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

spatialpandas/dask.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,22 @@
1212
import pyarrow.parquet as pq
1313
from dask import delayed
1414
from dask.dataframe.core import get_parallel_type
15+
from dask.dataframe.dispatch import make_meta_dispatch
1516
from dask.dataframe.extensions import make_array_nonempty
16-
from dask.dataframe.utils import make_meta_obj, meta_nonempty
17+
from dask.dataframe.utils import meta_nonempty
1718
from retrying import retry
1819

1920
from .geodataframe import GeoDataFrame
20-
from .geometry.base import GeometryDtype, _BaseCoordinateIndexer
21+
from .geometry.base import GeometryArray, GeometryDtype, _BaseCoordinateIndexer
2122
from .geoseries import GeoSeries
2223
from .spatialindex import HilbertRtree
2324

2425

26+
@make_array_nonempty.register(GeometryDtype)
27+
def make_geometry_array(dtype):
28+
return GeometryArray([], dtype=dtype)
29+
30+
2531
class DaskGeoSeries(dd.Series):
2632
def __init__(self, expr, *args, **kwargs):
2733
super().__init__(expr, *args, **kwargs)
@@ -98,10 +104,8 @@ def persist(self, **kwargs):
98104
)
99105

100106

101-
@make_meta_obj.register(GeoSeries)
107+
@make_meta_dispatch.register(GeoSeries)
102108
def make_meta_series(s, index=None):
103-
if hasattr(s, "__array__") or isinstance(s, np.ndarray):
104-
return s[:0]
105109
result = s.head(0)
106110
if index is not None:
107111
result = result.reindex(index[:0])
@@ -581,10 +585,8 @@ def __getitem__(self, key):
581585
return result
582586

583587

584-
@make_meta_obj.register(GeoDataFrame)
588+
@make_meta_dispatch.register(GeoDataFrame)
585589
def make_meta_dataframe(df, index=None):
586-
if hasattr(df, "__array__") or isinstance(df, np.ndarray):
587-
return df[:0]
588590
result = df.head(0)
589591
if index is not None:
590592
result = result.reindex(index[:0])
@@ -597,14 +599,15 @@ def meta_nonempty_dataframe(df, index=None):
597599

598600

599601
@get_parallel_type.register(GeoDataFrame)
600-
def get_parallel_type_dataframe(s):
602+
def get_parallel_type_dataframe(df):
601603
return DaskGeoDataFrame
602604

603605

604606
@dd.get_collection_type.register(GeoDataFrame)
605607
def get_collection_type_dataframe(df):
606608
return DaskGeoDataFrame
607609

610+
608611
class _DaskCoordinateIndexer(_BaseCoordinateIndexer):
609612
def __init__(self, obj, sindex):
610613
super().__init__(sindex)

0 commit comments

Comments
 (0)