@@ -763,20 +763,39 @@ def _check_style_compatibility(self, ctx: Context, style: str, value: t.Any) ->
763
763
def convert (self , ctx : Context ) -> t .Any :
764
764
datatype = ctx .datatype
765
765
union : t .Optional [Union ]
766
- literal_types : list [TypeHint ] = []
767
766
768
767
if isinstance (datatype , UnionTypeHint ):
769
768
if datatype .has_none_type ():
770
769
raise NotImplementedError ("unable to handle Union type with None in it" )
771
770
772
- literal_types = [a for a in datatype if isinstance (a , LiteralTypeHint )]
771
+ literal_values = [
772
+ literal_value
773
+ for literal_type in datatype
774
+ if isinstance (literal_type , LiteralTypeHint )
775
+ for literal_value in literal_type .values
776
+ ]
777
+
778
+ literal_type : t .Optional [TypeHint ] = None
779
+ if literal_values :
780
+ literal_type = t .Literal [tuple (literal_values )] # type: ignore
781
+
773
782
non_literal_types = [a for a in datatype if not isinstance (a , LiteralTypeHint )]
774
783
if not all (isinstance (a , ClassTypeHint ) for a in non_literal_types ):
775
784
raise NotImplementedError (f"members of plain Union must be concrete or Literal types: { datatype } " )
776
785
777
786
members = {t .cast (ClassTypeHint , a ).type .__name__ : a for a in non_literal_types }
778
787
if len (members ) != len (non_literal_types ):
779
- raise NotImplementedError (f"members of plain Union cannot have overlapping type names: { datatype } " )
788
+ raise ConversionError (
789
+ self , ctx , f"members of plain Union cannot have overlapping type names: { datatype } "
790
+ )
791
+
792
+ if literal_type is not None :
793
+ if "Literal" in members :
794
+ raise ConversionError (
795
+ self , ctx , f"members of plain Union with a Literal cannot have type name Literal: { datatype } "
796
+ )
797
+ members ["Literal" ] = literal_type
798
+
780
799
union = Union (members , Union .BEST_MATCH )
781
800
elif isinstance (datatype , (AnnotatedTypeHint , ClassTypeHint )):
782
801
union = ctx .get_setting (Union )
@@ -794,11 +813,6 @@ def convert(self, ctx: Context) -> t.Any:
794
813
return ctx .spawn (ctx .value , member_type , None ).convert ()
795
814
except ConversionError as exc :
796
815
errors .append ((exc .origin , exc ))
797
- for literal_type in literal_types :
798
- try :
799
- return ctx .spawn (ctx .value , literal_type , None ).convert ()
800
- except ConversionError as exc :
801
- errors .append ((exc .origin , exc ))
802
816
raise ConversionError (
803
817
self ,
804
818
ctx ,
0 commit comments