Skip to content

Commit

Permalink
Fix inheritance bugs with @extend_schema_view().
Browse files Browse the repository at this point in the history
When creating a copy of a method from a parent class we now:

- Ensure that `__qualname__` is defined correctly
  - i.e. `Child.method` instead of `Parent.method`.
  - This isn't essential but helps diagnosing issues when debugging.
- Move application of the decorator to the last moment.
- Deep copy the existing schema extensions before applying decorator.

This fixes tfranzel#218 where two child classes with @extend_schema_view affect
each other - schema extensions are applied to the parent such that the
second child overwrites the changes applied to the first child.

This also fixes my case where a child with @extend_schema_view clobbered
the schema extensions of the parent which also used @extend_schema_view.
  • Loading branch information
ngnpope committed Oct 11, 2021
1 parent 2a1f1ec commit d8bbb22
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 6 deletions.
16 changes: 12 additions & 4 deletions drf_spectacular/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import inspect
import sys
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

from rest_framework.fields import Field, empty
Expand Down Expand Up @@ -472,13 +473,20 @@ def extend_schema_view(**kwargs) -> Callable[[F], F]:
:param kwargs: method names as argument names and :func:`@extend_schema <.extend_schema>`
calls as values
"""
def wrapping_decorator(method_decorator, method):
@method_decorator
def wrapping_decorator(method_decorator, view, method):
@functools.wraps(method)
def wrapped_method(self, request, *args, **kwargs):
return method(self, request, *args, **kwargs)

return wrapped_method
# Construct a new __qualname__ based on the __name__ of the target view.
wrapped_method.__qualname__ = f'{view.__name__}.{method.__name__}'

# Clone the extended schema if the source method has it.
if hasattr(method, 'kwargs'):
wrapped_method.kwargs = deepcopy(method.kwargs)

# Finally apply any additional schema extensions applied to the target view.
return method_decorator(wrapped_method)

def decorator(view):
view_methods = {m.__name__: m for m in get_view_methods(view)}
Expand All @@ -498,7 +506,7 @@ def decorator(view):
if method_name in view.__dict__:
method_decorator(method)
else:
setattr(view, method_name, wrapping_decorator(method_decorator, method))
setattr(view, method_name, wrapping_decorator(method_decorator, view, method))
return view

return decorator
Expand Down
19 changes: 17 additions & 2 deletions tests/test_extend_schema_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Meta:
extended_action=extend_schema(description='view extended action description'),
raw_action=extend_schema(description='view raw action description'),
)
class XViewset(mixins.ListModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet):
class XViewSet(mixins.ListModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet):
queryset = ESVModel.objects.all()
serializer_class = ESVSerializer

Expand All @@ -52,9 +52,24 @@ class YViewSet(viewsets.ModelViewSet):
queryset = ESVModel.objects.all()


# view to make sure that schema applied to a subclass does not affect its parent.
@extend_schema_view(
list=extend_schema(exclude=True),
retrieve=extend_schema(description='overridden description for child only'),
extended_action=extend_schema(responses={200: {'type': 'string', 'pattern': r'^[0-9]{4}(?:-[0-9]{2}){2}$'}}),
raw_action=extend_schema(summary="view raw action summary"),
)
class ZViewSet(XViewSet):
@extend_schema(tags=['child-tag'])
@action(detail=False, methods=['GET'])
def raw_action(self, request):
return Response('2019-03-01')


router = routers.SimpleRouter()
router.register('x', XViewset)
router.register('x', XViewSet)
router.register('y', YViewSet)
router.register('z', ZViewSet)
urlpatterns = router.urls


Expand Down
59 changes: 59 additions & 0 deletions tests/test_extend_schema_view.yml
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,65 @@ paths:
responses:
'204':
description: No response body
/z/{id}/:
get:
operationId: z_retrieve
description: overridden description for child only
parameters:
- in: path
name: id
schema:
type: integer
description: A unique integer value identifying this esv model.
required: true
tags:
- custom-retrieve-tag
security:
- cookieAuth: []
- basicAuth: []
- {}
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/ESV'
description: ''
/z/extended_action/:
get:
operationId: z_extended_action_retrieve
description: view extended action description
tags:
- global-tag
security:
- cookieAuth: []
- basicAuth: []
- {}
responses:
'200':
content:
application/json:
schema:
type: string
pattern: ^[0-9]{4}(?:-[0-9]{2}){2}$
description: ''
/z/raw_action/:
get:
operationId: z_raw_action_retrieve
summary: view raw action summary
tags:
- child-tag
security:
- cookieAuth: []
- basicAuth: []
- {}
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/ESV'
description: ''
components:
schemas:
ESV:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2569,3 +2569,38 @@ def custom_action(self):

schema = generate_schema('x', viewset=XViewSet)
schema['paths']['/x/{id}/custom_action/']['get']['summary'] == 'A custom action!'


def test_extend_schema_view_isolation(no_warnings):

class Animal(models.Model):
pass

class AnimalSerializer(serializers.ModelSerializer):
class Meta:
model = Animal
fields = '__all__'

class AnimalViewSet(viewsets.GenericViewSet):
serializer_class = AnimalSerializer
queryset = Animal.objects.all()

@action(detail=False)
def notes(self, request):
pass # pragma: no cover

@extend_schema_view(notes=extend_schema(summary='List mammals.'))
class MammalViewSet(AnimalViewSet):
pass

@extend_schema_view(notes=extend_schema(summary='List insects.'))
class InsectViewSet(AnimalViewSet):
pass

router = routers.SimpleRouter()
router.register('api/mammals', MammalViewSet)
router.register('api/insects', InsectViewSet)

schema = generate_schema(None, patterns=router.urls)
assert schema['paths']['/api/mammals/notes/']['get']['summary'] == 'List mammals.'
assert schema['paths']['/api/insects/notes/']['get']['summary'] == 'List insects.'

0 comments on commit d8bbb22

Please sign in to comment.