Skip to content

Commit

Permalink
bugfix enum substitution for enumed arrays (multiple choice)
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Jan 17, 2021
1 parent 22c5c7a commit 5c9d446
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 7 deletions.
12 changes: 11 additions & 1 deletion drf_spectacular/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def iter_prop_containers(schema, component_name=None):
yield component_name, schema['properties']
yield from iter_prop_containers(schema.get('oneOf', []), component_name)
yield from iter_prop_containers(schema.get('allOf', []), component_name)
yield from iter_prop_containers(schema.get('anyOf', []), component_name)

def create_enum_component(name, schema):
component = ResolvedComponent(
Expand All @@ -53,6 +54,8 @@ def create_enum_component(name, schema):
# collect all enums, their names and choice sets
for component_name, props in iter_prop_containers(schemas):
for prop_name, prop_schema in props.items():
if prop_schema.get('type') == 'array':
prop_schema = prop_schema.get('items', {})
if 'enum' not in prop_schema:
continue
# remove blank/null entry for hashing. will be reconstructed in the last step
Expand Down Expand Up @@ -98,6 +101,10 @@ def create_enum_component(name, schema):
# enum, replace it with a reference and add a corresponding component.
for _, props in iter_prop_containers(schemas):
for prop_name, prop_schema in props.items():
is_array = prop_schema.get('type') == 'array'
if is_array:
prop_schema = prop_schema.get('items', {})

if 'enum' not in prop_schema:
continue

Expand Down Expand Up @@ -126,7 +133,10 @@ def create_enum_component(name, schema):
else:
prop_schema.update({'oneOf': [c.ref for c in components]})

props[prop_name] = safe_ref(prop_schema)
if is_array:
props[prop_name]['items'] = safe_ref(prop_schema)
else:
props[prop_name] = safe_ref(prop_schema)

# sort again with additional components
result['components'] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS)
Expand Down
3 changes: 2 additions & 1 deletion requirements/optionals.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ django-polymorphic>=2.1
django-rest-polymorphic>=0.1.8
django-oauth-toolkit>=1.2.0
djangorestframework-camel-case>=1.1.2
django-filter>=2.3.0
django-filter>=2.3.0
psycopg2-binary>=2.7.3.2
1 change: 1 addition & 0 deletions tests/contrib/test_django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def test_django_filters(no_warnings):

@pytest.mark.urls(__name__)
@pytest.mark.django_db
@pytest.mark.contrib('django_filter')
def test_django_filters_requests(no_warnings):
other_sub_product = OtherSubProduct.objects.create(uuid=uuid.uuid4())
product = Product.objects.create(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_extend_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class QuerySerializer(serializers.Serializer):
min_length=3, max_length=10, help_text='filter by containing string', required=False
)
order_by = serializers.MultipleChoiceField(
choices=['a', 'b'],
choices=['a', 'b', 'c'],
default=['a'],
)

Expand Down
12 changes: 8 additions & 4 deletions tests/test_extend_schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ paths:
enum:
- a
- b
- c
type: string
default:
- a
Expand Down Expand Up @@ -339,6 +340,12 @@ components:
required:
- inline_b
- inline_i
OrderByEnum:
enum:
- a
- b
- c
type: string
Query:
type: object
properties:
Expand All @@ -355,10 +362,7 @@ components:
order_by:
type: array
items:
enum:
- a
- b
type: string
$ref: '#/components/schemas/OrderByEnum'
default:
- a
required:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,3 +1244,29 @@ class XViewset(viewsets.ModelViewSet):
schema = SchemaGenerator(patterns=router.urls).get_schema(request=None, public=True)
assert '/x/{related_field}/' in schema['paths']
assert '/x/{related_field}/{id}/' in schema['paths']


@pytest.mark.contrib('psycopg2')
def test_multiple_choice_enum(no_warnings):
from django.contrib.postgres.fields import ArrayField

class M4(models.Model):
multi = ArrayField(
models.CharField(max_length=10, choices=(('A', 'A'), ('B', 'B'))),
size=8,
)

class M4Serializer(serializers.ModelSerializer):
class Meta:
fields = '__all__'
model = M4

class XViewset(viewsets.ModelViewSet):
serializer_class = M4Serializer
queryset = M4.objects.none()

schema = generate_schema('x', XViewset)
assert 'MultiEnum' in schema['components']['schemas']
prop = schema['components']['schemas']['M4']['properties']['multi']
assert prop['type'] == 'array'
assert prop['items']['$ref'] == '#/components/schemas/MultiEnum'

0 comments on commit 5c9d446

Please sign in to comment.