Skip to content

Commit

Permalink
Merge pull request #144 from eseglem/CustomSerializer
Browse files Browse the repository at this point in the history
Create a base model and generic serialization.
  • Loading branch information
eseglem authored Jul 21, 2023
2 parents 50a09b3 + c6712a2 commit 2fda6cc
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 141 deletions.
68 changes: 68 additions & 0 deletions geojson_pydantic/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""pydantic BaseModel for GeoJSON objects."""
from __future__ import annotations

from typing import Any, Dict, List, Optional, Set

from pydantic import BaseModel, SerializationInfo, field_validator, model_serializer

from geojson_pydantic.types import BBox


class _GeoJsonBase(BaseModel):
bbox: Optional[BBox] = None

# These fields will not be included when serializing in json mode
# `.model_dump_json()` or `.model_dump(mode="json")`
__geojson_exclude_if_none__: Set[str] = {"bbox"}

@property
def __geo_interface__(self) -> Dict[str, Any]:
"""GeoJSON-like protocol for geo-spatial (GIS) vector data.
ref: https://gist.github.com/sgillies/2217756#__geo_interface
"""
return self.model_dump(mode="json")

@field_validator("bbox")
def validate_bbox(cls, bbox: Optional[BBox]) -> Optional[BBox]:
"""Validate BBox values are ordered correctly."""
# If bbox is None, there is nothing to validate.
if bbox is None:
return None

# A list to store any errors found so we can raise them all at once.
errors: List[str] = []

# Determine where the second position starts. 2 for 2D, 3 for 3D.
offset = len(bbox) // 2

# Check X
if bbox[0] > bbox[offset]:
errors.append(f"Min X ({bbox[0]}) must be <= Max X ({bbox[offset]}).")
# Check Y
if bbox[1] > bbox[1 + offset]:
errors.append(f"Min Y ({bbox[1]}) must be <= Max Y ({bbox[1 + offset]}).")
# If 3D, check Z values.
if offset > 2 and bbox[2] > bbox[2 + offset]:
errors.append(f"Min Z ({bbox[2]}) must be <= Max Z ({bbox[2 + offset]}).")

# Raise any errors found.
if errors:
raise ValueError("Invalid BBox. Error(s): " + " ".join(errors))

return bbox

@model_serializer(when_used="json", mode="wrap")
def clean_model(self, serializer: Any, _info: SerializationInfo) -> Dict[str, Any]:
"""Custom Model serializer to match the GeoJSON specification.
Used to remove fields which are optional but cannot be null values.
"""
# This seems like the best way to have the least amount of unexpected consequences.
# We want to avoid forcing values in `exclude_none` or `exclude_unset` which could
# cause issues or unexpected behavior for downstream users.
data: Dict[str, Any] = serializer(self)
for field in self.__geojson_exclude_if_none__:
if field in data and data[field] is None:
del data[field]
return data
51 changes: 6 additions & 45 deletions geojson_pydantic/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,24 @@

from typing import Any, Dict, Generic, Iterator, List, Literal, Optional, TypeVar, Union

from pydantic import (
BaseModel,
Field,
StrictInt,
StrictStr,
field_validator,
model_serializer,
)

from geojson_pydantic.geo_interface import GeoInterfaceMixin
from pydantic import BaseModel, Field, StrictInt, StrictStr, field_validator

from geojson_pydantic.base import _GeoJsonBase
from geojson_pydantic.geometries import Geometry
from geojson_pydantic.types import BBox, validate_bbox

Props = TypeVar("Props", bound=Union[Dict[str, Any], BaseModel])
Geom = TypeVar("Geom", bound=Geometry)


class Feature(BaseModel, Generic[Geom, Props], GeoInterfaceMixin):
class Feature(_GeoJsonBase, Generic[Geom, Props]):
"""Feature Model"""

type: Literal["Feature"]
geometry: Union[Geom, None] = Field(...)
properties: Union[Props, None] = Field(...)
id: Optional[Union[StrictInt, StrictStr]] = None
bbox: Optional[BBox] = None

_validate_bbox = field_validator("bbox")(validate_bbox)

@model_serializer(when_used="json")
def ser_model(self) -> Dict[str, Any]:
"""Custom Model serializer to match the GeoJSON specification."""
model: Dict[str, Any] = {
"type": self.type,
"geometry": self.geometry,
"properties": self.properties,
}
if self.id is not None:
model["id"] = self.id
if self.bbox:
model["bbox"] = self.bbox

return model
__geojson_exclude_if_none__ = {"bbox", "id"}

@field_validator("geometry", mode="before")
def set_geometry(cls, geometry: Any) -> Any:
Expand All @@ -57,24 +33,11 @@ def set_geometry(cls, geometry: Any) -> Any:
Feat = TypeVar("Feat", bound=Feature)


class FeatureCollection(BaseModel, Generic[Feat], GeoInterfaceMixin):
class FeatureCollection(_GeoJsonBase, Generic[Feat]):
"""FeatureCollection Model"""

type: Literal["FeatureCollection"]
features: List[Feat]
bbox: Optional[BBox] = None

@model_serializer(when_used="json")
def ser_model(self) -> Dict[str, Any]:
"""Custom Model serializer to match the GeoJSON specification."""
model: Dict[str, Any] = {
"type": self.type,
"features": self.features,
}
if self.bbox:
model["bbox"] = self.bbox

return model

def __iter__(self) -> Iterator[Feat]: # type: ignore [override]
"""iterate over features"""
Expand All @@ -87,5 +50,3 @@ def __len__(self) -> int:
def __getitem__(self, index: int) -> Feat:
"""get feature at a given index"""
return self.features[index]

_validate_bbox = field_validator("bbox")(validate_bbox)
23 changes: 0 additions & 23 deletions geojson_pydantic/geo_interface.py

This file was deleted.

42 changes: 5 additions & 37 deletions geojson_pydantic/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@

import abc
import warnings
from typing import Any, Dict, Iterator, List, Literal, Optional, Union
from typing import Any, Iterator, List, Literal, Union

from pydantic import BaseModel, Field, field_validator, model_serializer
from pydantic import Field, field_validator
from typing_extensions import Annotated

from geojson_pydantic.geo_interface import GeoInterfaceMixin
from geojson_pydantic.base import _GeoJsonBase
from geojson_pydantic.types import (
BBox,
LinearRing,
LineStringCoords,
MultiLineStringCoords,
MultiPointCoords,
MultiPolygonCoords,
PolygonCoords,
Position,
validate_bbox,
)


Expand Down Expand Up @@ -72,24 +70,11 @@ def _polygons_wkt_coordinates(
)


class _GeometryBase(BaseModel, abc.ABC, GeoInterfaceMixin):
class _GeometryBase(_GeoJsonBase, abc.ABC):
"""Base class for geometry models"""

type: str
coordinates: Any
bbox: Optional[BBox] = None

@model_serializer(when_used="json")
def ser_model(self) -> Dict[str, Any]:
"""Custom Model serializer to match the GeoJSON specification."""
model: Dict[str, Any] = {
"type": self.type,
"coordinates": self.coordinates,
}
if self.bbox:
model["bbox"] = self.bbox

return model

@abc.abstractmethod
def __wkt_coordinates__(self, coordinates: Any, force_z: bool) -> str:
Expand Down Expand Up @@ -119,8 +104,6 @@ def wkt(self) -> str:

return wkt

_validate_bbox = field_validator("bbox")(validate_bbox)


class Point(_GeometryBase):
"""Point Model"""
Expand Down Expand Up @@ -261,24 +244,11 @@ def check_closure(cls, coordinates: List) -> List:
return coordinates


class GeometryCollection(BaseModel, GeoInterfaceMixin):
class GeometryCollection(_GeoJsonBase):
"""GeometryCollection Model"""

type: Literal["GeometryCollection"]
geometries: List[Geometry]
bbox: Optional[BBox] = None

@model_serializer(when_used="json")
def ser_model(self) -> Dict[str, Any]:
"""Custom Model serializer to match the GeoJSON specification."""
model: Dict[str, Any] = {
"type": self.type,
"geometries": self.geometries,
}
if self.bbox:
model["bbox"] = self.bbox

return model

def __iter__(self) -> Iterator[Geometry]: # type: ignore [override]
"""iterate over geometries"""
Expand Down Expand Up @@ -310,8 +280,6 @@ def wkt(self) -> str:
z = " Z " if "Z" in geometries else " "
return f"{self.type.upper()}{z}{geometries}"

_validate_bbox = field_validator("bbox")(validate_bbox)

@field_validator("geometries")
def check_geometries(cls, geometries: List) -> List:
"""Add warnings for conditions the spec does not explicitly forbid."""
Expand Down
34 changes: 1 addition & 33 deletions geojson_pydantic/types.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,15 @@
"""Types for geojson_pydantic models"""

from typing import List, Optional, Tuple, TypeVar, Union
from typing import List, Tuple, Union

from pydantic import Field
from typing_extensions import Annotated

T = TypeVar("T")

BBox = Union[
Tuple[float, float, float, float], # 2D bbox
Tuple[float, float, float, float, float, float], # 3D bbox
]


def validate_bbox(bbox: Optional[BBox]) -> Optional[BBox]:
"""Validate BBox values are ordered correctly."""
# If bbox is None, there is nothing to validate.
if bbox is None:
return None

# A list to store any errors found so we can raise them all at once.
errors: List[str] = []

# Determine where the second position starts. 2 for 2D, 3 for 3D.
offset = len(bbox) // 2

# Check X
if bbox[0] > bbox[offset]:
errors.append(f"Min X ({bbox[0]}) must be <= Max X ({bbox[offset]}).")
# Check Y
if bbox[1] > bbox[1 + offset]:
errors.append(f"Min Y ({bbox[1]}) must be <= Max Y ({bbox[1 + offset]}).")
# If 3D, check Z values.
if offset > 2 and bbox[2] > bbox[2 + offset]:
errors.append(f"Min Z ({bbox[2]}) must be <= Max Z ({bbox[2 + offset]}).")

# Raise any errors found.
if errors:
raise ValueError("Invalid BBox. Error(s): " + " ".join(errors))

return bbox


Position = Union[Tuple[float, float], Tuple[float, float, float]]

# Coordinate arrays
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,4 @@ ignore = [
]

[tool.ruff.per-file-ignores]
"tests/test_geometries.py" = ["D1"]
"tests/test_features.py" = ["D1"]
"tests/test_package.py" = ["D1"]
"tests/*.py" = ["D1"]
Loading

0 comments on commit 2fda6cc

Please sign in to comment.