Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions lonboard/_geoarrow/sanitize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Remove custom geoarrow.pyarrow types from input geoarrow data
"""
import json
from typing import Tuple

import pyarrow as pa
from pyproj import CRS


def sanitize_table(table: pa.Table) -> pa.Table:
"""
Convert any registered geoarrow.pyarrow extension fields and arrays to plain
metadata
"""
for field_idx in range(len(table.schema)):
field = table.field(field_idx)
column = table.column(field_idx)

if isinstance(field.type, pa.ExtensionType):
assert all(isinstance(chunk, pa.ExtensionArray) for chunk in column.chunks)
new_field, new_column = sanitize_column(field, column)
table = table.set_column(field_idx, new_field, new_column)

return table


def sanitize_column(
field: pa.Field, column: pa.ChunkedArray
) -> Tuple[pa.Field, pa.ChunkedArray]:
"""
Convert a registered geoarrow.pyarrow extension field and column to plain metadata
"""
import geoarrow.pyarrow as gap

extension_metadata = {}
if field.type.crs:
extension_metadata["crs"] = CRS.from_user_input(field.type.crs).to_json()

if field.type.edge_type == gap.EdgeType.SPHERICAL:
extension_metadata["edges"] = "spherical"

metadata = {
"ARROW:extension:name": field.type.extension_name,
}
if extension_metadata:
metadata["ARROW:extension:metadata"] = json.dumps(extension_metadata)

new_field = pa.field(
field.name, field.type.storage_type, nullable=field.nullable, metadata=metadata
)

new_chunks = []
for chunk in column.chunks:
if hasattr(chunk, "storage"):
new_chunks.append(chunk.storage)
else:
new_chunks.append(chunk.cast(new_field.type))

return new_field, pa.chunked_array(new_chunks)
8 changes: 8 additions & 0 deletions lonboard/_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lonboard._geoarrow.ops import reproject_table
from lonboard._geoarrow.ops.bbox import Bbox, total_bounds
from lonboard._geoarrow.ops.centroid import WeightedCentroid, weighted_centroid
from lonboard._geoarrow.sanitize import sanitize_table
from lonboard._serialization import infer_rows_per_chunk
from lonboard._utils import auto_downcast as _auto_downcast
from lonboard._utils import get_geometry_column_index, remove_extension_kwargs
Expand Down Expand Up @@ -231,6 +232,13 @@ class BaseArrowLayer(BaseLayer):
def __init__(
self, *, table: pa.Table, _rows_per_chunk: Optional[int] = None, **kwargs
):
# Check for Arrow PyCapsule Interface
# https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
if not isinstance(table, pa.Table) and hasattr(table, "__arrow_c_stream__"):
table = pa.table(table)

table = sanitize_table(table)

# Reproject table to WGS84 if needed
# Note this must happen before calculating the default viewport
table = reproject_table(table, to_crs=OGC_84)
Expand Down
11 changes: 5 additions & 6 deletions lonboard/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ACCESSOR_SERIALIZATION,
TABLE_SERIALIZATION,
)
from lonboard._utils import get_geometry_column_index


# This is a custom subclass of traitlets.TraitType because its `error` method ignores
Expand Down Expand Up @@ -139,21 +140,19 @@ def __init__(
)

def validate(self, obj: Self, value: Any):
# Check for Arrow PyCapsule Interface
# https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
if not isinstance(value, pa.Table) and hasattr(value, "__arrow_c_stream__"):
value = pa.table(value)

if not isinstance(value, pa.Table):
self.error(obj, value)

allowed_geometry_types = self.metadata.get("allowed_geometry_types")
# No restriction on the allowed geometry types in this table
if not allowed_geometry_types:
return value

geometry_extension_type = value.schema.field("geometry").metadata.get(
geom_col_idx = get_geometry_column_index(value.schema)
geometry_extension_type = value.schema.field(geom_col_idx).metadata.get(
b"ARROW:extension:name"
)

if (
allowed_geometry_types
and geometry_extension_type not in allowed_geometry_types
Expand Down
98 changes: 97 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ black = "^23.10.1"
geoarrow-rust-core = "^0.1.0"
geodatasets = "^2023.12.0"
pyogrio = "^0.7.2"
geoarrow-pyarrow = "^0.1.1"

[tool.poetry.group.docs.dependencies]
mkdocs = "^1.4.3"
Expand Down
14 changes: 14 additions & 0 deletions tests/test_layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import geopandas as gpd
import numpy as np
import pyarrow as pa
import pytest
import shapely
from traitlets import TraitError
Expand Down Expand Up @@ -58,3 +59,16 @@ def test_layer_outside_4326_range():

with pytest.raises(ValueError, match="outside of WGS84 bounds"):
_layer = ScatterplotLayer.from_geopandas(gdf)


def test_layer_from_geoarrow_pyarrow():
ga = pytest.importorskip("geoarrow.pyarrow")

points = gpd.GeoSeries(shapely.points([1, 2], [3, 4]))

# convert to geoarrow.pyarrow Table (currently requires to ensure interleaved
# coordinates manually)
points = ga.with_coord_type(ga.as_geoarrow(points), ga.CoordType.INTERLEAVED)
table = pa.table({"geometry": points})

_layer = ScatterplotLayer(table=table)