Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions spatialpandas/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
import pyarrow.parquet as pq
from dask import delayed
from dask.dataframe.core import get_parallel_type
from dask.dataframe.dispatch import make_meta_dispatch
from dask.dataframe.extensions import make_array_nonempty
from dask.dataframe.utils import make_meta_obj, meta_nonempty
from dask.dataframe.utils import meta_nonempty
from retrying import retry

from .geodataframe import GeoDataFrame
from .geometry.base import GeometryDtype, _BaseCoordinateIndexer
from .geometry.base import GeometryArray, GeometryDtype, _BaseCoordinateIndexer
from .geoseries import GeoSeries
from .spatialindex import HilbertRtree


@make_array_nonempty.register(GeometryDtype)
def make_geometry_array(dtype):
return GeometryArray([], dtype=dtype)


class DaskGeoSeries(dd.Series):
def __init__(self, expr, *args, **kwargs):
super().__init__(expr, *args, **kwargs)
Expand Down Expand Up @@ -98,10 +104,8 @@ def persist(self, **kwargs):
)


@make_meta_obj.register(GeoSeries)
@make_meta_dispatch.register(GeoSeries)
def make_meta_series(s, index=None):
if hasattr(s, "__array__") or isinstance(s, np.ndarray):
return s[:0]
result = s.head(0)
if index is not None:
result = result.reindex(index[:0])
Expand Down Expand Up @@ -581,10 +585,8 @@ def __getitem__(self, key):
return result


@make_meta_obj.register(GeoDataFrame)
@make_meta_dispatch.register(GeoDataFrame)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was wrongly changed to obj in #171

def make_meta_dataframe(df, index=None):
if hasattr(df, "__array__") or isinstance(df, np.ndarray):
return df[:0]
result = df.head(0)
if index is not None:
result = result.reindex(index[:0])
Expand All @@ -597,14 +599,15 @@ def meta_nonempty_dataframe(df, index=None):


@get_parallel_type.register(GeoDataFrame)
def get_parallel_type_dataframe(s):
def get_parallel_type_dataframe(df):
return DaskGeoDataFrame


@dd.get_collection_type.register(GeoDataFrame)
def get_collection_type_dataframe(df):
return DaskGeoDataFrame


class _DaskCoordinateIndexer(_BaseCoordinateIndexer):
def __init__(self, obj, sindex):
super().__init__(sindex)
Expand Down