From 25d13b431ce7457124a41feb0f461ef14fcda12d Mon Sep 17 00:00:00 2001 From: yinian1992 Date: Wed, 15 Dec 2021 19:29:32 +0800 Subject: [PATCH] fix annotated trivial type fallback and type dependency resolving --- django_rest_tsg/typescript.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/django_rest_tsg/typescript.py b/django_rest_tsg/typescript.py index fbd30db..886497e 100644 --- a/django_rest_tsg/typescript.py +++ b/django_rest_tsg/typescript.py @@ -115,6 +115,9 @@ def tokenize_python_type(tp) -> list[Union[type, str]]: * Brackets: '[' or ']' * User defined types """ + if get_origin(tp) is Annotated: + tp = get_args(tp)[0] + # non-generic fallback origin = get_origin(tp) if not origin: @@ -129,7 +132,7 @@ def tokenize_python_type(tp) -> list[Union[type, str]]: if top == RIGHT_BRACKET: result.append(top) continue - elif isinstance(top, str) or top in TYPE_MAPPING: + elif isinstance(top, str) or isinstance(top, int) or isinstance(top, float) or top in TYPE_MAPPING: result.append(top) current_type = top continue @@ -141,14 +144,13 @@ def tokenize_python_type(tp) -> list[Union[type, str]]: if origin is Annotated: current_type = get_args(current_type)[0] continue - elif origin in GENERICS: + if origin in GENERICS: result.append(origin) result.append(LEFT_BRACKET) stack.append(RIGHT_BRACKET) args = get_args(current_type) for arg in reversed(args): stack.append(arg) - print(current_type, args, stack) elif origin in TYPE_MAPPING: result.append(origin) # generic fallback @@ -215,7 +217,7 @@ def _build_generic_type(tp, children=None) -> str: if isinstance(child, str): part = f"'{child}'" else: - part = child + part = str(child) parts.append(part) return ' | '.join(parts) @@ -225,7 +227,9 @@ def build_type(tp) -> Tuple[str, List[Type]]: Build typescript type from python type. """ tokens = tokenize_python_type(tp) - dependencies = [token for token in tokens if token not in TYPE_MAPPING_WITH_GENERIC_FALLBACK] + dependencies = [token for token in tokens if + token not in TYPE_MAPPING_WITH_GENERIC_FALLBACK and + not type(token) in TRIVIAL_TYPE_MAPPING] return _build_type(tokens), dependencies @@ -296,7 +300,7 @@ def _get_serializer_field_type(field: Field) -> Tuple[str, Optional[Type]]: if isinstance(value, str): part = f"'{value}'" else: - part = value + part = str(value) parts.append(part) field_type = ' | '.join(parts) elif isinstance(field, ManyRelatedField): @@ -307,6 +311,10 @@ def _get_serializer_field_type(field: Field) -> Tuple[str, Optional[Type]]: elif isinstance(field, Serializer): field_type = get_serializer_prefix(type(field)) dependency = field + else: + field_type = TYPESCRIPT_ANY + if field_type != TYPESCRIPT_ANY and field.allow_null: + field_type += TYPESCRIPT_NULLABLE return field_type, dependency @@ -336,7 +344,7 @@ def get_serializer_field_type(field: Field) -> Tuple[str, list]: result = f'{{[index: string]: {result}}}' else: result = item - return result, sorted(list(dependencies)) + return result, sorted(list(dependencies), key=lambda tp: tp.__name__) def build_interface_from_serializer(serializer_class: Type[Serializer], @@ -349,9 +357,12 @@ def build_interface_from_serializer(serializer_class: Type[Serializer], interface_fields = [] interface_dependencies = set() for field_name, field_instance in serializer.get_fields().items(): + field_instance: Field field_type = type(field_instance) if field_type in DRF_FIELD_MAPPING: field_type = DRF_FIELD_MAPPING[field_type] + if field_instance.allow_null: + field_type += TYPESCRIPT_NULLABLE else: field_type, field_dependencies = get_serializer_field_type(field_instance) for dependency in field_dependencies: @@ -363,6 +374,6 @@ def build_interface_from_serializer(serializer_class: Type[Serializer], if not interface_name: interface_name = get_serializer_prefix(serializer_class) return TypeScriptCode(type=TypeScriptCodeType.INTERFACE, source=serializer_class, name=interface_name, - dependencies=sorted(list(interface_dependencies)), + dependencies=sorted(list(interface_dependencies), key=lambda tp: tp.__name__), content=INTERFACE_TEMPLATE.substitute(fields='\n'.join(interface_fields), name=interface_name))