Skip to content

Commit

Permalink
fallback improvements to typing system with typing_extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Sep 22, 2021
1 parent 536a4d6 commit 0437440
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 53 deletions.
21 changes: 13 additions & 8 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from enum import Enum
from typing import DefaultDict, Generic, List, Optional, Type, TypeVar, Union

if sys.version_info >= (3, 8):
from typing import Literal, _TypedDictMeta # type: ignore[attr-defined]
else:
from typing_extensions import Literal, _TypedDictMeta # type: ignore[attr-defined]

import inflection
import uritemplate
from django.apps import apps
Expand All @@ -24,14 +29,13 @@
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.sql.query import Query
from django.urls.converters import get_converters
from django.urls.resolvers import ( # type: ignore
from django.urls.resolvers import ( # type: ignore[attr-defined]
_PATH_PARAMETER_COMPONENT_RE, RegexPattern, Resolver404, RoutePattern, URLPattern, URLResolver,
get_resolver,
)
from django.utils.functional import Promise, cached_property
from django.utils.module_loading import import_string
from django.utils.translation import gettext_lazy as _
from django.utils.version import PY38
from rest_framework import exceptions, fields, mixins, serializers, versioning
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
Expand All @@ -51,7 +55,7 @@
class Choices: # type: ignore
pass

if PY38:
if sys.version_info >= (3, 8):
CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) # type: ignore
else:
CACHED_PROPERTY_FUNCS = (cached_property,) # type: ignore
Expand Down Expand Up @@ -478,7 +482,7 @@ def _follow_return_type(a_callable):
if target_type is None:
return target_type
origin, args = _get_type_hint_origin(target_type)
if origin is typing.Union:
if origin is Union:
type_args = [arg for arg in args if arg is not type(None)] # noqa: E721
if len(type_args) > 1:
warn(
Expand Down Expand Up @@ -1088,8 +1092,9 @@ def resolve_type_hint(hint):
return build_array_type(resolve_type_hint(args[0]))
elif origin is frozenset:
return build_array_type(resolve_type_hint(args[0]))
elif hasattr(typing, 'Literal') and origin is typing.Literal:
# python >= 3.8
elif origin is Literal:
# Literal only works for python >= 3.8 despite typing_extensions, because it
# behaves slightly different w.r.t. __origin__
schema = {'enum': list(args)}
if all(type(args[0]) is type(choice) for choice in args):
schema.update(build_basic_type(type(args[0])))
Expand All @@ -1100,13 +1105,13 @@ def resolve_type_hint(hint):
if mixin_base_types:
schema.update(build_basic_type(mixin_base_types[0]))
return schema
elif hasattr(typing, 'TypedDict') and isinstance(hint, typing._TypedDictMeta):
elif isinstance(hint, _TypedDictMeta):
return build_object_type(
properties={
k: resolve_type_hint(v) for k, v in get_type_hints(hint).items()
}
)
elif origin is typing.Union:
elif origin is Union:
type_args = [arg for arg in args if arg is not type(None)] # noqa: E721
if len(type_args) > 1:
schema = {'oneOf': [resolve_type_hint(arg) for arg in type_args]}
Expand Down
19 changes: 14 additions & 5 deletions drf_spectacular/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import inspect
import sys
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

from rest_framework.fields import Field, empty
Expand All @@ -9,8 +10,16 @@
from drf_spectacular.drainage import error, get_view_methods, set_override, warn
from drf_spectacular.types import OpenApiTypes, _KnownPythonTypes

if sys.version_info >= (3, 8):
from typing import Final, Literal
else:
from typing_extensions import Final, Literal


_SerializerType = Union[Serializer, Type[Serializer]]

_ParameterLocationType = Literal['query', 'path', 'header', 'cookie']


class PolymorphicProxySerializer(Serializer):
"""
Expand Down Expand Up @@ -125,16 +134,16 @@ class OpenApiParameter(OpenApiSchemaBase):
For valid ``style`` choices please consult the
`OpenAPI specification <https://swagger.io/specification/#style-values>`_.
"""
QUERY = 'query'
PATH = 'path'
HEADER = 'header'
COOKIE = 'cookie'
QUERY: Final = 'query'
PATH: Final = 'path'
HEADER: Final = 'header'
COOKIE: Final = 'cookie'

def __init__(
self,
name: str,
type: Union[_SerializerType, _KnownPythonTypes, OpenApiTypes, dict] = str,
location: str = QUERY,
location: _ParameterLocationType = QUERY,
required: bool = False,
description: str = '',
enum: Optional[List[Any]] = None,
Expand Down
3 changes: 2 additions & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ djangorestframework>=3.10
uritemplate>=2.0.0
PyYAML>=5.1
jsonschema>=2.6.0
inflection>=0.3.1
inflection>=0.3.1
typing-extensions; python_version < "3.8"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
long_description = readme.read()

with open('requirements/base.txt') as fh:
requirements = [r for r in fh.read().split() if not r.startswith('#')]
requirements = [r for r in fh.read().split('\n') if not r.startswith('#')]


def get_version(package):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import json
import sys
import tempfile
import uuid
from datetime import timedelta
Expand All @@ -20,14 +21,14 @@
from drf_spectacular.generators import SchemaGenerator
from tests import assert_equal, assert_schema, build_absolute_file_path

try:
functools_cached_property = functools.cached_property # type: ignore
except AttributeError:
# functools.cached_property is only available in Python 3.8+.
if sys.version_info >= (3, 8):
functools_cached_property = functools.cached_property
else:
# We re-use Django's cached_property when it's not avaiable to
# keep tests unified across Python versions.
functools_cached_property = cached_property


fs = FileSystemStorage(location=tempfile.gettempdir())


Expand Down
76 changes: 43 additions & 33 deletions tests/test_plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from drf_spectacular.validation import validate_schema
from tests import generate_schema

if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict


def test_is_serializer():
assert not is_serializer(serializers.SlugField)
Expand Down Expand Up @@ -86,6 +91,9 @@ def test_detype_patterns_with_module_includes(no_warnings):
)


NamedTupleA = collections.namedtuple("NamedTupleA", "a, b")


class NamedTupleB(typing.NamedTuple):
a: int
b: str
Expand All @@ -103,6 +111,10 @@ class InvalidLanguageEnum(Enum):
DE = 'de'


TD1 = TypedDict('TD1', {"foo": int, "bar": typing.List[str]})
TD2 = TypedDict('TD2', {"foo": str, "bar": typing.Dict[str, int]})


TYPE_HINT_TEST_PARAMS = [
(
typing.Optional[int],
Expand Down Expand Up @@ -152,6 +164,34 @@ class InvalidLanguageEnum(Enum):
), (
InvalidLanguageEnum,
{'enum': ['en', 'de']}
), (
TD1,
{
'type': 'object',
'properties': {
'foo': {'type': 'integer'},
'bar': {'type': 'array', 'items': {'type': 'string'}}
}
}
), (
typing.List[TD2],
{
'type': 'array',
'items': {
'type': 'object',
'properties': {
'foo': {'type': 'string'},
'bar': {'type': 'object', 'additionalProperties': {'type': 'integer'}}
}
}
}
), (
NamedTupleB,
{
'type': 'object',
'properties': {'a': {'type': 'integer'}, 'b': {'type': 'string'}},
'required': ['a', 'b']
}
)
]

Expand All @@ -170,50 +210,20 @@ class LanguageChoices(TextChoices):

if sys.version_info >= (3, 7):
TYPE_HINT_TEST_PARAMS.append((
typing.Iterable[collections.namedtuple("NamedTupleA", "a, b")], # noqa
typing.Iterable[NamedTupleA],
{
'type': 'array',
'items': {'type': 'object', 'properties': {'a': {}, 'b': {}}, 'required': ['a', 'b']}
}
))
TYPE_HINT_TEST_PARAMS.append((
NamedTupleB,
{
'type': 'object',
'properties': {'a': {'type': 'integer'}, 'b': {'type': 'string'}},
'required': ['a', 'b']
}
))

if sys.version_info >= (3, 8):
# Literal only works for python >= 3.8 despite typing_extensions, because it
# behaves slightly different w.r.t. __origin__
TYPE_HINT_TEST_PARAMS.append((
typing.Literal['x', 'y'],
{'enum': ['x', 'y'], 'type': 'string'}
))
TYPE_HINT_TEST_PARAMS.append((
typing.TypedDict('TD', foo=int, bar=typing.List[str]),
{
'type': 'object',
'properties': {
'foo': {'type': 'integer'},
'bar': {'type': 'array', 'items': {'type': 'string'}}
}
}
))
TYPE_HINT_TEST_PARAMS.append((
typing.List[typing.TypedDict('TD', foo=str, bar=typing.Dict[str, int])], # noqa: F821
{
'type': 'array',
'items': {
'type': 'object',
'properties': {
'foo': {'type': 'string'},
'bar': {'type': 'object', 'additionalProperties': {'type': 'integer'}}
}
}
}
))


if sys.version_info >= (3, 9):
TYPE_HINT_TEST_PARAMS.append((
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ use_parentheses = true
include_trailing_comma = true

[mypy]
python_version = 3.6
python_version = 3.8
plugins = mypy_django_plugin.main,mypy_drf_plugin.main

[mypy.plugins.django-stubs]
Expand Down

0 comments on commit 0437440

Please sign in to comment.