Skip to content

Commit

Permalink
add support for rest_framework_gis
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed May 24, 2022
1 parent 68e2b1b commit b51575c
Show file tree
Hide file tree
Showing 9 changed files with 1,470 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ jobs:
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.setup.python-version }}
- name: Install system dependencies
run: sudo apt-get install -y gdal-bin libsqlite3-mod-spatialite
- name: Install tox
run: pip install tox
- name: Run Tox
Expand Down
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Features
- `drf-nested-routers <https://github.com/alanjds/drf-nested-routers>`_
- `djangorestframework-recursive <https://github.com/heywbj/django-rest-framework-recursive>`_
- `djangorestframework-dataclasses <https://github.com/oxan/djangorestframework-dataclasses>`_
- `django-rest-framework-gis <https://github.com/openwisp/django-rest-framework-gis>`_


For more information visit the `documentation <https://drf-spectacular.readthedocs.io/>`_.
Expand Down
1 change: 1 addition & 0 deletions drf_spectacular/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
'rest_framework_simplejwt',
'django_filters',
'rest_framework_recursive',
'rest_framework_gis',
]
216 changes: 216 additions & 0 deletions drf_spectacular/contrib/rest_framework_gis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from django.contrib.gis.db import models
from rest_framework.utils.model_meta import get_field_info

from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import ResolvedComponent, build_array_type, build_object_type, get_doc


def build_point_schema():
return {
"type": "array",
"items": {"type": "number", "format": "float"},
"example": [12.9721, 77.5933],
"minItems": 2,
"maxItems": 3,
}


def build_linestring_schema():
return {
"type": "array",
"items": build_point_schema(),
"example": [[22.4707, 70.0577], [12.9721, 77.5933]],
"minItems": 2,
}


def build_polygon_schema():
return {
"type": "array",
"items": {**build_linestring_schema(), "minItems": 4},
"example": [
[
[0.0, 0.0],
[0.0, 50.0],
[50.0, 50.0],
[50.0, 0.0],
[0.0, 0.0],
],
]
}


def build_geo_container_schema(name, coords):
return build_object_type(
properties={
"type": {"type": "string", "enum": [name]},
"coordinates": coords,
}
)


def build_point_geo_schema():
return build_geo_container_schema("Point", build_point_schema())


def build_linestring_geo_schema():
return build_geo_container_schema("LineString", build_linestring_schema())


def build_polygon_geo_schema():
return build_geo_container_schema("Polygon", build_polygon_schema())


def build_geometry_geo_schema():
return {
'oneOf': [
build_point_geo_schema(),
build_linestring_geo_schema(),
build_polygon_geo_schema(),
]
}


def build_bbox_schema():
return {
"type": "array",
"items": {"type": "number"},
"minItems": 4,
"maxItems": 4,
"example": [12.9721, 77.5933, 12.9721, 77.5933],
}


def build_geo_schema(model_field):
if isinstance(model_field, models.PointField):
return build_point_geo_schema()
elif isinstance(model_field, models.LineStringField):
return build_linestring_geo_schema()
elif isinstance(model_field, models.PolygonField):
return build_polygon_geo_schema()
elif isinstance(model_field, models.MultiPointField):
return build_geo_container_schema(
"MultiPoint", build_array_type(build_point_schema())
)
elif isinstance(model_field, models.MultiLineStringField):
return build_geo_container_schema(
"MultiLineString", build_array_type(build_linestring_schema())
)
elif isinstance(model_field, models.MultiPolygonField):
return build_geo_container_schema(
"MultiPolygon", build_array_type(build_polygon_schema())
)
elif isinstance(model_field, models.GeometryCollectionField):
return build_geo_container_schema(
"GeometryCollection", build_array_type(build_geometry_geo_schema())
)
elif isinstance(model_field, models.GeometryField):
return build_geometry_geo_schema()
else:
warn("Encountered unknown GIS geometry field")
return {}


def map_geo_field(serializer, geo_field_name):
from rest_framework_gis.fields import GeometrySerializerMethodField

field = serializer.fields[geo_field_name]
if isinstance(field, GeometrySerializerMethodField):
warn("Geometry generation for GeometrySerializerMethodField is not supported.")
return {}
model_field = get_field_info(serializer.Meta.model).fields[geo_field_name]
return build_geo_schema(model_field)


def _inject_enum_collision_fix(collection):
from drf_spectacular.settings import spectacular_settings
if not collection and 'GisFeatureEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureEnum'] = ('Feature',)
if collection and 'GisFeatureCollectionEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureCollectionEnum'] = ('FeatureCollection',)


class GeoFeatureModelSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_gis.serializers.GeoFeatureModelSerializer'
match_subclasses = True

def map_serializer(self, auto_schema, direction):
_inject_enum_collision_fix(collection=False)

base_schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
return self.map_geo_feature_model_serializer(self.target, base_schema)

def map_geo_feature_model_serializer(self, serializer, base_schema):
from rest_framework_gis.serializers import GeoFeatureModelSerializer

geo_properties = {
"type": {"type": "string", "enum": ["Feature"]}
}
if serializer.Meta.id_field:
geo_properties["id"] = base_schema["properties"].pop(serializer.Meta.id_field)

geo_properties["geometry"] = map_geo_field(serializer, serializer.Meta.geo_field)
base_schema["properties"].pop(serializer.Meta.geo_field)

if serializer.Meta.auto_bbox or serializer.Meta.bbox_geo_field:
geo_properties["bbox"] = build_bbox_schema()
base_schema["properties"].pop(serializer.Meta.bbox_geo_field, None)

# only expose if description comes from the user
description = base_schema.pop('description', None)
if description == get_doc(GeoFeatureModelSerializer):
description = None

# ignore this aspect for now
base_schema.pop('required', None)

# nest remaining fields under property "properties"
geo_properties["properties"] = base_schema

return build_object_type(
properties=geo_properties,
description=description,
)


class GeoFeatureModelListSerializerExtension(OpenApiSerializerExtension):
target_class = 'rest_framework_gis.serializers.GeoFeatureModelListSerializer'

def map_serializer(self, auto_schema, direction):
_inject_enum_collision_fix(collection=True)

# build/retrieve feature component generated by GeoFeatureModelSerializerExtension.
# wrap the ref in the special list structure and build another component based on that.
feature_component = auto_schema.resolve_serializer(self.target.child, direction)
collection_schema = build_object_type(
properties={
"type": {"type": "string", "enum": ["FeatureCollection"]},
"features": build_array_type(feature_component.ref)
}
)
list_component = ResolvedComponent(
name=f'{feature_component.name}List',
type=ResolvedComponent.SCHEMA,
object=self.target.child,
schema=collection_schema
)
auto_schema.registry.register_on_missing(list_component)
return list_component.ref


class GeometryFieldExtension(OpenApiSerializerFieldExtension):
target_class = 'rest_framework_gis.fields.GeometryField'
match_subclasses = True

def map_serializer_field(self, auto_schema, direction):
# running this extension for GeoFeatureModelSerializer's geo_field is superfluous
# as above extension already handles that individually. We run it anyway because
# robustly checking the proper condition is harder.
try:
model = self.target.parent.Meta.model
model_field = get_field_info(model).fields[self.target.source]
return build_geo_schema(model_field)
except: # noqa: E722
warn(f'Encountered an issue resolving field {self.target}. defaulting to generic object.')
return {}
1 change: 1 addition & 0 deletions requirements/optionals.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ drf-nested-routers>=0.93.3
djangorestframework-recursive>=0.1.2
drf-spectacular-sidecar
djangorestframework-dataclasses>=1.0.0; python_version >= '3.7'
djangorestframework-gis>=1.0.0
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def pytest_configure(config):
'allauth.account',
'oauth2_provider',
'django_filters',
'rest_framework_gis',
# this is not strictly required and when added django-polymorphic
# currently breaks the whole Django/DRF upstream testing.
# 'polymorphic',
Expand Down
82 changes: 82 additions & 0 deletions tests/contrib/test_rest_framework_gis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from unittest import mock

import pytest
from django.db import models
from rest_framework import __version__ as DRF_VERSION # type: ignore[attr-defined]
from rest_framework import mixins, routers, serializers, viewsets

from drf_spectacular.utils import extend_schema_serializer
from tests import assert_schema, generate_schema


@pytest.mark.contrib('rest_framework_gis')
@pytest.mark.skipif(DRF_VERSION < '3.12', reason='DRF pagination schema broken')
@mock.patch('drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES', {})
def test_rest_framework_gis(no_warnings, clear_caches):
from django.contrib.gis.db.models import (
GeometryCollectionField, GeometryField, LineStringField, MultiLineStringField,
MultiPointField, MultiPolygonField, PointField, PolygonField,
)
from rest_framework_gis.pagination import GeoJsonPagination
from rest_framework_gis.serializers import GeoFeatureModelSerializer

class GeoModel(models.Model):
field_random1 = models.CharField(max_length=32)
field_random2 = models.IntegerField()
field_gis_plain = PointField()

field_polygon = PolygonField()
field_point = PointField()
field_linestring = LineStringField()
field_geometry = GeometryField()
field_multipolygon = MultiPolygonField()
field_multipoint = MultiPointField()
field_multilinestring = MultiLineStringField()
field_geometrycollection = GeometryCollectionField()

router = routers.SimpleRouter()

# all GIS fields as GeoJSON in singular and list form
fields = [
'Point', 'Polygon', 'Linestring', 'Geometry',
'Multipoint', 'Multipolygon', 'Multilinestring', 'Geometrycollection'
]
for name in fields:
@extend_schema_serializer(component_name=name)
class XSerializer(GeoFeatureModelSerializer):
class Meta:
model = GeoModel
geo_field = f'field_{name.lower()}'
auto_bbox = name == 'Polygon'
fields = ['id', 'field_random1', 'field_random2', 'field_gis_plain']

class XViewset(mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet):
serializer_class = XSerializer
queryset = GeoModel.objects.none()

router.register(name.lower(), XViewset, basename=name)

# plain serializer with GIS fields but without restructured container object
class PlainSerializer(serializers.ModelSerializer):
class Meta:
model = GeoModel
fields = ['id', 'field_random1', 'field_random2', 'field_gis_plain']

class PlainViewset(mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet):
serializer_class = PlainSerializer
queryset = GeoModel.objects.none()

router.register('plain', PlainViewset, basename='plain')

# GIS specific pagination
class PlainViewset(mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet):
serializer_class = PlainSerializer
queryset = GeoModel.objects.none()
pagination_class = GeoJsonPagination

router.register('paginated', PlainViewset, basename='paginated')

assert_schema(
generate_schema(None, patterns=router.urls),
'tests/contrib/test_rest_framework_gis.yml'
)
Loading

0 comments on commit b51575c

Please sign in to comment.