Skip to content

Commit 0c5094b

Browse files
committed
Structure code so by choosing {"type": Literal} we can pick the Literal
1 parent d54db1c commit 0c5094b

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

databind/src/databind/json/converters.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -763,20 +763,39 @@ def _check_style_compatibility(self, ctx: Context, style: str, value: t.Any) ->
763763
def convert(self, ctx: Context) -> t.Any:
764764
datatype = ctx.datatype
765765
union: t.Optional[Union]
766-
literal_types: list[TypeHint] = []
767766

768767
if isinstance(datatype, UnionTypeHint):
769768
if datatype.has_none_type():
770769
raise NotImplementedError("unable to handle Union type with None in it")
771770

772-
literal_types = [a for a in datatype if isinstance(a, LiteralTypeHint)]
771+
literal_values = [
772+
literal_value
773+
for literal_type in datatype
774+
if isinstance(literal_type, LiteralTypeHint)
775+
for literal_value in literal_type.values
776+
]
777+
778+
literal_type: t.Optional[TypeHint] = None
779+
if literal_values:
780+
literal_type = t.Literal[tuple(literal_values)] # type: ignore
781+
773782
non_literal_types = [a for a in datatype if not isinstance(a, LiteralTypeHint)]
774783
if not all(isinstance(a, ClassTypeHint) for a in non_literal_types):
775784
raise NotImplementedError(f"members of plain Union must be concrete or Literal types: {datatype}")
776785

777786
members = {t.cast(ClassTypeHint, a).type.__name__: a for a in non_literal_types}
778787
if len(members) != len(non_literal_types):
779-
raise NotImplementedError(f"members of plain Union cannot have overlapping type names: {datatype}")
788+
raise ConversionError(
789+
self, ctx, f"members of plain Union cannot have overlapping type names: {datatype}"
790+
)
791+
792+
if literal_type is not None:
793+
if "Literal" in members:
794+
raise ConversionError(
795+
self, ctx, f"members of plain Union with a Literal cannot have type name Literal: {datatype}"
796+
)
797+
members["Literal"] = literal_type
798+
780799
union = Union(members, Union.BEST_MATCH)
781800
elif isinstance(datatype, (AnnotatedTypeHint, ClassTypeHint)):
782801
union = ctx.get_setting(Union)
@@ -794,11 +813,6 @@ def convert(self, ctx: Context) -> t.Any:
794813
return ctx.spawn(ctx.value, member_type, None).convert()
795814
except ConversionError as exc:
796815
errors.append((exc.origin, exc))
797-
for literal_type in literal_types:
798-
try:
799-
return ctx.spawn(ctx.value, literal_type, None).convert()
800-
except ConversionError as exc:
801-
errors.append((exc.origin, exc))
802816
raise ConversionError(
803817
self,
804818
ctx,

0 commit comments

Comments
 (0)