@@ -82,15 +82,19 @@ def __repr__(self) -> str:
8282 op_str = "<:"
8383 if self .op == SUPERTYPE_OF :
8484 op_str = ":>"
85- return f"{ self .type_var } { op_str } { self .target } "
85+ return f"{ self .origin_type_var } { op_str } { self .target } "
8686
8787 def __hash__ (self ) -> int :
88- return hash ((self .type_var , self .op , self .target ))
88+ return hash ((self .origin_type_var , self .op , self .target ))
8989
9090 def __eq__ (self , other : object ) -> bool :
9191 if not isinstance (other , Constraint ):
9292 return False
93- return (self .type_var , self .op , self .target ) == (other .type_var , other .op , other .target )
93+ return (self .origin_type_var , self .op , self .target ) == (
94+ other .origin_type_var ,
95+ other .op ,
96+ other .target ,
97+ )
9498
9599
96100def infer_constraints_for_callable (
@@ -698,25 +702,54 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
698702 )
699703 elif isinstance (tvar , ParamSpecType ) and isinstance (mapped_arg , ParamSpecType ):
700704 suffix = get_proper_type (instance_arg )
705+ prefix = mapped_arg .prefix
706+ length = len (prefix .arg_types )
701707
702708 if isinstance (suffix , CallableType ):
703- prefix = mapped_arg .prefix
704709 from_concat = bool (prefix .arg_types ) or suffix .from_concatenate
705710 suffix = suffix .copy_modified (from_concatenate = from_concat )
706711
707712 if isinstance (suffix , (Parameters , CallableType )):
708713 # no such thing as variance for ParamSpecs
709714 # TODO: is there a case I am missing?
710- # TODO: constraints between prefixes
711- prefix = mapped_arg .prefix
712- suffix = suffix .copy_modified (
713- suffix .arg_types [len (prefix .arg_types ) :],
714- suffix .arg_kinds [len (prefix .arg_kinds ) :],
715- suffix .arg_names [len (prefix .arg_names ) :],
715+ length = min (length , len (suffix .arg_types ))
716+
717+ constrained_to = suffix .copy_modified (
718+ suffix .arg_types [length :],
719+ suffix .arg_kinds [length :],
720+ suffix .arg_names [length :],
721+ )
722+ constrained_from = mapped_arg .copy_modified (
723+ prefix = prefix .copy_modified (
724+ prefix .arg_types [length :],
725+ prefix .arg_kinds [length :],
726+ prefix .arg_names [length :],
727+ )
716728 )
717- res .append (Constraint (mapped_arg , SUPERTYPE_OF , suffix ))
729+
730+ res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained_to ))
731+ res .append (Constraint (constrained_from , SUBTYPE_OF , constrained_to ))
718732 elif isinstance (suffix , ParamSpecType ):
719- res .append (Constraint (mapped_arg , SUPERTYPE_OF , suffix ))
733+ suffix_prefix = suffix .prefix
734+ length = min (length , len (suffix_prefix .arg_types ))
735+
736+ constrained = suffix .copy_modified (
737+ prefix = suffix_prefix .copy_modified (
738+ suffix_prefix .arg_types [length :],
739+ suffix_prefix .arg_kinds [length :],
740+ suffix_prefix .arg_names [length :],
741+ )
742+ )
743+ constrained_from = mapped_arg .copy_modified (
744+ prefix = prefix .copy_modified (
745+ prefix .arg_types [length :],
746+ prefix .arg_kinds [length :],
747+ prefix .arg_names [length :],
748+ )
749+ )
750+
751+ res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained ))
752+ res .append (Constraint (constrained_from , SUBTYPE_OF , constrained ))
720753 else :
721754 # This case should have been handled above.
722755 assert not isinstance (tvar , TypeVarTupleType )
@@ -768,26 +801,56 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
768801 template_arg , ParamSpecType
769802 ):
770803 suffix = get_proper_type (mapped_arg )
804+ prefix = template_arg .prefix
805+ length = len (prefix .arg_types )
771806
772807 if isinstance (suffix , CallableType ):
773808 prefix = template_arg .prefix
774809 from_concat = bool (prefix .arg_types ) or suffix .from_concatenate
775810 suffix = suffix .copy_modified (from_concatenate = from_concat )
776811
812+ # TODO: this is almost a copy-paste of code above: make this into a function
777813 if isinstance (suffix , (Parameters , CallableType )):
778814 # no such thing as variance for ParamSpecs
779815 # TODO: is there a case I am missing?
780- # TODO: constraints between prefixes
781- prefix = template_arg .prefix
816+ length = min (length , len (suffix .arg_types ))
782817
783- suffix = suffix .copy_modified (
784- suffix .arg_types [len ( prefix . arg_types ) :],
785- suffix .arg_kinds [len ( prefix . arg_kinds ) :],
786- suffix .arg_names [len ( prefix . arg_names ) :],
818+ constrained_to = suffix .copy_modified (
819+ suffix .arg_types [length :],
820+ suffix .arg_kinds [length :],
821+ suffix .arg_names [length :],
787822 )
788- res .append (Constraint (template_arg , SUPERTYPE_OF , suffix ))
823+ constrained_from = template_arg .copy_modified (
824+ prefix = prefix .copy_modified (
825+ prefix .arg_types [length :],
826+ prefix .arg_kinds [length :],
827+ prefix .arg_names [length :],
828+ )
829+ )
830+
831+ res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained_to ))
832+ res .append (Constraint (constrained_from , SUBTYPE_OF , constrained_to ))
789833 elif isinstance (suffix , ParamSpecType ):
790- res .append (Constraint (template_arg , SUPERTYPE_OF , suffix ))
834+ suffix_prefix = suffix .prefix
835+ length = min (length , len (suffix_prefix .arg_types ))
836+
837+ constrained = suffix .copy_modified (
838+ prefix = suffix_prefix .copy_modified (
839+ suffix_prefix .arg_types [length :],
840+ suffix_prefix .arg_kinds [length :],
841+ suffix_prefix .arg_names [length :],
842+ )
843+ )
844+ constrained_from = template_arg .copy_modified (
845+ prefix = prefix .copy_modified (
846+ prefix .arg_types [length :],
847+ prefix .arg_kinds [length :],
848+ prefix .arg_names [length :],
849+ )
850+ )
851+
852+ res .append (Constraint (constrained_from , SUPERTYPE_OF , constrained ))
853+ res .append (Constraint (constrained_from , SUBTYPE_OF , constrained ))
791854 else :
792855 # This case should have been handled above.
793856 assert not isinstance (tvar , TypeVarTupleType )
@@ -954,9 +1017,19 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
9541017 prefix_len = len (prefix .arg_types )
9551018 cactual_ps = cactual .param_spec ()
9561019
1020+ cactual_prefix : Parameters | CallableType
1021+ if cactual_ps :
1022+ cactual_prefix = cactual_ps .prefix
1023+ else :
1024+ cactual_prefix = cactual
1025+
1026+ max_prefix_len = len (
1027+ [k for k in cactual_prefix .arg_kinds if k in (ARG_POS , ARG_OPT )]
1028+ )
1029+ prefix_len = min (prefix_len , max_prefix_len )
1030+
1031+ # we could check the prefixes match here, but that should be caught elsewhere.
9571032 if not cactual_ps :
958- max_prefix_len = len ([k for k in cactual .arg_kinds if k in (ARG_POS , ARG_OPT )])
959- prefix_len = min (prefix_len , max_prefix_len )
9601033 res .append (
9611034 Constraint (
9621035 param_spec ,
@@ -970,7 +1043,17 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
9701043 )
9711044 )
9721045 else :
973- res .append (Constraint (param_spec , SUBTYPE_OF , cactual_ps ))
1046+ # earlier, cactual_prefix = cactual_ps.prefix. thus, this is guaranteed
1047+ assert isinstance (cactual_prefix , Parameters )
1048+
1049+ constrained_by = cactual_ps .copy_modified (
1050+ prefix = cactual_prefix .copy_modified (
1051+ cactual_prefix .arg_types [prefix_len :],
1052+ cactual_prefix .arg_kinds [prefix_len :],
1053+ cactual_prefix .arg_names [prefix_len :],
1054+ )
1055+ )
1056+ res .append (Constraint (param_spec , SUBTYPE_OF , constrained_by ))
9741057
9751058 # compare prefixes
9761059 cactual_prefix = cactual .copy_modified (
0 commit comments