Skip to content

Commit

Permalink
fix annotated trivial type fallback and type dependency resolving
Browse files Browse the repository at this point in the history
  • Loading branch information
yinian1992 committed Dec 15, 2021
1 parent e9ad17d commit 25d13b4
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions django_rest_tsg/typescript.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def tokenize_python_type(tp) -> list[Union[type, str]]:
* Brackets: '[' or ']'
* User defined types
"""
if get_origin(tp) is Annotated:
tp = get_args(tp)[0]

# non-generic fallback
origin = get_origin(tp)
if not origin:
Expand All @@ -129,7 +132,7 @@ def tokenize_python_type(tp) -> list[Union[type, str]]:
if top == RIGHT_BRACKET:
result.append(top)
continue
elif isinstance(top, str) or top in TYPE_MAPPING:
elif isinstance(top, str) or isinstance(top, int) or isinstance(top, float) or top in TYPE_MAPPING:
result.append(top)
current_type = top
continue
Expand All @@ -141,14 +144,13 @@ def tokenize_python_type(tp) -> list[Union[type, str]]:
if origin is Annotated:
current_type = get_args(current_type)[0]
continue
elif origin in GENERICS:
if origin in GENERICS:
result.append(origin)
result.append(LEFT_BRACKET)
stack.append(RIGHT_BRACKET)
args = get_args(current_type)
for arg in reversed(args):
stack.append(arg)
print(current_type, args, stack)
elif origin in TYPE_MAPPING:
result.append(origin)
# generic fallback
Expand Down Expand Up @@ -215,7 +217,7 @@ def _build_generic_type(tp, children=None) -> str:
if isinstance(child, str):
part = f"'{child}'"
else:
part = child
part = str(child)
parts.append(part)
return ' | '.join(parts)

Expand All @@ -225,7 +227,9 @@ def build_type(tp) -> Tuple[str, List[Type]]:
Build typescript type from python type.
"""
tokens = tokenize_python_type(tp)
dependencies = [token for token in tokens if token not in TYPE_MAPPING_WITH_GENERIC_FALLBACK]
dependencies = [token for token in tokens if
token not in TYPE_MAPPING_WITH_GENERIC_FALLBACK and
not type(token) in TRIVIAL_TYPE_MAPPING]
return _build_type(tokens), dependencies


Expand Down Expand Up @@ -296,7 +300,7 @@ def _get_serializer_field_type(field: Field) -> Tuple[str, Optional[Type]]:
if isinstance(value, str):
part = f"'{value}'"
else:
part = value
part = str(value)
parts.append(part)
field_type = ' | '.join(parts)
elif isinstance(field, ManyRelatedField):
Expand All @@ -307,6 +311,10 @@ def _get_serializer_field_type(field: Field) -> Tuple[str, Optional[Type]]:
elif isinstance(field, Serializer):
field_type = get_serializer_prefix(type(field))
dependency = field
else:
field_type = TYPESCRIPT_ANY
if field_type != TYPESCRIPT_ANY and field.allow_null:
field_type += TYPESCRIPT_NULLABLE
return field_type, dependency


Expand Down Expand Up @@ -336,7 +344,7 @@ def get_serializer_field_type(field: Field) -> Tuple[str, list]:
result = f'{{[index: string]: {result}}}'
else:
result = item
return result, sorted(list(dependencies))
return result, sorted(list(dependencies), key=lambda tp: tp.__name__)


def build_interface_from_serializer(serializer_class: Type[Serializer],
Expand All @@ -349,9 +357,12 @@ def build_interface_from_serializer(serializer_class: Type[Serializer],
interface_fields = []
interface_dependencies = set()
for field_name, field_instance in serializer.get_fields().items():
field_instance: Field
field_type = type(field_instance)
if field_type in DRF_FIELD_MAPPING:
field_type = DRF_FIELD_MAPPING[field_type]
if field_instance.allow_null:
field_type += TYPESCRIPT_NULLABLE
else:
field_type, field_dependencies = get_serializer_field_type(field_instance)
for dependency in field_dependencies:
Expand All @@ -363,6 +374,6 @@ def build_interface_from_serializer(serializer_class: Type[Serializer],
if not interface_name:
interface_name = get_serializer_prefix(serializer_class)
return TypeScriptCode(type=TypeScriptCodeType.INTERFACE, source=serializer_class, name=interface_name,
dependencies=sorted(list(interface_dependencies)),
dependencies=sorted(list(interface_dependencies), key=lambda tp: tp.__name__),
content=INTERFACE_TEMPLATE.substitute(fields='\n'.join(interface_fields),
name=interface_name))

0 comments on commit 25d13b4

Please sign in to comment.