diff --git a/geojson_pydantic/base.py b/geojson_pydantic/base.py new file mode 100644 index 0000000..85dfa02 --- /dev/null +++ b/geojson_pydantic/base.py @@ -0,0 +1,68 @@ +"""pydantic BaseModel for GeoJSON objects.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Set + +from pydantic import BaseModel, SerializationInfo, field_validator, model_serializer + +from geojson_pydantic.types import BBox + + +class _GeoJsonBase(BaseModel): + bbox: Optional[BBox] = None + + # These fields will not be included when serializing in json mode + # `.model_dump_json()` or `.model_dump(mode="json")` + __geojson_exclude_if_none__: Set[str] = {"bbox"} + + @property + def __geo_interface__(self) -> Dict[str, Any]: + """GeoJSON-like protocol for geo-spatial (GIS) vector data. + + ref: https://gist.github.com/sgillies/2217756#__geo_interface + """ + return self.model_dump(mode="json") + + @field_validator("bbox") + def validate_bbox(cls, bbox: Optional[BBox]) -> Optional[BBox]: + """Validate BBox values are ordered correctly.""" + # If bbox is None, there is nothing to validate. + if bbox is None: + return None + + # A list to store any errors found so we can raise them all at once. + errors: List[str] = [] + + # Determine where the second position starts. 2 for 2D, 3 for 3D. + offset = len(bbox) // 2 + + # Check X + if bbox[0] > bbox[offset]: + errors.append(f"Min X ({bbox[0]}) must be <= Max X ({bbox[offset]}).") + # Check Y + if bbox[1] > bbox[1 + offset]: + errors.append(f"Min Y ({bbox[1]}) must be <= Max Y ({bbox[1 + offset]}).") + # If 3D, check Z values. + if offset > 2 and bbox[2] > bbox[2 + offset]: + errors.append(f"Min Z ({bbox[2]}) must be <= Max Z ({bbox[2 + offset]}).") + + # Raise any errors found. + if errors: + raise ValueError("Invalid BBox. Error(s): " + " ".join(errors)) + + return bbox + + @model_serializer(when_used="json", mode="wrap") + def clean_model(self, serializer: Any, _info: SerializationInfo) -> Dict[str, Any]: + """Custom Model serializer to match the GeoJSON specification. + + Used to remove fields which are optional but cannot be null values. + """ + # This seems like the best way to have the least amount of unexpected consequences. + # We want to avoid forcing values in `exclude_none` or `exclude_unset` which could + # cause issues or unexpected behavior for downstream users. + data: Dict[str, Any] = serializer(self) + for field in self.__geojson_exclude_if_none__: + if field in data and data[field] is None: + del data[field] + return data diff --git a/geojson_pydantic/features.py b/geojson_pydantic/features.py index e613c65..738cdce 100644 --- a/geojson_pydantic/features.py +++ b/geojson_pydantic/features.py @@ -2,48 +2,24 @@ from typing import Any, Dict, Generic, Iterator, List, Literal, Optional, TypeVar, Union -from pydantic import ( - BaseModel, - Field, - StrictInt, - StrictStr, - field_validator, - model_serializer, -) - -from geojson_pydantic.geo_interface import GeoInterfaceMixin +from pydantic import BaseModel, Field, StrictInt, StrictStr, field_validator + +from geojson_pydantic.base import _GeoJsonBase from geojson_pydantic.geometries import Geometry -from geojson_pydantic.types import BBox, validate_bbox Props = TypeVar("Props", bound=Union[Dict[str, Any], BaseModel]) Geom = TypeVar("Geom", bound=Geometry) -class Feature(BaseModel, Generic[Geom, Props], GeoInterfaceMixin): +class Feature(_GeoJsonBase, Generic[Geom, Props]): """Feature Model""" type: Literal["Feature"] geometry: Union[Geom, None] = Field(...) properties: Union[Props, None] = Field(...) id: Optional[Union[StrictInt, StrictStr]] = None - bbox: Optional[BBox] = None - - _validate_bbox = field_validator("bbox")(validate_bbox) - - @model_serializer(when_used="json") - def ser_model(self) -> Dict[str, Any]: - """Custom Model serializer to match the GeoJSON specification.""" - model: Dict[str, Any] = { - "type": self.type, - "geometry": self.geometry, - "properties": self.properties, - } - if self.id is not None: - model["id"] = self.id - if self.bbox: - model["bbox"] = self.bbox - return model + __geojson_exclude_if_none__ = {"bbox", "id"} @field_validator("geometry", mode="before") def set_geometry(cls, geometry: Any) -> Any: @@ -57,24 +33,11 @@ def set_geometry(cls, geometry: Any) -> Any: Feat = TypeVar("Feat", bound=Feature) -class FeatureCollection(BaseModel, Generic[Feat], GeoInterfaceMixin): +class FeatureCollection(_GeoJsonBase, Generic[Feat]): """FeatureCollection Model""" type: Literal["FeatureCollection"] features: List[Feat] - bbox: Optional[BBox] = None - - @model_serializer(when_used="json") - def ser_model(self) -> Dict[str, Any]: - """Custom Model serializer to match the GeoJSON specification.""" - model: Dict[str, Any] = { - "type": self.type, - "features": self.features, - } - if self.bbox: - model["bbox"] = self.bbox - - return model def __iter__(self) -> Iterator[Feat]: # type: ignore [override] """iterate over features""" @@ -87,5 +50,3 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> Feat: """get feature at a given index""" return self.features[index] - - _validate_bbox = field_validator("bbox")(validate_bbox) diff --git a/geojson_pydantic/geo_interface.py b/geojson_pydantic/geo_interface.py deleted file mode 100644 index ca8edc8..0000000 --- a/geojson_pydantic/geo_interface.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Mixin for __geo_interface__ on GeoJSON objects.""" - -from typing import Any, Dict, Protocol - - -class _ModelDumpProtocol(Protocol): - """Protocol for use as the type of self in the mixin.""" - - def model_dump(self, *, exclude_unset: bool, **args: Any) -> Dict[str, Any]: - """Define a dict function so the mixin knows it exists.""" - ... - - -class GeoInterfaceMixin: - """Mixin for __geo_interface__ on GeoJSON objects.""" - - @property - def __geo_interface__(self: _ModelDumpProtocol) -> Dict[str, Any]: - """GeoJSON-like protocol for geo-spatial (GIS) vector data. - - ref: https://gist.github.com/sgillies/2217756#__geo_interface - """ - return self.model_dump(exclude_unset=True) diff --git a/geojson_pydantic/geometries.py b/geojson_pydantic/geometries.py index 2d6787c..ecd2d3a 100644 --- a/geojson_pydantic/geometries.py +++ b/geojson_pydantic/geometries.py @@ -3,14 +3,13 @@ import abc import warnings -from typing import Any, Dict, Iterator, List, Literal, Optional, Union +from typing import Any, Iterator, List, Literal, Union -from pydantic import BaseModel, Field, field_validator, model_serializer +from pydantic import Field, field_validator from typing_extensions import Annotated -from geojson_pydantic.geo_interface import GeoInterfaceMixin +from geojson_pydantic.base import _GeoJsonBase from geojson_pydantic.types import ( - BBox, LinearRing, LineStringCoords, MultiLineStringCoords, @@ -18,7 +17,6 @@ MultiPolygonCoords, PolygonCoords, Position, - validate_bbox, ) @@ -72,24 +70,11 @@ def _polygons_wkt_coordinates( ) -class _GeometryBase(BaseModel, abc.ABC, GeoInterfaceMixin): +class _GeometryBase(_GeoJsonBase, abc.ABC): """Base class for geometry models""" type: str coordinates: Any - bbox: Optional[BBox] = None - - @model_serializer(when_used="json") - def ser_model(self) -> Dict[str, Any]: - """Custom Model serializer to match the GeoJSON specification.""" - model: Dict[str, Any] = { - "type": self.type, - "coordinates": self.coordinates, - } - if self.bbox: - model["bbox"] = self.bbox - - return model @abc.abstractmethod def __wkt_coordinates__(self, coordinates: Any, force_z: bool) -> str: @@ -119,8 +104,6 @@ def wkt(self) -> str: return wkt - _validate_bbox = field_validator("bbox")(validate_bbox) - class Point(_GeometryBase): """Point Model""" @@ -261,24 +244,11 @@ def check_closure(cls, coordinates: List) -> List: return coordinates -class GeometryCollection(BaseModel, GeoInterfaceMixin): +class GeometryCollection(_GeoJsonBase): """GeometryCollection Model""" type: Literal["GeometryCollection"] geometries: List[Geometry] - bbox: Optional[BBox] = None - - @model_serializer(when_used="json") - def ser_model(self) -> Dict[str, Any]: - """Custom Model serializer to match the GeoJSON specification.""" - model: Dict[str, Any] = { - "type": self.type, - "geometries": self.geometries, - } - if self.bbox: - model["bbox"] = self.bbox - - return model def __iter__(self) -> Iterator[Geometry]: # type: ignore [override] """iterate over geometries""" @@ -310,8 +280,6 @@ def wkt(self) -> str: z = " Z " if "Z" in geometries else " " return f"{self.type.upper()}{z}{geometries}" - _validate_bbox = field_validator("bbox")(validate_bbox) - @field_validator("geometries") def check_geometries(cls, geometries: List) -> List: """Add warnings for conditions the spec does not explicitly forbid.""" diff --git a/geojson_pydantic/types.py b/geojson_pydantic/types.py index 509df1e..d02c244 100644 --- a/geojson_pydantic/types.py +++ b/geojson_pydantic/types.py @@ -1,47 +1,15 @@ """Types for geojson_pydantic models""" -from typing import List, Optional, Tuple, TypeVar, Union +from typing import List, Tuple, Union from pydantic import Field from typing_extensions import Annotated -T = TypeVar("T") - BBox = Union[ Tuple[float, float, float, float], # 2D bbox Tuple[float, float, float, float, float, float], # 3D bbox ] - -def validate_bbox(bbox: Optional[BBox]) -> Optional[BBox]: - """Validate BBox values are ordered correctly.""" - # If bbox is None, there is nothing to validate. - if bbox is None: - return None - - # A list to store any errors found so we can raise them all at once. - errors: List[str] = [] - - # Determine where the second position starts. 2 for 2D, 3 for 3D. - offset = len(bbox) // 2 - - # Check X - if bbox[0] > bbox[offset]: - errors.append(f"Min X ({bbox[0]}) must be <= Max X ({bbox[offset]}).") - # Check Y - if bbox[1] > bbox[1 + offset]: - errors.append(f"Min Y ({bbox[1]}) must be <= Max Y ({bbox[1 + offset]}).") - # If 3D, check Z values. - if offset > 2 and bbox[2] > bbox[2 + offset]: - errors.append(f"Min Z ({bbox[2]}) must be <= Max Z ({bbox[2 + offset]}).") - - # Raise any errors found. - if errors: - raise ValueError("Invalid BBox. Error(s): " + " ".join(errors)) - - return bbox - - Position = Union[Tuple[float, float], Tuple[float, float, float]] # Coordinate arrays diff --git a/pyproject.toml b/pyproject.toml index c04b6c5..d74b373 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,4 @@ ignore = [ ] [tool.ruff.per-file-ignores] -"tests/test_geometries.py" = ["D1"] -"tests/test_features.py" = ["D1"] -"tests/test_package.py" = ["D1"] +"tests/*.py" = ["D1"] diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..971fb6d --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,98 @@ +from typing import Set, Tuple, Union + +import pytest +from pydantic import Field, ValidationError + +from geojson_pydantic.base import _GeoJsonBase + +BBOXES = ( + (100, 0, 0, 0), # Incorrect Order + (0, 100, 0, 0), + (0, 0, 100, 0, 0, 0), + (0, "a", 0, 0), # Invalid Type +) + + +@pytest.mark.parametrize("values", BBOXES) +def test_bbox_validation(values: Tuple) -> None: + # Ensure validation is happening correctly on the base model + with pytest.raises(ValidationError): + _GeoJsonBase(bbox=values) + + +@pytest.mark.parametrize("values", BBOXES) +def test_bbox_validation_subclass(values: Tuple) -> None: + # Ensure validation is happening correctly when subclassed + class TestClass(_GeoJsonBase): + test_field: str = None + + with pytest.raises(ValidationError): + TestClass(bbox=values) + + +@pytest.mark.parametrize("values", BBOXES) +def test_bbox_validation_field(values: Tuple) -> None: + # Ensure validation is happening correctly when used as a field + class TestClass(_GeoJsonBase): + geo: _GeoJsonBase + + with pytest.raises(ValidationError): + TestClass(geo={"bbox": values}) + + +def test_exclude_if_none() -> None: + model = _GeoJsonBase() + # Included in default dump + assert model.model_dump() == {"bbox": None} + # Not included when in json mode + assert model.model_dump(mode="json") == {} + # And not included in the output json string. + assert model.model_dump_json() == "{}" + + # Included if it has a value + model = _GeoJsonBase(bbox=(0, 0, 0, 0)) + assert model.model_dump() == {"bbox": (0, 0, 0, 0)} + assert model.model_dump(mode="json") == {"bbox": [0, 0, 0, 0]} + assert model.model_dump_json() == '{"bbox":[0.0,0.0,0.0,0.0]}' + + # Since `validate_assignment` is not set, you can do this without an error. + # The solution should handle this and not just look at if the field is set. + model.bbox = None + assert model.model_dump() == {"bbox": None} + assert model.model_dump(mode="json") == {} + assert model.model_dump_json() == "{}" + + +def test_exclude_if_none_subclass() -> None: + # Create a subclass that adds a field, and ensure it works. + class TestClass(_GeoJsonBase): + test_field: str = None + __geojson_exclude_if_none__: Set[str] = {"bbox", "test_field"} + + assert TestClass().model_dump_json() == "{}" + assert TestClass(test_field="a").model_dump_json() == '{"test_field":"a"}' + assert ( + TestClass(bbox=(0, 0, 0, 0)).model_dump_json() == '{"bbox":[0.0,0.0,0.0,0.0]}' + ) + + +def test_exclude_if_none_kwargs() -> None: + # Create a subclass that adds fields and dumps it with kwargs to ensure + # the kwargs are still being utilized. + class TestClass(_GeoJsonBase): + test_field: str = Field(default="test", alias="field") + null_field: Union[str, None] = None + + model = TestClass(bbox=(0, 0, 0, 0)) + assert ( + model.model_dump_json(indent=2, by_alias=True, exclude_none=True) + == """{ + "bbox": [ + 0.0, + 0.0, + 0.0, + 0.0 + ], + "field": "test" +}""" + )