Skip to content

Commit 15249cc

Browse files
committed
More optimizations, change API
1 parent 792d38e commit 15249cc

File tree

4 files changed

+88
-77
lines changed

4 files changed

+88
-77
lines changed

examples/vector.ipynb

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
"metadata": {},
88
"outputs": [],
99
"source": [
10+
"from functools import partial\n",
11+
"from geocube.rasterize import rasterize_image\n",
12+
"from rasterio.enums import MergeAlg\n",
1013
"import geopandas as gpd\n",
11-
"from ipyleaflet import LayersControl, Map, WidgetControl, basemaps\n",
14+
"from ipyleaflet import LocalTileLayer, LayersControl, Map, WidgetControl, basemaps\n",
1215
"from ipywidgets import FloatSlider\n",
1316
"import xarray_leaflet\n",
1417
"import matplotlib.pyplot as plt"
@@ -32,8 +35,7 @@
3235
"metadata": {},
3336
"outputs": [],
3437
"source": [
35-
"df = gpd.read_file(\"bldg_footprints.shp\")\n",
36-
"df[\"mask\"] = 1"
38+
"df = gpd.read_file(\"bldg_footprints.shp\")"
3739
]
3840
},
3941
{
@@ -54,7 +56,9 @@
5456
"metadata": {},
5557
"outputs": [],
5658
"source": [
57-
"l = df.leaflet.plot(m, measurement=\"mask\", colormap=plt.cm.inferno)"
59+
"rasterize_function = partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=False)\n",
60+
"layer = partial(LocalTileLayer, max_zoom=20)\n",
61+
"l = df.leaflet.plot(m, measurement=\"Height\", layer=layer, dynamic=False, rasterize_function=rasterize_function, colormap=plt.cm.viridis)"
5862
]
5963
},
6064
{
@@ -94,7 +98,7 @@
9498
"name": "python",
9599
"nbconvert_exporter": "python",
96100
"pygments_lexer": "ipython3",
97-
"version": "3.10.5"
101+
"version": "3.10.6"
98102
}
99103
},
100104
"nbformat": 4,

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ install_requires =
2323
jupyter_server >=0.2.0
2424
rioxarray >=0.0.30
2525
ipyleaflet >=0.13.1
26+
ipywidgets >=7.7.2
2627
pillow >=7
2728
matplotlib >=3
2829
affine >=2
2930
mercantile >=1
3031
ipyspin >=0.1.6
3132
ipyurl >=0.1.3
32-
jupyterlab-widgets >=1.0.0,<2
3333
geocube <1.0.0
3434
pygeos >=0.12,<1.0.0
3535
zarr >=2.0.0,<3.0.0

xarray_leaflet/vector.py

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import json
2+
import math
23
from functools import partial
34
from pathlib import Path
4-
from typing import Optional
5+
from typing import Callable, Optional
56

67
import mercantile
78
import numpy as np
8-
import pyproj
99
import xarray as xr
1010
import zarr
1111
from geocube.api.core import make_geocube
@@ -23,13 +23,17 @@ def __init__(
2323
self,
2424
df: GeoDataFrame,
2525
measurement: str,
26+
rasterize_function: Optional[Callable],
2627
width: int,
2728
height: int,
2829
root_path: str = "",
2930
):
3031
# reproject to Web Mercator
3132
self.df = df.to_crs(epsg=3857)
3233
self.measurement = measurement
34+
self.rasterize_function = rasterize_function or partial(
35+
rasterize_image, merge_alg=MergeAlg.add, all_touched=True
36+
)
3337
self.width = width
3438
self.height = height
3539
self.zzarr = Zzarr(root_path, width, height)
@@ -56,9 +60,7 @@ def get_da_tile(self, tile: mercantile.Tile) -> Optional[xr.DataArray]:
5660
vector_data=df_tile,
5761
resolution=(-dy, dx),
5862
measurements=[self.measurement],
59-
rasterize_function=partial(
60-
rasterize_image, merge_alg=MergeAlg.add, all_touched=True
61-
),
63+
rasterize_function=self.rasterize_function,
6264
fill=0,
6365
geom=geom,
6466
)
@@ -82,15 +84,10 @@ def get_da_llbbox(
8284
self.tiles.append(tile)
8385
if all_none:
8486
return None
85-
project = pyproj.Transformer.from_crs(
86-
pyproj.CRS("EPSG:4326"), pyproj.CRS("EPSG:3857"), always_xy=True
87-
).transform
88-
b = box(*bbox)
89-
polygon = transform(project, b)
90-
left, bottom, right, top = polygon.bounds
91-
return self.zzarr.get_ds(z)["da"].sel(
92-
x=slice(left, right), y=slice(top, bottom)
93-
)
87+
da = self.get_da(z)
88+
y0, x0 = deg2idx(bbox.north, bbox.west, z, self.height, self.width, math.floor)
89+
y1, x1 = deg2idx(bbox.south, bbox.east, z, self.height, self.width, math.ceil)
90+
return da[y0:y1, x0:x1]
9491

9592
def get_da(self, z: int) -> xr.DataArray:
9693
return self.zzarr.get_ds(z)["da"]
@@ -101,7 +98,7 @@ def __init__(self, root_path: str, width: int, height: int):
10198
self.root_path = Path(root_path)
10299
self.width = width
103100
self.height = height
104-
self.ds = {}
101+
self.z = None
105102

106103
def open_zarr(self, mode: str, z: int) -> zarr.Array:
107104
path = self.root_path / str(z)
@@ -114,32 +111,6 @@ def open_zarr(self, mode: str, z: int) -> zarr.Array:
114111
)
115112
if mode == "w":
116113
# write Dataset to zarr
117-
mi, ma = mercantile.minmax(z)
118-
ul = mercantile.xy_bounds(mi, mi, z)
119-
lr = mercantile.xy_bounds(ma, ma, z)
120-
bbox = mercantile.Bbox(ul.left, lr.bottom, lr.right, ul.top)
121-
x = zarr.open(
122-
path / "x",
123-
mode="w",
124-
shape=(2**z * self.width,),
125-
chunks=(2**z * self.width,),
126-
dtype="<f8",
127-
)
128-
x[:] = np.linspace(bbox.left, bbox.right, 2**z * self.width)
129-
x_zattrs = dict(_ARRAY_DIMENSIONS=["x"])
130-
(path / "x" / ".zattrs").write_text(json.dumps(x_zattrs))
131-
y = zarr.open(
132-
path / "y",
133-
mode="w",
134-
shape=(2**z * self.height,),
135-
chunks=(2**z * self.height,),
136-
dtype="<f8",
137-
)
138-
y[:] = np.linspace(bbox.top, bbox.bottom, 2**z * self.height)
139-
x_zarray = json.loads((path / "x" / ".zarray").read_text())
140-
y_zarray = json.loads((path / "y" / ".zarray").read_text())
141-
y_zattrs = dict(_ARRAY_DIMENSIONS=["y"])
142-
(path / "y" / ".zattrs").write_text(json.dumps(y_zattrs))
143114
(path / ".zattrs").write_text(json.dumps(dict()))
144115
zarray = json.loads((path / "da" / ".zarray").read_text())
145116
zattrs = dict(_ARRAY_DIMENSIONS=["y", "x"])
@@ -153,10 +124,6 @@ def open_zarr(self, mode: str, z: int) -> zarr.Array:
153124
".zgroup": zgroup,
154125
"da/.zarray": zarray,
155126
"da/.zattrs": zattrs,
156-
"x/.zarray": x_zarray,
157-
"x/.zattrs": x_zattrs,
158-
"y/.zarray": y_zarray,
159-
"y/.zattrs": y_zattrs,
160127
},
161128
zarr_consolidated_format=1,
162129
)
@@ -172,14 +139,23 @@ def write_to_zarr(self, tile: mercantile.Tile, data: np.ndarray):
172139
mode = "a"
173140
else:
174141
mode = "w"
175-
self.array = self.open_zarr(mode, z)
176-
self.array[
142+
array = self.open_zarr(mode, z)
143+
array[
177144
y * self.height : (y + 1) * self.height, # noqa
178145
x * self.width : (x + 1) * self.width, # noqa
179146
] = data
180147

181148
def get_ds(self, z: int) -> xr.Dataset:
182149
path = self.root_path / str(z)
183-
if z not in self.ds:
184-
self.ds[z] = xr.open_zarr(path)
185-
return self.ds[z]
150+
if z != self.z:
151+
self.ds_z = xr.open_zarr(path)
152+
self.z = z
153+
return self.ds_z
154+
155+
156+
def deg2idx(lat_deg, lon_deg, zoom, height, width, round_fun):
157+
lat_rad = math.radians(lat_deg)
158+
n = 2**zoom
159+
xtile = round_fun(((lon_deg + 180) % 360) / 360 * n * width)
160+
ytile = round_fun((1 - math.asinh(math.tan(lat_rad)) / math.pi) / 2 * n * height)
161+
return ytile, xtile

xarray_leaflet/xarray_leaflet.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ def _map_ready_changed(self, change):
4444
def plot(
4545
self,
4646
m,
47+
*,
48+
# raster or vector options:
49+
get_base_url: Optional[Callable] = None,
50+
dynamic: Optional[bool] = None,
51+
persist: bool = True,
52+
tile_dir=None,
53+
tile_height: int = 256,
54+
tile_width: int = 256,
55+
layer: Callable = LocalTileLayer,
56+
# raster-only options:
4757
x_dim="x",
4858
y_dim="y",
4959
fit_bounds=True,
@@ -54,15 +64,11 @@ def plot(
5464
transform3=passthrough,
5565
colormap=None,
5666
colorbar_position="topright",
57-
persist=True,
58-
dynamic=False,
59-
tile_dir=None,
60-
tile_height=256,
61-
tile_width=256,
6267
resampling=Resampling.nearest,
63-
get_base_url=None,
68+
# vector-only options:
6469
measurement: Optional[str] = None,
6570
visible_callback: Optional[Callable] = None,
71+
rasterize_function: Optional[Callable] = None,
6672
):
6773
"""Display an array as an interactive map.
6874
@@ -122,30 +128,44 @@ def plot(
122128
- the mercantile.LngLatBbox of the visible region
123129
124130
and returning True if the layer should be shown, False otherwise.
131+
rasterize_function: callable, optional
132+
A callable passed to make_geocube. Defaults to:
133+
partial(rasterize_image, merge_alg=MergeAlg.add, all_touched=True)
125134
126135
Returns
127136
-------
128137
layer : ipyleaflet.LocalTileLayer
129138
A handler to the layer that is added to the map.
130139
"""
131140

132-
self.layer = LocalTileLayer()
141+
self.layer = layer()
133142

134143
if self.is_vector:
135144
# source is a GeoDataFrame (vector)
136145
self.visible_callback = visible_callback
137146
if measurement is None:
138147
raise RuntimeError("You must provide a 'measurement'.")
148+
if dynamic is None:
149+
dynamic = True
150+
if not dynamic:
151+
self.vmin = self._df[measurement].min()
152+
self.vmax = self._df[measurement].max()
139153
self.measurement = measurement
140-
dynamic = True
141154
zarr_temp_dir = tempfile.TemporaryDirectory(prefix="xarray_leaflet_zarr_")
142155
self.zvect = Zvect(
143-
self._df, measurement, tile_width, tile_height, zarr_temp_dir.name
156+
self._df,
157+
measurement,
158+
rasterize_function,
159+
tile_width,
160+
tile_height,
161+
zarr_temp_dir.name,
144162
)
145163
if colormap is None:
146164
colormap = plt.cm.viridis
147165
else:
148166
# source is a DataArray (raster)
167+
if dynamic is None:
168+
dynamic = False
149169
if "proj4def" in m.crs:
150170
# it's a custom projection
151171
if dynamic:
@@ -363,6 +383,7 @@ def _get_vector_tiles(self, change=None):
363383
tiles = mercantile.tiles(west, south, east, north, z)
364384

365385
if self.dynamic:
386+
# get DataArray for the visible map
366387
llbbox = mercantile.LngLatBbox(west, south, east, north)
367388
da_visible = self.zvect.get_da_llbbox(llbbox, z)
368389
# check if we must show the layer
@@ -372,32 +393,42 @@ def _get_vector_tiles(self, change=None):
372393
self.m.remove_control(self.spinner_control)
373394
return
374395
if da_visible is None:
375-
self.max_value = 0
396+
vmin = vmax = 0
376397
else:
377-
self.max_value = da_visible.max()
398+
vmin = da_visible.min()
399+
vmax = da_visible.max()
400+
else:
401+
vmin = self.vmin
402+
vmax = self.vmax
403+
da_visible_computed = False
378404

379405
for tile in tiles:
380406
x, y, z = tile
381407
path = f"{self.tile_path}/{z}/{x}/{y}.png"
382408
if self.dynamic or not os.path.exists(path):
383-
xy_bbox = mercantile.xy_bounds(tile)
384-
if self.dynamic:
385-
if da_visible is not None:
386-
da_tile = self.zvect.get_da(z).sel(
387-
y=slice(xy_bbox.top, xy_bbox.bottom),
388-
x=slice(xy_bbox.left, xy_bbox.right),
389-
)
390-
else:
391-
da_tile = None
409+
if not self.dynamic and not da_visible_computed:
410+
# get DataArray for the visible map
411+
llbbox = mercantile.LngLatBbox(west, south, east, north)
412+
da_visible = self.zvect.get_da_llbbox(llbbox, z)
413+
da_visible_computed = True
414+
if self.dynamic and da_visible is None:
415+
da_tile = None
416+
else:
417+
da_tile = self.zvect.get_da(z)[
418+
y * self.tile_height : (y + 1) * self.tile_height,
419+
x * self.tile_width : (x + 1) * self.tile_width,
420+
]
392421
if da_tile is None:
393422
write_image(path, None)
394423
else:
395-
da_tile /= self.max_value
424+
# normalize
425+
da_tile = (da_tile - vmin) / (vmax - vmin)
396426
da_tile = self.colormap(da_tile)
397427
write_image(path, da_tile * 255)
398428

399429
if self.dynamic:
400430
self.layer.redraw()
431+
401432
self.m.remove_control(self.spinner_control)
402433

403434
def _get_raster_tiles(self, change=None):

0 commit comments

Comments
 (0)