12
12
import pyarrow .parquet as pq
13
13
from dask import delayed
14
14
from dask .dataframe .core import get_parallel_type
15
+ from dask .dataframe .dispatch import make_meta_dispatch
15
16
from dask .dataframe .extensions import make_array_nonempty
16
- from dask .dataframe .utils import make_meta_obj , meta_nonempty
17
+ from dask .dataframe .utils import meta_nonempty
17
18
from retrying import retry
18
19
19
20
from .geodataframe import GeoDataFrame
20
- from .geometry .base import GeometryDtype , _BaseCoordinateIndexer
21
+ from .geometry .base import GeometryArray , GeometryDtype , _BaseCoordinateIndexer
21
22
from .geoseries import GeoSeries
22
23
from .spatialindex import HilbertRtree
23
24
24
25
26
+ @make_array_nonempty .register (GeometryDtype )
27
+ def make_geometry_array (dtype ):
28
+ return GeometryArray ([], dtype = dtype )
29
+
30
+
25
31
class DaskGeoSeries (dd .Series ):
26
32
def __init__ (self , expr , * args , ** kwargs ):
27
33
super ().__init__ (expr , * args , ** kwargs )
@@ -98,10 +104,8 @@ def persist(self, **kwargs):
98
104
)
99
105
100
106
101
- @make_meta_obj .register (GeoSeries )
107
+ @make_meta_dispatch .register (GeoSeries )
102
108
def make_meta_series (s , index = None ):
103
- if hasattr (s , "__array__" ) or isinstance (s , np .ndarray ):
104
- return s [:0 ]
105
109
result = s .head (0 )
106
110
if index is not None :
107
111
result = result .reindex (index [:0 ])
@@ -581,10 +585,8 @@ def __getitem__(self, key):
581
585
return result
582
586
583
587
584
- @make_meta_obj .register (GeoDataFrame )
588
+ @make_meta_dispatch .register (GeoDataFrame )
585
589
def make_meta_dataframe (df , index = None ):
586
- if hasattr (df , "__array__" ) or isinstance (df , np .ndarray ):
587
- return df [:0 ]
588
590
result = df .head (0 )
589
591
if index is not None :
590
592
result = result .reindex (index [:0 ])
@@ -597,14 +599,15 @@ def meta_nonempty_dataframe(df, index=None):
597
599
598
600
599
601
@get_parallel_type .register (GeoDataFrame )
600
- def get_parallel_type_dataframe (s ):
602
+ def get_parallel_type_dataframe (df ):
601
603
return DaskGeoDataFrame
602
604
603
605
604
606
@dd .get_collection_type .register (GeoDataFrame )
605
607
def get_collection_type_dataframe (df ):
606
608
return DaskGeoDataFrame
607
609
610
+
608
611
class _DaskCoordinateIndexer (_BaseCoordinateIndexer ):
609
612
def __init__ (self , obj , sindex ):
610
613
super ().__init__ (sindex )
0 commit comments