Skip to content

Commit 6dd089f

Browse files
committed
More optimizations, change API
1 parent 792d38e commit 6dd089f

File tree

4 files changed

+80
-77
lines changed

4 files changed

+80
-77
lines changed

examples/vector.ipynb

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
"metadata": {},
88
"outputs": [],
99
"source": [
10+
"from functools import partial\n",
11+
"from geocube.rasterize import rasterize_image\n",
12+
"from rasterio.enums import MergeAlg\n",
1013
"import geopandas as gpd\n",
11-
"from ipyleaflet import LayersControl, Map, WidgetControl, basemaps\n",
14+
"from ipyleaflet import LocalTileLayer, LayersControl, Map, WidgetControl, basemaps\n",
1215
"from ipywidgets import FloatSlider\n",
1316
"import xarray_leaflet\n",
1417
"import matplotlib.pyplot as plt"
@@ -32,8 +35,7 @@
3235
"metadata": {},
3336
"outputs": [],
3437
"source": [
35-
"df = gpd.read_file(\"bldg_footprints.shp\")\n",
36-
"df[\"mask\"] = 1"
38+
"df = gpd.read_file(\"bldg_footprints.shp\")"
3739
]
3840
},
3941
{
@@ -54,7 +56,9 @@
5456
"metadata": {},
5557
"outputs": [],
5658
"source": [
57-
"l = df.leaflet.plot(m, measurement=\"mask\", colormap=plt.cm.inferno)"
59+
"rasterize_function = partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=False)\n",
60+
"layer = partial(LocalTileLayer, max_zoom=20)\n",
61+
"l = df.leaflet.plot(m, measurement=\"Height\", layer=layer, dynamic=False, rasterize_function=rasterize_function, colormap=plt.cm.viridis)"
5862
]
5963
},
6064
{
@@ -94,7 +98,7 @@
9498
"name": "python",
9599
"nbconvert_exporter": "python",
96100
"pygments_lexer": "ipython3",
97-
"version": "3.10.5"
101+
"version": "3.10.6"
98102
}
99103
},
100104
"nbformat": 4,

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ install_requires =
2323
jupyter_server >=0.2.0
2424
rioxarray >=0.0.30
2525
ipyleaflet >=0.13.1
26+
ipywidgets >=7.7.2
2627
pillow >=7
2728
matplotlib >=3
2829
affine >=2
2930
mercantile >=1
3031
ipyspin >=0.1.6
3132
ipyurl >=0.1.3
32-
jupyterlab-widgets >=1.0.0,<2
3333
geocube <1.0.0
3434
pygeos >=0.12,<1.0.0
3535
zarr >=2.0.0,<3.0.0

xarray_leaflet/vector.py

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import json
2+
import math
23
from functools import partial
34
from pathlib import Path
4-
from typing import Optional
5+
from typing import Callable, Optional
56

67
import mercantile
78
import numpy as np
8-
import pyproj
99
import xarray as xr
1010
import zarr
1111
from geocube.api.core import make_geocube
@@ -23,13 +23,15 @@ def __init__(
2323
self,
2424
df: GeoDataFrame,
2525
measurement: str,
26+
rasterize_function: Optional[Callable],
2627
width: int,
2728
height: int,
2829
root_path: str = "",
2930
):
3031
# reproject to Web Mercator
3132
self.df = df.to_crs(epsg=3857)
3233
self.measurement = measurement
34+
self.rasterize_function = rasterize_function or partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=True)
3335
self.width = width
3436
self.height = height
3537
self.zzarr = Zzarr(root_path, width, height)
@@ -56,9 +58,7 @@ def get_da_tile(self, tile: mercantile.Tile) -> Optional[xr.DataArray]:
5658
vector_data=df_tile,
5759
resolution=(-dy, dx),
5860
measurements=[self.measurement],
59-
rasterize_function=partial(
60-
rasterize_image, merge_alg=MergeAlg.add, all_touched=True
61-
),
61+
rasterize_function=self.rasterize_function,
6262
fill=0,
6363
geom=geom,
6464
)
@@ -82,15 +82,10 @@ def get_da_llbbox(
8282
self.tiles.append(tile)
8383
if all_none:
8484
return None
85-
project = pyproj.Transformer.from_crs(
86-
pyproj.CRS("EPSG:4326"), pyproj.CRS("EPSG:3857"), always_xy=True
87-
).transform
88-
b = box(*bbox)
89-
polygon = transform(project, b)
90-
left, bottom, right, top = polygon.bounds
91-
return self.zzarr.get_ds(z)["da"].sel(
92-
x=slice(left, right), y=slice(top, bottom)
93-
)
85+
da = self.get_da(z)
86+
y0, x0 = deg2idx(bbox.north, bbox.west, z, self.height, self.width, math.floor)
87+
y1, x1 = deg2idx(bbox.south, bbox.east, z, self.height, self.width, math.ceil)
88+
return da[y0:y1, x0:x1]
9489

9590
def get_da(self, z: int) -> xr.DataArray:
9691
return self.zzarr.get_ds(z)["da"]
@@ -101,7 +96,7 @@ def __init__(self, root_path: str, width: int, height: int):
10196
self.root_path = Path(root_path)
10297
self.width = width
10398
self.height = height
104-
self.ds = {}
99+
self.z = None
105100

106101
def open_zarr(self, mode: str, z: int) -> zarr.Array:
107102
path = self.root_path / str(z)
@@ -114,32 +109,6 @@ def open_zarr(self, mode: str, z: int) -> zarr.Array:
114109
)
115110
if mode == "w":
116111
# write Dataset to zarr
117-
mi, ma = mercantile.minmax(z)
118-
ul = mercantile.xy_bounds(mi, mi, z)
119-
lr = mercantile.xy_bounds(ma, ma, z)
120-
bbox = mercantile.Bbox(ul.left, lr.bottom, lr.right, ul.top)
121-
x = zarr.open(
122-
path / "x",
123-
mode="w",
124-
shape=(2**z * self.width,),
125-
chunks=(2**z * self.width,),
126-
dtype="<f8",
127-
)
128-
x[:] = np.linspace(bbox.left, bbox.right, 2**z * self.width)
129-
x_zattrs = dict(_ARRAY_DIMENSIONS=["x"])
130-
(path / "x" / ".zattrs").write_text(json.dumps(x_zattrs))
131-
y = zarr.open(
132-
path / "y",
133-
mode="w",
134-
shape=(2**z * self.height,),
135-
chunks=(2**z * self.height,),
136-
dtype="<f8",
137-
)
138-
y[:] = np.linspace(bbox.top, bbox.bottom, 2**z * self.height)
139-
x_zarray = json.loads((path / "x" / ".zarray").read_text())
140-
y_zarray = json.loads((path / "y" / ".zarray").read_text())
141-
y_zattrs = dict(_ARRAY_DIMENSIONS=["y"])
142-
(path / "y" / ".zattrs").write_text(json.dumps(y_zattrs))
143112
(path / ".zattrs").write_text(json.dumps(dict()))
144113
zarray = json.loads((path / "da" / ".zarray").read_text())
145114
zattrs = dict(_ARRAY_DIMENSIONS=["y", "x"])
@@ -153,10 +122,6 @@ def open_zarr(self, mode: str, z: int) -> zarr.Array:
153122
".zgroup": zgroup,
154123
"da/.zarray": zarray,
155124
"da/.zattrs": zattrs,
156-
"x/.zarray": x_zarray,
157-
"x/.zattrs": x_zattrs,
158-
"y/.zarray": y_zarray,
159-
"y/.zattrs": y_zattrs,
160125
},
161126
zarr_consolidated_format=1,
162127
)
@@ -172,14 +137,22 @@ def write_to_zarr(self, tile: mercantile.Tile, data: np.ndarray):
172137
mode = "a"
173138
else:
174139
mode = "w"
175-
self.array = self.open_zarr(mode, z)
176-
self.array[
140+
array = self.open_zarr(mode, z)
141+
array[
177142
y * self.height : (y + 1) * self.height, # noqa
178143
x * self.width : (x + 1) * self.width, # noqa
179144
] = data
180145

181146
def get_ds(self, z: int) -> xr.Dataset:
182147
path = self.root_path / str(z)
183-
if z not in self.ds:
184-
self.ds[z] = xr.open_zarr(path)
185-
return self.ds[z]
148+
if z != self.z:
149+
self.ds_z = xr.open_zarr(path)
150+
self.z = z
151+
return self.ds_z
152+
153+
def deg2idx(lat_deg, lon_deg, zoom, height, width, round_fun):
154+
lat_rad = math.radians(lat_deg)
155+
n = 2 ** zoom
156+
xtile = round_fun(((lon_deg + 180) % 360) / 360 * n * width)
157+
ytile = round_fun((1 - math.asinh(math.tan(lat_rad)) / math.pi) / 2 * n * height)
158+
return ytile, xtile

xarray_leaflet/xarray_leaflet.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ def _map_ready_changed(self, change):
4444
def plot(
4545
self,
4646
m,
47+
*,
48+
# raster or vector options:
49+
get_base_url: Optional[Callable] = None,
50+
dynamic: Optional[bool] = None,
51+
persist: bool = True,
52+
tile_dir=None,
53+
tile_height: int = 256,
54+
tile_width: int = 256,
55+
layer: Callable = LocalTileLayer,
56+
# raster-only options:
4757
x_dim="x",
4858
y_dim="y",
4959
fit_bounds=True,
@@ -54,15 +64,11 @@ def plot(
5464
transform3=passthrough,
5565
colormap=None,
5666
colorbar_position="topright",
57-
persist=True,
58-
dynamic=False,
59-
tile_dir=None,
60-
tile_height=256,
61-
tile_width=256,
6267
resampling=Resampling.nearest,
63-
get_base_url=None,
68+
# vector-only options:
6469
measurement: Optional[str] = None,
6570
visible_callback: Optional[Callable] = None,
71+
rasterize_function: Optional[Callable] = None,
6672
):
6773
"""Display an array as an interactive map.
6874
@@ -122,30 +128,39 @@ def plot(
122128
- the mercantile.LngLatBbox of the visible region
123129
124130
and returning True if the layer should be shown, False otherwise.
131+
rasterize_function: callable, optional
132+
A callable passed to make_geocube. Defaults to:
133+
partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=True)
125134
126135
Returns
127136
-------
128137
layer : ipyleaflet.LocalTileLayer
129138
A handler to the layer that is added to the map.
130139
"""
131140

132-
self.layer = LocalTileLayer()
141+
self.layer = layer()
133142

134143
if self.is_vector:
135144
# source is a GeoDataFrame (vector)
136145
self.visible_callback = visible_callback
137146
if measurement is None:
138147
raise RuntimeError("You must provide a 'measurement'.")
148+
if dynamic is None:
149+
dynamic = True
150+
if not dynamic:
151+
self.vmin = self._df[measurement].min()
152+
self.vmax = self._df[measurement].max()
139153
self.measurement = measurement
140-
dynamic = True
141154
zarr_temp_dir = tempfile.TemporaryDirectory(prefix="xarray_leaflet_zarr_")
142155
self.zvect = Zvect(
143-
self._df, measurement, tile_width, tile_height, zarr_temp_dir.name
156+
self._df, measurement, rasterize_function, tile_width, tile_height, zarr_temp_dir.name
144157
)
145158
if colormap is None:
146159
colormap = plt.cm.viridis
147160
else:
148161
# source is a DataArray (raster)
162+
if dynamic is None:
163+
dynamic = False
149164
if "proj4def" in m.crs:
150165
# it's a custom projection
151166
if dynamic:
@@ -363,6 +378,7 @@ def _get_vector_tiles(self, change=None):
363378
tiles = mercantile.tiles(west, south, east, north, z)
364379

365380
if self.dynamic:
381+
# get DataArray for the visible map
366382
llbbox = mercantile.LngLatBbox(west, south, east, north)
367383
da_visible = self.zvect.get_da_llbbox(llbbox, z)
368384
# check if we must show the layer
@@ -372,32 +388,42 @@ def _get_vector_tiles(self, change=None):
372388
self.m.remove_control(self.spinner_control)
373389
return
374390
if da_visible is None:
375-
self.max_value = 0
391+
vmin = vmax = 0
376392
else:
377-
self.max_value = da_visible.max()
393+
vmin = da_visible.min()
394+
vmax = da_visible.max()
395+
else:
396+
vmin = self.vmin
397+
vmax = self.vmax
398+
da_visible_computed = False
378399

379400
for tile in tiles:
380401
x, y, z = tile
381402
path = f"{self.tile_path}/{z}/{x}/{y}.png"
382403
if self.dynamic or not os.path.exists(path):
383-
xy_bbox = mercantile.xy_bounds(tile)
384-
if self.dynamic:
385-
if da_visible is not None:
386-
da_tile = self.zvect.get_da(z).sel(
387-
y=slice(xy_bbox.top, xy_bbox.bottom),
388-
x=slice(xy_bbox.left, xy_bbox.right),
389-
)
390-
else:
391-
da_tile = None
404+
if not self.dynamic and not da_visible_computed:
405+
# get DataArray for the visible map
406+
llbbox = mercantile.LngLatBbox(west, south, east, north)
407+
da_visible = self.zvect.get_da_llbbox(llbbox, z)
408+
da_visible_computed = True
409+
if self.dynamic and da_visible is None:
410+
da_tile = None
411+
else:
412+
da_tile = self.zvect.get_da(z)[
413+
y * self.tile_height : (y + 1) * self.tile_height,
414+
x * self.tile_width : (x + 1) * self.tile_width,
415+
]
392416
if da_tile is None:
393417
write_image(path, None)
394418
else:
395-
da_tile /= self.max_value
419+
# normalize
420+
da_tile = (da_tile - vmin) / (vmax - vmin)
396421
da_tile = self.colormap(da_tile)
397422
write_image(path, da_tile * 255)
398423

399424
if self.dynamic:
400425
self.layer.redraw()
426+
401427
self.m.remove_control(self.spinner_control)
402428

403429
def _get_raster_tiles(self, change=None):

0 commit comments

Comments
 (0)