Skip to content

Commit

Permalink
add mocked request for view processing. #81 #141
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Sep 12, 2020
1 parent 37f8a64 commit e3103ed
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 27 deletions.
9 changes: 9 additions & 0 deletions drf_spectacular/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from rest_framework import views, viewsets
from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore
from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator
from rest_framework.test import APIRequestFactory

from drf_spectacular.extensions import OpenApiViewExtension
from drf_spectacular.plumbing import (
Expand Down Expand Up @@ -127,6 +128,14 @@ def parse(self, request, public):
if not self.has_view_permissions(path, method, view):
continue

# mocked request to allow certain operations in get_queryset and get_serializer[_class]
# without exceptions being raised due to no request.
if not request:
request = getattr(APIRequestFactory(), method.lower())(path=path)
request = view.initialize_request(request)

view.request = request

if view.versioning_class and not is_versioning_supported(view.versioning_class):
warn(
f'using unsupported versioning class "{view.versioning_class}". view will be '
Expand Down
15 changes: 7 additions & 8 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from drf_spectacular.plumbing import (
ComponentRegistry, ResolvedComponent, anyisinstance, append_meta, build_array_type,
build_basic_type, build_choice_field, build_object_type, build_parameter_type, error,
follow_field_source, force_instance, get_doc, get_override, has_override, is_basic_type,
is_field, is_serializer, resolve_regex_path_parameter, safe_ref, warn,
follow_field_source, force_instance, get_doc, get_override, get_view_model, has_override,
is_basic_type, is_field, is_serializer, resolve_regex_path_parameter, safe_ref, warn,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
Expand Down Expand Up @@ -284,9 +284,7 @@ def _tokenize_path(self):
return [t for t in path if t]

def _resolve_path_parameters(self, variables):
model = getattr(getattr(self.view, 'queryset', None), 'model', None)
parameters = []

for variable in variables:
schema = build_basic_type(OpenApiTypes.STR)
description = ''
Expand All @@ -297,15 +295,16 @@ def _resolve_path_parameters(self, variables):

if resolved_parameter:
schema = resolved_parameter['schema']
elif not model:
elif get_view_model(self.view) is None:
warn(
f'could not derive type of path parameter "{variable}" because because it '
f'is untyped and {self.view.__class__} has no queryset. consider adding a '
f'type to the path (e.g. <int:{variable}>) or annotating the parameter '
f'type with @extend_schema. defaulting to "string".'
f'is untyped and obtaining queryset from {self.view.__class__} failed. '
f'consider adding a type to the path (e.g. <int:{variable}>) or annotating '
f'the parameter type with @extend_schema. defaulting to "string".'
)
else:
try:
model = get_view_model(self.view)
model_field = model._meta.get_field(variable)
schema = self._map_model_field(model_field, direction=None)
# strip irrelevant meta data
Expand Down
47 changes: 29 additions & 18 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,27 @@ def get_lib_doc_excludes():
]


def get_view_model(view):
"""
obtain model from view via view's queryset. try safer view attribute first
before going through get_queryset(), which may perform arbitrary operations.
"""
model = getattr(getattr(view, 'queryset', None), 'model', None)

if model is not None:
return model

try:
return view.get_queryset().model
except Exception as exc:
warn(
f'failed to obtain model through view\'s queryset due to raised exception. '
f'prevent this either by setting "queryset = Model.objects.none()" on the view, '
f'having an empty fallback in get_queryset() or by using @extend_schema. '
f'(Exception: {exc})'
)


def get_doc(obj):
""" get doc string with fallback on obj's base classes (ignoring DRF documentation). """
if not inspect.isclass(obj):
Expand Down Expand Up @@ -617,22 +638,9 @@ def operation_matches_version(view, requested_version):


def modify_for_versioning(patterns, method, path, view, requested_version):
assert view.versioning_class

from rest_framework.test import APIRequestFactory

params = {'path': path}
if issubclass(view.versioning_class, versioning.AcceptHeaderVersioning):
renderer = view.get_renderers()[0]
params['HTTP_ACCEPT'] = f'{renderer.media_type}; version={requested_version}'
assert view.versioning_class and view.request

request = getattr(APIRequestFactory(), method.lower())(**params)
view.request = request

# wrap request in DRF's Request, necessary for content negotiation
view.request = view.initialize_request(view.request)

request.version = requested_version
view.request.version = requested_version

if issubclass(view.versioning_class, versioning.URLPathVersioning):
version_param = view.versioning_class.version_param
Expand All @@ -644,14 +652,17 @@ def modify_for_versioning(patterns, method, path, view, requested_version):
view.kwargs[version_param] = requested_version
elif issubclass(view.versioning_class, versioning.NamespaceVersioning):
try:
request.resolver_match = get_resolver(
view.request.resolver_match = get_resolver(
urlconf=tuple(detype_pattern(p) for p in patterns)
).resolve(path)
except Resolver404:
error(f"namespace versioning path resolution failed for {path}. path will be ignored.")
elif issubclass(view.versioning_class, versioning.AcceptHeaderVersioning):
neg = view.perform_content_negotiation(view.request)
view.request.accepted_renderer, view.request.accepted_media_type = neg
renderer = view.get_renderers()[0]

This comment has been minimized.

Copy link
@ticosax

ticosax Sep 30, 2020

Contributor

This seems to break content negotiation, because the first renderer is not necessarily the one requested.

This comment has been minimized.

Copy link
@tfranzel

tfranzel Oct 1, 2020

Author Owner

@ticosax good catch! in case of accept header version together with SpectacularApiView (and a UI) unexpectedly modifies the request. can we make this a proper issue? inline code comments are hard to follow.

This comment has been minimized.

Copy link
@ticosax

ticosax Oct 1, 2020

Contributor

I'm working already on a fix, PR will follow soon

view.request.META['HTTP_ACCEPT'] = f'{renderer.media_type}; version={requested_version}'

negotiated = view.perform_content_negotiation(view.request)
view.request.accepted_renderer, view.request.accepted_media_type = negotiated

return path

Expand Down
24 changes: 24 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,3 +922,27 @@ def view_func(request, format=None):
operation = schema['paths']['/multi/step/path/{someName}/']['get']
assert operation['parameters'][0]['name'] == 'someName'
assert operation['operationId'] == 'multiStepPathRetrieve'


def test_mocked_request_with_get_queryset_get_serializer_class(no_warnings):
class M4(models.Model):
pass

class XSerializer(serializers.ModelSerializer):
class Meta:
fields = '__all__'
model = M4

class XViewset(viewsets.ReadOnlyModelViewSet):
def get_serializer_class(self):
assert not self.request.user.is_authenticated
assert self.action in ['retrieve', 'list']
return XSerializer

def get_queryset(self):
assert not self.request.user.is_authenticated
assert self.request.method == 'GET'
return M4.objects.none()

schema = generate_schema('x', XViewset)
validate_schema(schema)
4 changes: 3 additions & 1 deletion tests/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ class X1Viewset(viewsets.ReadOnlyModelViewSet):
serializer_class = X1Serializer

generate_schema('x1', X1Viewset)
assert 'no queryset' in capsys.readouterr().err
stderr = capsys.readouterr().err
assert 'obtaining queryset from' in stderr # warning 1
assert 'failed to obtain model through view\'s queryset' in stderr # warning 2


def test_path_param_not_in_model(capsys):
Expand Down

0 comments on commit e3103ed

Please sign in to comment.