-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
1,470 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,5 @@ | |
'rest_framework_simplejwt', | ||
'django_filters', | ||
'rest_framework_recursive', | ||
'rest_framework_gis', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
from django.contrib.gis.db import models | ||
from rest_framework.utils.model_meta import get_field_info | ||
|
||
from drf_spectacular.drainage import warn | ||
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiSerializerFieldExtension | ||
from drf_spectacular.plumbing import ResolvedComponent, build_array_type, build_object_type, get_doc | ||
|
||
|
||
def build_point_schema(): | ||
return { | ||
"type": "array", | ||
"items": {"type": "number", "format": "float"}, | ||
"example": [12.9721, 77.5933], | ||
"minItems": 2, | ||
"maxItems": 3, | ||
} | ||
|
||
|
||
def build_linestring_schema(): | ||
return { | ||
"type": "array", | ||
"items": build_point_schema(), | ||
"example": [[22.4707, 70.0577], [12.9721, 77.5933]], | ||
"minItems": 2, | ||
} | ||
|
||
|
||
def build_polygon_schema(): | ||
return { | ||
"type": "array", | ||
"items": {**build_linestring_schema(), "minItems": 4}, | ||
"example": [ | ||
[ | ||
[0.0, 0.0], | ||
[0.0, 50.0], | ||
[50.0, 50.0], | ||
[50.0, 0.0], | ||
[0.0, 0.0], | ||
], | ||
] | ||
} | ||
|
||
|
||
def build_geo_container_schema(name, coords): | ||
return build_object_type( | ||
properties={ | ||
"type": {"type": "string", "enum": [name]}, | ||
"coordinates": coords, | ||
} | ||
) | ||
|
||
|
||
def build_point_geo_schema(): | ||
return build_geo_container_schema("Point", build_point_schema()) | ||
|
||
|
||
def build_linestring_geo_schema(): | ||
return build_geo_container_schema("LineString", build_linestring_schema()) | ||
|
||
|
||
def build_polygon_geo_schema(): | ||
return build_geo_container_schema("Polygon", build_polygon_schema()) | ||
|
||
|
||
def build_geometry_geo_schema(): | ||
return { | ||
'oneOf': [ | ||
build_point_geo_schema(), | ||
build_linestring_geo_schema(), | ||
build_polygon_geo_schema(), | ||
] | ||
} | ||
|
||
|
||
def build_bbox_schema(): | ||
return { | ||
"type": "array", | ||
"items": {"type": "number"}, | ||
"minItems": 4, | ||
"maxItems": 4, | ||
"example": [12.9721, 77.5933, 12.9721, 77.5933], | ||
} | ||
|
||
|
||
def build_geo_schema(model_field): | ||
if isinstance(model_field, models.PointField): | ||
return build_point_geo_schema() | ||
elif isinstance(model_field, models.LineStringField): | ||
return build_linestring_geo_schema() | ||
elif isinstance(model_field, models.PolygonField): | ||
return build_polygon_geo_schema() | ||
elif isinstance(model_field, models.MultiPointField): | ||
return build_geo_container_schema( | ||
"MultiPoint", build_array_type(build_point_schema()) | ||
) | ||
elif isinstance(model_field, models.MultiLineStringField): | ||
return build_geo_container_schema( | ||
"MultiLineString", build_array_type(build_linestring_schema()) | ||
) | ||
elif isinstance(model_field, models.MultiPolygonField): | ||
return build_geo_container_schema( | ||
"MultiPolygon", build_array_type(build_polygon_schema()) | ||
) | ||
elif isinstance(model_field, models.GeometryCollectionField): | ||
return build_geo_container_schema( | ||
"GeometryCollection", build_array_type(build_geometry_geo_schema()) | ||
) | ||
elif isinstance(model_field, models.GeometryField): | ||
return build_geometry_geo_schema() | ||
else: | ||
warn("Encountered unknown GIS geometry field") | ||
return {} | ||
|
||
|
||
def map_geo_field(serializer, geo_field_name): | ||
from rest_framework_gis.fields import GeometrySerializerMethodField | ||
|
||
field = serializer.fields[geo_field_name] | ||
if isinstance(field, GeometrySerializerMethodField): | ||
warn("Geometry generation for GeometrySerializerMethodField is not supported.") | ||
return {} | ||
model_field = get_field_info(serializer.Meta.model).fields[geo_field_name] | ||
return build_geo_schema(model_field) | ||
|
||
|
||
def _inject_enum_collision_fix(collection): | ||
from drf_spectacular.settings import spectacular_settings | ||
if not collection and 'GisFeatureEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES: | ||
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureEnum'] = ('Feature',) | ||
if collection and 'GisFeatureCollectionEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES: | ||
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureCollectionEnum'] = ('FeatureCollection',) | ||
|
||
|
||
class GeoFeatureModelSerializerExtension(OpenApiSerializerExtension): | ||
target_class = 'rest_framework_gis.serializers.GeoFeatureModelSerializer' | ||
match_subclasses = True | ||
|
||
def map_serializer(self, auto_schema, direction): | ||
_inject_enum_collision_fix(collection=False) | ||
|
||
base_schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True) | ||
return self.map_geo_feature_model_serializer(self.target, base_schema) | ||
|
||
def map_geo_feature_model_serializer(self, serializer, base_schema): | ||
from rest_framework_gis.serializers import GeoFeatureModelSerializer | ||
|
||
geo_properties = { | ||
"type": {"type": "string", "enum": ["Feature"]} | ||
} | ||
if serializer.Meta.id_field: | ||
geo_properties["id"] = base_schema["properties"].pop(serializer.Meta.id_field) | ||
|
||
geo_properties["geometry"] = map_geo_field(serializer, serializer.Meta.geo_field) | ||
base_schema["properties"].pop(serializer.Meta.geo_field) | ||
|
||
if serializer.Meta.auto_bbox or serializer.Meta.bbox_geo_field: | ||
geo_properties["bbox"] = build_bbox_schema() | ||
base_schema["properties"].pop(serializer.Meta.bbox_geo_field, None) | ||
|
||
# only expose if description comes from the user | ||
description = base_schema.pop('description', None) | ||
if description == get_doc(GeoFeatureModelSerializer): | ||
description = None | ||
|
||
# ignore this aspect for now | ||
base_schema.pop('required', None) | ||
|
||
# nest remaining fields under property "properties" | ||
geo_properties["properties"] = base_schema | ||
|
||
return build_object_type( | ||
properties=geo_properties, | ||
description=description, | ||
) | ||
|
||
|
||
class GeoFeatureModelListSerializerExtension(OpenApiSerializerExtension): | ||
target_class = 'rest_framework_gis.serializers.GeoFeatureModelListSerializer' | ||
|
||
def map_serializer(self, auto_schema, direction): | ||
_inject_enum_collision_fix(collection=True) | ||
|
||
# build/retrieve feature component generated by GeoFeatureModelSerializerExtension. | ||
# wrap the ref in the special list structure and build another component based on that. | ||
feature_component = auto_schema.resolve_serializer(self.target.child, direction) | ||
collection_schema = build_object_type( | ||
properties={ | ||
"type": {"type": "string", "enum": ["FeatureCollection"]}, | ||
"features": build_array_type(feature_component.ref) | ||
} | ||
) | ||
list_component = ResolvedComponent( | ||
name=f'{feature_component.name}List', | ||
type=ResolvedComponent.SCHEMA, | ||
object=self.target.child, | ||
schema=collection_schema | ||
) | ||
auto_schema.registry.register_on_missing(list_component) | ||
return list_component.ref | ||
|
||
|
||
class GeometryFieldExtension(OpenApiSerializerFieldExtension): | ||
target_class = 'rest_framework_gis.fields.GeometryField' | ||
match_subclasses = True | ||
|
||
def map_serializer_field(self, auto_schema, direction): | ||
# running this extension for GeoFeatureModelSerializer's geo_field is superfluous | ||
# as above extension already handles that individually. We run it anyway because | ||
# robustly checking the proper condition is harder. | ||
try: | ||
model = self.target.parent.Meta.model | ||
model_field = get_field_info(model).fields[self.target.source] | ||
return build_geo_schema(model_field) | ||
except: # noqa: E722 | ||
warn(f'Encountered an issue resolving field {self.target}. defaulting to generic object.') | ||
return {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from unittest import mock | ||
|
||
import pytest | ||
from django.db import models | ||
from rest_framework import __version__ as DRF_VERSION # type: ignore[attr-defined] | ||
from rest_framework import mixins, routers, serializers, viewsets | ||
|
||
from drf_spectacular.utils import extend_schema_serializer | ||
from tests import assert_schema, generate_schema | ||
|
||
|
||
@pytest.mark.contrib('rest_framework_gis') | ||
@pytest.mark.skipif(DRF_VERSION < '3.12', reason='DRF pagination schema broken') | ||
@mock.patch('drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES', {}) | ||
def test_rest_framework_gis(no_warnings, clear_caches): | ||
from django.contrib.gis.db.models import ( | ||
GeometryCollectionField, GeometryField, LineStringField, MultiLineStringField, | ||
MultiPointField, MultiPolygonField, PointField, PolygonField, | ||
) | ||
from rest_framework_gis.pagination import GeoJsonPagination | ||
from rest_framework_gis.serializers import GeoFeatureModelSerializer | ||
|
||
class GeoModel(models.Model): | ||
field_random1 = models.CharField(max_length=32) | ||
field_random2 = models.IntegerField() | ||
field_gis_plain = PointField() | ||
|
||
field_polygon = PolygonField() | ||
field_point = PointField() | ||
field_linestring = LineStringField() | ||
field_geometry = GeometryField() | ||
field_multipolygon = MultiPolygonField() | ||
field_multipoint = MultiPointField() | ||
field_multilinestring = MultiLineStringField() | ||
field_geometrycollection = GeometryCollectionField() | ||
|
||
router = routers.SimpleRouter() | ||
|
||
# all GIS fields as GeoJSON in singular and list form | ||
fields = [ | ||
'Point', 'Polygon', 'Linestring', 'Geometry', | ||
'Multipoint', 'Multipolygon', 'Multilinestring', 'Geometrycollection' | ||
] | ||
for name in fields: | ||
@extend_schema_serializer(component_name=name) | ||
class XSerializer(GeoFeatureModelSerializer): | ||
class Meta: | ||
model = GeoModel | ||
geo_field = f'field_{name.lower()}' | ||
auto_bbox = name == 'Polygon' | ||
fields = ['id', 'field_random1', 'field_random2', 'field_gis_plain'] | ||
|
||
class XViewset(mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet): | ||
serializer_class = XSerializer | ||
queryset = GeoModel.objects.none() | ||
|
||
router.register(name.lower(), XViewset, basename=name) | ||
|
||
# plain serializer with GIS fields but without restructured container object | ||
class PlainSerializer(serializers.ModelSerializer): | ||
class Meta: | ||
model = GeoModel | ||
fields = ['id', 'field_random1', 'field_random2', 'field_gis_plain'] | ||
|
||
class PlainViewset(mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet): | ||
serializer_class = PlainSerializer | ||
queryset = GeoModel.objects.none() | ||
|
||
router.register('plain', PlainViewset, basename='plain') | ||
|
||
# GIS specific pagination | ||
class PlainViewset(mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet): | ||
serializer_class = PlainSerializer | ||
queryset = GeoModel.objects.none() | ||
pagination_class = GeoJsonPagination | ||
|
||
router.register('paginated', PlainViewset, basename='paginated') | ||
|
||
assert_schema( | ||
generate_schema(None, patterns=router.urls), | ||
'tests/contrib/test_rest_framework_gis.yml' | ||
) |
Oops, something went wrong.