diff --git a/drf_spectacular/generators.py b/drf_spectacular/generators.py index 3cfcbf9c..9a2fa839 100644 --- a/drf_spectacular/generators.py +++ b/drf_spectacular/generators.py @@ -5,6 +5,7 @@ from rest_framework import views, viewsets from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator +from rest_framework.test import APIRequestFactory from drf_spectacular.extensions import OpenApiViewExtension from drf_spectacular.plumbing import ( @@ -127,6 +128,14 @@ def parse(self, request, public): if not self.has_view_permissions(path, method, view): continue + # mocked request to allow certain operations in get_queryset and get_serializer[_class] + # without exceptions being raised due to no request. + if not request: + request = getattr(APIRequestFactory(), method.lower())(path=path) + request = view.initialize_request(request) + + view.request = request + if view.versioning_class and not is_versioning_supported(view.versioning_class): warn( f'using unsupported versioning class "{view.versioning_class}". view will be ' diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 3a84c9ac..6a61e26e 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -23,8 +23,8 @@ from drf_spectacular.plumbing import ( ComponentRegistry, ResolvedComponent, anyisinstance, append_meta, build_array_type, build_basic_type, build_choice_field, build_object_type, build_parameter_type, error, - follow_field_source, force_instance, get_doc, get_override, has_override, is_basic_type, - is_field, is_serializer, resolve_regex_path_parameter, safe_ref, warn, + follow_field_source, force_instance, get_doc, get_override, get_view_model, has_override, + is_basic_type, is_field, is_serializer, resolve_regex_path_parameter, safe_ref, warn, ) from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes @@ -284,9 +284,7 @@ def _tokenize_path(self): return [t for t in path if t] def _resolve_path_parameters(self, variables): - model = getattr(getattr(self.view, 'queryset', None), 'model', None) parameters = [] - for variable in variables: schema = build_basic_type(OpenApiTypes.STR) description = '' @@ -297,15 +295,16 @@ def _resolve_path_parameters(self, variables): if resolved_parameter: schema = resolved_parameter['schema'] - elif not model: + elif get_view_model(self.view) is None: warn( f'could not derive type of path parameter "{variable}" because because it ' - f'is untyped and {self.view.__class__} has no queryset. consider adding a ' - f'type to the path (e.g. ) or annotating the parameter ' - f'type with @extend_schema. defaulting to "string".' + f'is untyped and obtaining queryset from {self.view.__class__} failed. ' + f'consider adding a type to the path (e.g. ) or annotating ' + f'the parameter type with @extend_schema. defaulting to "string".' ) else: try: + model = get_view_model(self.view) model_field = model._meta.get_field(variable) schema = self._map_model_field(model_field, direction=None) # strip irrelevant meta data diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 71d970e9..10e5ea62 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -153,6 +153,27 @@ def get_lib_doc_excludes(): ] +def get_view_model(view): + """ + obtain model from view via view's queryset. try safer view attribute first + before going through get_queryset(), which may perform arbitrary operations. + """ + model = getattr(getattr(view, 'queryset', None), 'model', None) + + if model is not None: + return model + + try: + return view.get_queryset().model + except Exception as exc: + warn( + f'failed to obtain model through view\'s queryset due to raised exception. ' + f'prevent this either by setting "queryset = Model.objects.none()" on the view, ' + f'having an empty fallback in get_queryset() or by using @extend_schema. ' + f'(Exception: {exc})' + ) + + def get_doc(obj): """ get doc string with fallback on obj's base classes (ignoring DRF documentation). """ if not inspect.isclass(obj): @@ -617,22 +638,9 @@ def operation_matches_version(view, requested_version): def modify_for_versioning(patterns, method, path, view, requested_version): - assert view.versioning_class - - from rest_framework.test import APIRequestFactory - - params = {'path': path} - if issubclass(view.versioning_class, versioning.AcceptHeaderVersioning): - renderer = view.get_renderers()[0] - params['HTTP_ACCEPT'] = f'{renderer.media_type}; version={requested_version}' + assert view.versioning_class and view.request - request = getattr(APIRequestFactory(), method.lower())(**params) - view.request = request - - # wrap request in DRF's Request, necessary for content negotiation - view.request = view.initialize_request(view.request) - - request.version = requested_version + view.request.version = requested_version if issubclass(view.versioning_class, versioning.URLPathVersioning): version_param = view.versioning_class.version_param @@ -644,14 +652,17 @@ def modify_for_versioning(patterns, method, path, view, requested_version): view.kwargs[version_param] = requested_version elif issubclass(view.versioning_class, versioning.NamespaceVersioning): try: - request.resolver_match = get_resolver( + view.request.resolver_match = get_resolver( urlconf=tuple(detype_pattern(p) for p in patterns) ).resolve(path) except Resolver404: error(f"namespace versioning path resolution failed for {path}. path will be ignored.") elif issubclass(view.versioning_class, versioning.AcceptHeaderVersioning): - neg = view.perform_content_negotiation(view.request) - view.request.accepted_renderer, view.request.accepted_media_type = neg + renderer = view.get_renderers()[0] + view.request.META['HTTP_ACCEPT'] = f'{renderer.media_type}; version={requested_version}' + + negotiated = view.perform_content_negotiation(view.request) + view.request.accepted_renderer, view.request.accepted_media_type = negotiated return path diff --git a/tests/test_regressions.py b/tests/test_regressions.py index e3abe7a8..6ec50ff9 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -922,3 +922,27 @@ def view_func(request, format=None): operation = schema['paths']['/multi/step/path/{someName}/']['get'] assert operation['parameters'][0]['name'] == 'someName' assert operation['operationId'] == 'multiStepPathRetrieve' + + +def test_mocked_request_with_get_queryset_get_serializer_class(no_warnings): + class M4(models.Model): + pass + + class XSerializer(serializers.ModelSerializer): + class Meta: + fields = '__all__' + model = M4 + + class XViewset(viewsets.ReadOnlyModelViewSet): + def get_serializer_class(self): + assert not self.request.user.is_authenticated + assert self.action in ['retrieve', 'list'] + return XSerializer + + def get_queryset(self): + assert not self.request.user.is_authenticated + assert self.request.method == 'GET' + return M4.objects.none() + + schema = generate_schema('x', XViewset) + validate_schema(schema) diff --git a/tests/test_warnings.py b/tests/test_warnings.py index 2a17a449..9a97913e 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -69,7 +69,9 @@ class X1Viewset(viewsets.ReadOnlyModelViewSet): serializer_class = X1Serializer generate_schema('x1', X1Viewset) - assert 'no queryset' in capsys.readouterr().err + stderr = capsys.readouterr().err + assert 'obtaining queryset from' in stderr # warning 1 + assert 'failed to obtain model through view\'s queryset' in stderr # warning 2 def test_path_param_not_in_model(capsys):