7
7
import matplotlib as mpl
8
8
import mercantile
9
9
import numpy as np
10
+ import pandas as pd
10
11
import xarray as xr
11
12
from ipyleaflet import DrawControl , LocalTileLayer , WidgetControl
12
13
from ipyspin import Spinner
30
31
from .vector import Zvect
31
32
32
33
33
- @ xr . register_dataarray_accessor ( "leaflet" )
34
- class LeafletMap ( HasTraits ):
35
- """A xarray.DataArray extension for tiled map plotting, based on (ipy)leaflet."""
34
+ class Leaflet ( HasTraits ):
35
+
36
+ is_vector : bool
36
37
37
38
map_ready = Bool (False )
38
39
39
40
@observe ("map_ready" )
40
41
def _map_ready_changed (self , change ):
41
42
self ._start ()
42
43
43
- def __init__ (self , da : xr .DataArray = None , df : gpd .GeoDataFrame = None ):
44
- self ._da = da
45
- self ._df = df
46
- self ._da_selected = None
47
-
48
44
def plot (
49
45
self ,
50
46
m ,
@@ -127,14 +123,7 @@ def plot(
127
123
128
124
self .layer = LocalTileLayer ()
129
125
130
- source_nb = sum ([0 if i is None else 1 for i in (self ._da , self ._df )])
131
- if source_nb == 0 :
132
- raise RuntimeError ("No DataArray or GeoDataFrame provided" )
133
-
134
- if source_nb > 1 :
135
- raise RuntimeError ("Only one of DataArray or GeoDataFrame must be provided" )
136
-
137
- if self ._df is not None :
126
+ if self .is_vector :
138
127
# source is a GeoDataFrame (vector)
139
128
if measurement is None :
140
129
raise RuntimeError ("You must provide a 'measurement'." )
@@ -146,7 +135,7 @@ def plot(
146
135
)
147
136
if colormap is None :
148
137
colormap = plt .cm .viridis
149
- elif self . _da is not None :
138
+ else :
150
139
# source is a DataArray (raster)
151
140
if "proj4def" in m .crs :
152
141
# it's a custom projection
@@ -252,7 +241,7 @@ def plot(
252
241
else :
253
242
self .base_url = get_base_url (self .m .window_url )
254
243
255
- if fit_bounds and self . _da is not None :
244
+ if fit_bounds and not self . is_vector :
256
245
asyncio .ensure_future (self .async_fit_bounds ())
257
246
else :
258
247
asyncio .ensure_future (self .async_wait_for_bounds ())
@@ -302,7 +291,7 @@ def _get_selection(self, *args, **kwargs):
302
291
303
292
def _start (self ):
304
293
self .m .add_control (self .spinner_control )
305
- if self . _da is not None :
294
+ if not self . is_vector :
306
295
self ._da , self .transform0_args = get_transform (self .transform0 (self ._da ))
307
296
else :
308
297
self .layer .name = self .measurement
@@ -318,14 +307,16 @@ def _start(self):
318
307
self .layer .path = self .url
319
308
320
309
self .m .remove_control (self .spinner_control )
321
- if self . _da is not None :
310
+ if not self . is_vector :
322
311
get_tiles = self ._get_raster_tiles
323
- else :
312
+ elif self . _df is not None :
324
313
get_tiles = self ._get_vector_tiles
314
+ else :
315
+ raise RuntimeError ("No DataArray or GeoDataFrame provided." )
325
316
get_tiles ()
326
317
self .m .observe (get_tiles , names = "pixel_bounds" )
327
318
if not self .dynamic :
328
- if self . _da is not None :
319
+ if not self . is_vector :
329
320
self ._show_colorbar (self ._da_notransform )
330
321
self .m .add_layer (self .layer )
331
322
@@ -603,3 +594,22 @@ async def async_fit_bounds(self):
603
594
if self .base_url is None :
604
595
self .base_url = (await self .url_widget .get_url ()).rstrip ("/" )
605
596
self .map_ready = True
597
+
598
+
599
+ @xr .register_dataarray_accessor ("leaflet" )
600
+ class DataArrayLeaflet (Leaflet ):
601
+ """A DataArraye extension for tiled map plotting, based on (ipy)leaflet."""
602
+
603
+ def __init__ (self , da : xr .DataArray = None ):
604
+ self ._da = da
605
+ self ._da_selected = None
606
+ self .is_vector = False
607
+
608
+
609
+ @pd .api .extensions .register_dataframe_accessor ("leaflet" )
610
+ class GeoDataFrameLeaflet (Leaflet ):
611
+ """A GeoDataFrame extension for tiled map plotting, based on (ipy)leaflet."""
612
+
613
+ def __init__ (self , df : gpd .GeoDataFrame = None ):
614
+ self ._df = df
615
+ self .is_vector = True
0 commit comments