diff --git a/drf_spectacular/hooks.py b/drf_spectacular/hooks.py index 82dd6711..8ea118ed 100644 --- a/drf_spectacular/hooks.py +++ b/drf_spectacular/hooks.py @@ -33,6 +33,7 @@ def iter_prop_containers(schema, component_name=None): yield component_name, schema['properties'] yield from iter_prop_containers(schema.get('oneOf', []), component_name) yield from iter_prop_containers(schema.get('allOf', []), component_name) + yield from iter_prop_containers(schema.get('anyOf', []), component_name) def create_enum_component(name, schema): component = ResolvedComponent( @@ -53,6 +54,8 @@ def create_enum_component(name, schema): # collect all enums, their names and choice sets for component_name, props in iter_prop_containers(schemas): for prop_name, prop_schema in props.items(): + if prop_schema.get('type') == 'array': + prop_schema = prop_schema.get('items', {}) if 'enum' not in prop_schema: continue # remove blank/null entry for hashing. will be reconstructed in the last step @@ -98,6 +101,10 @@ def create_enum_component(name, schema): # enum, replace it with a reference and add a corresponding component. for _, props in iter_prop_containers(schemas): for prop_name, prop_schema in props.items(): + is_array = prop_schema.get('type') == 'array' + if is_array: + prop_schema = prop_schema.get('items', {}) + if 'enum' not in prop_schema: continue @@ -126,7 +133,10 @@ def create_enum_component(name, schema): else: prop_schema.update({'oneOf': [c.ref for c in components]}) - props[prop_name] = safe_ref(prop_schema) + if is_array: + props[prop_name]['items'] = safe_ref(prop_schema) + else: + props[prop_name] = safe_ref(prop_schema) # sort again with additional components result['components'] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS) diff --git a/requirements/optionals.txt b/requirements/optionals.txt index 7272a266..b50a973e 100644 --- a/requirements/optionals.txt +++ b/requirements/optionals.txt @@ -6,4 +6,5 @@ django-polymorphic>=2.1 django-rest-polymorphic>=0.1.8 django-oauth-toolkit>=1.2.0 djangorestframework-camel-case>=1.1.2 -django-filter>=2.3.0 \ No newline at end of file +django-filter>=2.3.0 +psycopg2-binary>=2.7.3.2 \ No newline at end of file diff --git a/tests/contrib/test_django_filters.py b/tests/contrib/test_django_filters.py index e92a8eab..5f51e4e3 100644 --- a/tests/contrib/test_django_filters.py +++ b/tests/contrib/test_django_filters.py @@ -103,6 +103,7 @@ def test_django_filters(no_warnings): @pytest.mark.urls(__name__) @pytest.mark.django_db +@pytest.mark.contrib('django_filter') def test_django_filters_requests(no_warnings): other_sub_product = OtherSubProduct.objects.create(uuid=uuid.uuid4()) product = Product.objects.create( diff --git a/tests/test_extend_schema.py b/tests/test_extend_schema.py index c7c49c49..55ae57da 100644 --- a/tests/test_extend_schema.py +++ b/tests/test_extend_schema.py @@ -47,7 +47,7 @@ class QuerySerializer(serializers.Serializer): min_length=3, max_length=10, help_text='filter by containing string', required=False ) order_by = serializers.MultipleChoiceField( - choices=['a', 'b'], + choices=['a', 'b', 'c'], default=['a'], ) diff --git a/tests/test_extend_schema.yml b/tests/test_extend_schema.yml index 82d0ce91..613fb9fc 100644 --- a/tests/test_extend_schema.yml +++ b/tests/test_extend_schema.yml @@ -230,6 +230,7 @@ paths: enum: - a - b + - c type: string default: - a @@ -339,6 +340,12 @@ components: required: - inline_b - inline_i + OrderByEnum: + enum: + - a + - b + - c + type: string Query: type: object properties: @@ -355,10 +362,7 @@ components: order_by: type: array items: - enum: - - a - - b - type: string + $ref: '#/components/schemas/OrderByEnum' default: - a required: diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 10c596fd..69569ab0 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -1244,3 +1244,29 @@ class XViewset(viewsets.ModelViewSet): schema = SchemaGenerator(patterns=router.urls).get_schema(request=None, public=True) assert '/x/{related_field}/' in schema['paths'] assert '/x/{related_field}/{id}/' in schema['paths'] + + +@pytest.mark.contrib('psycopg2') +def test_multiple_choice_enum(no_warnings): + from django.contrib.postgres.fields import ArrayField + + class M4(models.Model): + multi = ArrayField( + models.CharField(max_length=10, choices=(('A', 'A'), ('B', 'B'))), + size=8, + ) + + class M4Serializer(serializers.ModelSerializer): + class Meta: + fields = '__all__' + model = M4 + + class XViewset(viewsets.ModelViewSet): + serializer_class = M4Serializer + queryset = M4.objects.none() + + schema = generate_schema('x', XViewset) + assert 'MultiEnum' in schema['components']['schemas'] + prop = schema['components']['schemas']['M4']['properties']['multi'] + assert prop['type'] == 'array' + assert prop['items']['$ref'] == '#/components/schemas/MultiEnum'