Skip to content

Commit 15b63ca

Browse files
committed
Add (Geo)DataFrame accessor
1 parent 13ae9f4 commit 15b63ca

File tree

6 files changed

+40
-30
lines changed

6 files changed

+40
-30
lines changed

examples/introduction.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@
246246
"name": "python",
247247
"nbconvert_exporter": "python",
248248
"pygments_lexer": "ipython3",
249-
"version": "3.9.6"
249+
"version": "3.10.5"
250250
}
251251
},
252252
"nbformat": 4,

examples/vector.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"import geopandas as gpd\n",
1111
"from ipyleaflet import LayersControl, Map, WidgetControl, basemaps\n",
1212
"from ipywidgets import FloatSlider\n",
13-
"from xarray_leaflet import LeafletMap\n",
13+
"import xarray_leaflet\n",
1414
"import matplotlib.pyplot as plt"
1515
]
1616
},
@@ -54,8 +54,7 @@
5454
"metadata": {},
5555
"outputs": [],
5656
"source": [
57-
"lm = LeafletMap(df=df)\n",
58-
"l = lm.plot(m, measurement=\"mask\", dynamic=True, colormap=plt.cm.inferno)"
57+
"l = df.leaflet.plot(m, measurement=\"mask\", colormap=plt.cm.inferno)"
5958
]
6059
},
6160
{

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ install_requires =
2929
mercantile >=1
3030
ipyspin >=0.1.1
3131
ipyurl >=0.1.2
32-
geocube
32+
geocube <1.0.0
3333
pygeos >=0.12,<1.0.0
3434
zarr >=2.0.0,<3.0.0
3535

ui-tests/notebooks/test_vector.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"import geopandas\n",
1010
"import matplotlib.pyplot as plt\n",
1111
"from ipyleaflet import Map\n",
12-
"from xarray_leaflet import LeafletMap"
12+
"import xarray_leaflet"
1313
]
1414
},
1515
{
@@ -39,7 +39,7 @@
3939
"metadata": {},
4040
"outputs": [],
4141
"source": [
42-
"l = LeafletMap(df=df).plot(m, fit_bounds=False, colormap=plt.cm.inferno, measurement=\"mask\")"
42+
"l = df.leaflet.plot(m, fit_bounds=False, colormap=plt.cm.inferno, measurement=\"mask\")"
4343
]
4444
}
4545
],

xarray_leaflet/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .server_extension import _jupyter_nbextension_paths # noqa
44
from .server_extension import _jupyter_server_extension_paths # noqa
55
from .server_extension import _load_jupyter_server_extension
6-
from .xarray_leaflet import LeafletMap # noqa
6+
from .xarray_leaflet import DataArrayLeaflet # noqa
7+
from .xarray_leaflet import GeoDataFrameLeaflet # noqa
78

89
load_jupyter_server_extension = _load_jupyter_server_extension

xarray_leaflet/xarray_leaflet.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import matplotlib as mpl
88
import mercantile
99
import numpy as np
10+
import pandas as pd
1011
import xarray as xr
1112
from ipyleaflet import DrawControl, LocalTileLayer, WidgetControl
1213
from ipyspin import Spinner
@@ -30,21 +31,16 @@
3031
from .vector import Zvect
3132

3233

33-
@xr.register_dataarray_accessor("leaflet")
34-
class LeafletMap(HasTraits):
35-
"""A xarray.DataArray extension for tiled map plotting, based on (ipy)leaflet."""
34+
class Leaflet(HasTraits):
35+
36+
is_vector: bool
3637

3738
map_ready = Bool(False)
3839

3940
@observe("map_ready")
4041
def _map_ready_changed(self, change):
4142
self._start()
4243

43-
def __init__(self, da: xr.DataArray = None, df: gpd.GeoDataFrame = None):
44-
self._da = da
45-
self._df = df
46-
self._da_selected = None
47-
4844
def plot(
4945
self,
5046
m,
@@ -127,14 +123,7 @@ def plot(
127123

128124
self.layer = LocalTileLayer()
129125

130-
source_nb = sum([0 if i is None else 1 for i in (self._da, self._df)])
131-
if source_nb == 0:
132-
raise RuntimeError("No DataArray or GeoDataFrame provided")
133-
134-
if source_nb > 1:
135-
raise RuntimeError("Only one of DataArray or GeoDataFrame must be provided")
136-
137-
if self._df is not None:
126+
if self.is_vector:
138127
# source is a GeoDataFrame (vector)
139128
if measurement is None:
140129
raise RuntimeError("You must provide a 'measurement'.")
@@ -146,7 +135,7 @@ def plot(
146135
)
147136
if colormap is None:
148137
colormap = plt.cm.viridis
149-
elif self._da is not None:
138+
else:
150139
# source is a DataArray (raster)
151140
if "proj4def" in m.crs:
152141
# it's a custom projection
@@ -252,7 +241,7 @@ def plot(
252241
else:
253242
self.base_url = get_base_url(self.m.window_url)
254243

255-
if fit_bounds and self._da is not None:
244+
if fit_bounds and not self.is_vector:
256245
asyncio.ensure_future(self.async_fit_bounds())
257246
else:
258247
asyncio.ensure_future(self.async_wait_for_bounds())
@@ -302,7 +291,7 @@ def _get_selection(self, *args, **kwargs):
302291

303292
def _start(self):
304293
self.m.add_control(self.spinner_control)
305-
if self._da is not None:
294+
if not self.is_vector:
306295
self._da, self.transform0_args = get_transform(self.transform0(self._da))
307296
else:
308297
self.layer.name = self.measurement
@@ -318,14 +307,16 @@ def _start(self):
318307
self.layer.path = self.url
319308

320309
self.m.remove_control(self.spinner_control)
321-
if self._da is not None:
310+
if not self.is_vector:
322311
get_tiles = self._get_raster_tiles
323-
else:
312+
elif self._df is not None:
324313
get_tiles = self._get_vector_tiles
314+
else:
315+
raise RuntimeError("No DataArray or GeoDataFrame provided.")
325316
get_tiles()
326317
self.m.observe(get_tiles, names="pixel_bounds")
327318
if not self.dynamic:
328-
if self._da is not None:
319+
if not self.is_vector:
329320
self._show_colorbar(self._da_notransform)
330321
self.m.add_layer(self.layer)
331322

@@ -603,3 +594,22 @@ async def async_fit_bounds(self):
603594
if self.base_url is None:
604595
self.base_url = (await self.url_widget.get_url()).rstrip("/")
605596
self.map_ready = True
597+
598+
599+
@xr.register_dataarray_accessor("leaflet")
600+
class DataArrayLeaflet(Leaflet):
601+
"""A DataArraye extension for tiled map plotting, based on (ipy)leaflet."""
602+
603+
def __init__(self, da: xr.DataArray = None):
604+
self._da = da
605+
self._da_selected = None
606+
self.is_vector = False
607+
608+
609+
@pd.api.extensions.register_dataframe_accessor("leaflet")
610+
class GeoDataFrameLeaflet(Leaflet):
611+
"""A GeoDataFrame extension for tiled map plotting, based on (ipy)leaflet."""
612+
613+
def __init__(self, df: gpd.GeoDataFrame = None):
614+
self._df = df
615+
self.is_vector = True

0 commit comments

Comments
 (0)