Skip to content

Commit

Permalink
Fix inheritance bugs with @extend_schema_view().
Browse files Browse the repository at this point in the history
When creating a copy of a method from a parent class we now:

- Ensure that `__qualname__` is defined correctly
  - i.e. `Child.method` instead of `Parent.method`.
  - This isn't essential but helps diagnosing issues when debugging.
- Move application of the decorator to the last moment.
- Deep copy the existing schema extensions before applying decorator.

This fixes tfranzel#218 where two child classes with @extend_schema_view affect
each other - schema extensions are applied to the parent such that the
second child overwrites the changes applied to the first child.

This also fixes my case where a child with @extend_schema_view clobbered
the schema extensions of the parent which also used @extend_schema_view.
  • Loading branch information
ngnpope committed Oct 7, 2021
1 parent 2c91cbd commit 34cd906
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 6 deletions.
16 changes: 12 additions & 4 deletions drf_spectacular/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import inspect
import sys
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union

from rest_framework.fields import Field, empty
Expand Down Expand Up @@ -469,13 +470,20 @@ def extend_schema_view(**kwargs) -> Callable[[F], F]:
:param kwargs: method names as argument names and :func:`@extend_schema <.extend_schema>`
calls as values
"""
def wrapping_decorator(method_decorator, method):
@method_decorator
def wrapping_decorator(method_decorator, view, method):
@functools.wraps(method)
def wrapped_method(self, request, *args, **kwargs):
return method(self, request, *args, **kwargs)

return wrapped_method
# Construct a new __qualname__ based on the __name__ of the target view.
wrapped_method.__qualname__ = f'{view.__name__}.{method.__name__}'

# Clone the extended schema if the source method has it.
if hasattr(method, 'kwargs'):
wrapped_method.kwargs = deepcopy(method.kwargs)

# Finally apply any additional schema extensions applied to the target view.
return method_decorator(wrapped_method)

def decorator(view):
view_methods = {m.__name__: m for m in get_view_methods(view)}
Expand All @@ -495,7 +503,7 @@ def decorator(view):
if method_name in view.__dict__:
method_decorator(method)
else:
setattr(view, method_name, wrapping_decorator(method_decorator, method))
setattr(view, method_name, wrapping_decorator(method_decorator, view, method))
return view

return decorator
Expand Down
62 changes: 60 additions & 2 deletions tests/test_extend_schema_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Meta:
extended_action=extend_schema(description='view extended action description'),
raw_action=extend_schema(description='view raw action description'),
)
class XViewset(mixins.ListModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet):
class XViewSet(mixins.ListModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet):
queryset = ESVModel.objects.all()
serializer_class = ESVSerializer

Expand All @@ -52,9 +52,67 @@ class YViewSet(viewsets.ModelViewSet):
queryset = ESVModel.objects.all()


# view to make sure that schema applied to a subclass does not affect its parent.
@extend_schema_view(
list=extend_schema(exclude=True),
retrieve=extend_schema(description='overridden description for child only'),
extended_action=extend_schema(responses={200: {'type': 'string', 'pattern': r'^[0-9]{4}(?:-[0-9]{2}){2}$'}}),
raw_action=extend_schema(summary="view raw action summary"),
)
class ZViewSet(XViewSet):
@extend_schema(tags=['child-tag'])
@action(detail=False, methods=['GET'])
def raw_action(self, request):
return Response('2019-03-01')


class Appointment(models.Model):
pass


class AppointmentSerializer(serializers.ModelSerializer):
class Meta:
model = Appointment
fields = '__all__'


class AppointmentViewSet(viewsets.GenericViewSet):
serializer_class = AppointmentSerializer
queryset = Appointment.objects.all()

@action(detail=False)
def notes(self, request):
return Response(['A note.'])


@extend_schema_view(
notes=extend_schema(
tags=["Doctor Appointments"],
summary="Get appointment notes",
description="Retrieves the detailed appointment notes as written by the doctor.",
),
)
class DoctorAppointmentViewSet(AppointmentViewSet):
pass


@extend_schema_view(
notes=extend_schema(
tags=["Patient Appointments"],
summary="Get appointment notes",
description="Retrieves the summarized appointment notes for viewing by the patient.",
),
)
class PatientAppointmentViewSet(AppointmentViewSet):
pass


router = routers.SimpleRouter()
router.register('x', XViewset)
router.register('x', XViewSet)
router.register('y', YViewSet)
router.register('z', ZViewSet)
router.register('api/doctors', DoctorAppointmentViewSet)
router.register('api/patients', PatientAppointmentViewSet)
urlpatterns = router.urls


Expand Down
103 changes: 103 additions & 0 deletions tests/test_extend_schema_view.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,42 @@ info:
title: ''
version: 0.0.0
paths:
/api/doctors/notes/:
get:
operationId: api_doctors_notes_retrieve
description: Retrieves the detailed appointment notes as written by the doctor.
summary: Get appointment notes
tags:
- Doctor Appointments
security:
- cookieAuth: []
- basicAuth: []
- {}
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/Appointment'
description: ''
/api/patients/notes/:
get:
operationId: api_patients_notes_retrieve
description: Retrieves the summarized appointment notes for viewing by the patient.
summary: Get appointment notes
tags:
- Patient Appointments
security:
- cookieAuth: []
- basicAuth: []
- {}
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/Appointment'
description: ''
/x/:
get:
operationId: x_list
Expand Down Expand Up @@ -232,8 +268,75 @@ paths:
responses:
'204':
description: No response body
/z/{id}/:
get:
operationId: z_retrieve
description: overridden description for child only
parameters:
- in: path
name: id
schema:
type: integer
description: A unique integer value identifying this esv model.
required: true
tags:
- custom-retrieve-tag
security:
- cookieAuth: []
- basicAuth: []
- {}
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/ESV'
description: ''
/z/extended_action/:
get:
operationId: z_extended_action_retrieve
description: view extended action description
tags:
- global-tag
security:
- cookieAuth: []
- basicAuth: []
- {}
responses:
'200':
content:
application/json:
schema:
type: string
pattern: ^[0-9]{4}(?:-[0-9]{2}){2}$
description: ''
/z/raw_action/:
get:
operationId: z_raw_action_retrieve
summary: view raw action summary
tags:
- child-tag
security:
- cookieAuth: []
- basicAuth: []
- {}
responses:
'200':
content:
application/json:
schema:
$ref: '#/components/schemas/ESV'
description: ''
components:
schemas:
Appointment:
type: object
properties:
id:
type: integer
readOnly: true
required:
- id
ESV:
type: object
properties:
Expand Down

0 comments on commit 34cd906

Please sign in to comment.