@@ -284,14 +284,65 @@ fn class_lookup_cycle_initial<'db>(
284284 Place :: bound ( Type :: Never ) . into ( )
285285}
286286
287+ pub ( crate ) trait VarianceInferable < ' db > : Sized {
288+ fn variance_of ( self , db : & ' db dyn Db , type_var : TypeVarInstance < ' db > ) -> TypeVarVariance ;
289+
290+ fn with_polarity ( self , polarity : TypeVarVariance ) -> WithPolarity < Self > {
291+ WithPolarity {
292+ variance_inferable : self ,
293+ polarity,
294+ }
295+ }
296+ }
297+
298+ pub ( crate ) struct WithPolarity < T > {
299+ variance_inferable : T ,
300+ polarity : TypeVarVariance ,
301+ }
302+
303+ impl < ' db , T > VarianceInferable < ' db > for WithPolarity < T >
304+ where
305+ T : VarianceInferable < ' db > ,
306+ {
307+ // Based on the variance composition/transformation operator in
308+ // https://people.cs.umass.edu/~yannis/variance-extended2011.pdf, page 5
309+ //
310+ // While their operation has compose(invariant, bivariant) = invariant, we
311+ // instead have it evalaute to bivariant. This is a valid choice, as
312+ // discussed on that same page, where type equality is semantic rather than
313+ // syntactic. To see that this holds for our setting consider the type
314+ // ```python
315+ // type ConstantInt[T] = int
316+ // ```
317+ // We would say `ConstantInt[str]` = ConstantInt[float], so we qualify as
318+ // using semantic equivalence.
319+ fn variance_of ( self , db : & ' db dyn Db , type_var : TypeVarInstance < ' db > ) -> TypeVarVariance {
320+ let WithPolarity {
321+ variance_inferable,
322+ polarity,
323+ } = self ;
324+ match polarity {
325+ TypeVarVariance :: Covariant => variance_inferable. variance_of ( db, type_var) ,
326+ TypeVarVariance :: Contravariant => variance_inferable. variance_of ( db, type_var) . flip ( ) ,
327+ TypeVarVariance :: Bivariant => TypeVarVariance :: Bivariant ,
328+ TypeVarVariance :: Invariant => {
329+ if let TypeVarVariance :: Bivariant = variance_inferable. variance_of ( db, type_var) {
330+ TypeVarVariance :: Bivariant
331+ } else {
332+ TypeVarVariance :: Invariant
333+ }
334+ }
335+ }
336+ }
337+ }
338+
287339#[ allow( clippy:: trivially_copy_pass_by_ref) ]
288340fn variance_cycle_recover < ' db , T > (
289341 _db : & ' db dyn Db ,
290342 _value : & TypeVarVariance ,
291343 count : u32 ,
292344 _self : T ,
293345 _type_var : TypeVarInstance < ' db > ,
294- _variance : TypeVarVariance ,
295346) -> salsa:: CycleRecoveryAction < TypeVarVariance > {
296347 assert ! (
297348 count <= 2 ,
@@ -304,7 +355,6 @@ fn variance_cycle_initial<'db, T>(
304355 _db : & ' db dyn Db ,
305356 _self : T ,
306357 _type_var : TypeVarInstance < ' db > ,
307- _variance : TypeVarVariance ,
308358) -> TypeVarVariance {
309359 TypeVarVariance :: Bivariant
310360}
@@ -757,115 +807,6 @@ impl<'db> Type<'db> {
757807 }
758808 }
759809
760- #[ salsa:: tracked( cycle_fn=variance_cycle_recover, cycle_initial=variance_cycle_initial) ]
761- fn variance_of (
762- self ,
763- db : & ' db dyn Db ,
764- type_var : TypeVarInstance < ' db > ,
765- variance : TypeVarVariance ,
766- ) -> TypeVarVariance {
767- tracing:: debug!(
768- "Checking variance of '{tvar}' in `{ty:?}` (currently `{variance:?}`)" ,
769- tvar = type_var. name( db) ,
770- ty = self ,
771- variance = variance
772- ) ;
773-
774- // Some optimizations:
775- // we rewrite all inference to be in terms of covariant polarity.
776- // Consolidating along a single polarity allows us to re-use the cache,
777- // i.e. if we need the variance of T in C[T] at both covariant and
778- // contravariant polarities, instead of traversing everything twice we
779- // can re-use the contravariance result.
780- // This is possible due to homomorphism of the variance lattice over inference, in particular
781- // ty.variance_of(tvar, p).flip() === ty.variance_of(tvar, p.flip())
782- // and
783- // ty.variance_of(tvar, p).join(ty.variance_of(tvar, q)) === ty.variance_of(tvar, p.join(q))
784- match variance {
785- // If the variance is bivariant, it will never change, because the
786- // only way this parameter changes is by flipping. Then either the
787- // type variable doesn't occur in `self` at all in which case it's
788- // bivariant, or occurs, but because our current polarity is
789- // bivariant, that will be bivariant as well.
790- TypeVarVariance :: Bivariant => return TypeVarVariance :: Bivariant ,
791- TypeVarVariance :: Invariant => {
792- return self
793- . variance_of ( db, type_var, TypeVarVariance :: Covariant )
794- . join ( self . variance_of ( db, type_var, TypeVarVariance :: Contravariant ) ) ;
795- }
796- TypeVarVariance :: Covariant => ( ) , // proceed
797- TypeVarVariance :: Contravariant => {
798- return self
799- . variance_of ( db, type_var, TypeVarVariance :: Covariant )
800- . flip ( ) ;
801- }
802- }
803-
804- let v = match self {
805- Type :: ClassLiteral ( class_literal) => class_literal. variance_of ( db, type_var, variance) ,
806-
807- Type :: FunctionLiteral ( function_type) => {
808- function_type
809- . signature ( db)
810- . variance_of ( db, type_var, variance)
811- }
812-
813- Type :: BoundMethod ( method_type) => {
814- // TODO: do we need to replace self?
815- method_type
816- . function ( db)
817- . signature ( db)
818- . variance_of ( db, type_var, variance)
819- }
820-
821- Type :: NominalInstance ( nominal_instance_type) => {
822- nominal_instance_type. variance_of ( db, type_var, variance)
823- }
824- Type :: GenericAlias ( generic_alias) => generic_alias. variance_of ( db, type_var, variance) ,
825- Type :: Callable ( callable_type) => callable_type
826- . signatures ( db)
827- . variance_of ( db, type_var, variance) ,
828- Type :: TypeVar ( other_type_var) if other_type_var == type_var => {
829- // If the TypeVar occurs, we return the variance of the TypeVar itself.
830- variance
831- }
832- Type :: ProtocolInstance ( protocol_instance_type) => protocol_instance_type. variance_of ( db, type_var, variance) ,
833- Type :: Union ( union_type) => union_type. elements ( db) . iter ( ) . map ( |ty| ty. variance_of ( db, type_var, variance) ) . collect ( ) ,
834- Type :: Intersection ( intersection_type) => itertools:: chain ( intersection_type. positive ( db) . iter ( ) . map ( |ty| ty. variance_of ( db, type_var, variance) ) ,
835- intersection_type. negative ( db) . iter ( ) . map ( |ty| ty. variance_of ( db, type_var, variance. flip ( ) ) ) ) . collect ( ) ,
836- Type :: Tuple ( tuple_type) => tuple_type. elements ( db) . iter ( ) . map ( |ty| ty. variance_of ( db, type_var, variance) ) . collect ( ) ,
837- | Type :: Dynamic ( _)
838- | Type :: Never
839- | Type :: WrapperDescriptor ( _)
840- | Type :: MethodWrapper ( _)
841- | Type :: DataclassDecorator ( _)
842- | Type :: DataclassTransformer ( _)
843- | Type :: ModuleLiteral ( _)
844- | Type :: IntLiteral ( _)
845- | Type :: BooleanLiteral ( _)
846- | Type :: StringLiteral ( _)
847- | Type :: LiteralString
848- | Type :: BytesLiteral ( _)
849- | Type :: SpecialForm ( _)
850- | Type :: KnownInstance ( _)
851- | Type :: AlwaysFalsy
852- | Type :: AlwaysTruthy
853- | Type :: PropertyInstance ( _)
854- | Type :: BoundSuper ( _)
855- | Type :: SubclassOf ( _) // TODO: double check
856- | Type :: TypeVar ( _)
857- | Type :: TypeIs ( _) => TypeVarVariance :: Bivariant ,
858- } ;
859-
860- tracing:: debug!(
861- "Result of variance of '{tvar}' in `{ty:?}` is `{v:?}` (polarity was `{variance:?}`)" ,
862- tvar = type_var. name( db) ,
863- ty = self ,
864- variance = variance
865- ) ;
866- v
867- }
868-
869810 /// Replace references to the class `class` with a self-reference marker. This is currently
870811 /// used for recursive protocols, but could probably be extended to self-referential type-
871812 /// aliases and similar.
@@ -5794,6 +5735,79 @@ impl<'db> From<&Type<'db>> for Type<'db> {
57945735 }
57955736}
57965737
5738+ impl < ' db > VarianceInferable < ' db > for Type < ' db > {
5739+ fn variance_of ( self , db : & ' db dyn Db , type_var : TypeVarInstance ) -> TypeVarVariance {
5740+ tracing:: debug!(
5741+ "Checking variance of '{tvar}' in `{ty:?}`" ,
5742+ tvar = type_var. name( db) ,
5743+ ty = self ,
5744+ ) ;
5745+
5746+ let v = match self {
5747+ Type :: ClassLiteral ( class_literal) => class_literal. variance_of ( db, type_var) ,
5748+
5749+ Type :: FunctionLiteral ( function_type) => {
5750+ function_type
5751+ . signature ( db)
5752+ . variance_of ( db, type_var)
5753+ }
5754+
5755+ Type :: BoundMethod ( method_type) => {
5756+ // TODO: do we need to replace self?
5757+ method_type
5758+ . function ( db)
5759+ . signature ( db)
5760+ . variance_of ( db, type_var)
5761+ }
5762+
5763+ Type :: NominalInstance ( nominal_instance_type) => {
5764+ nominal_instance_type. variance_of ( db, type_var)
5765+ }
5766+ Type :: GenericAlias ( generic_alias) => generic_alias. variance_of ( db, type_var) ,
5767+ Type :: Callable ( callable_type) => callable_type
5768+ . signatures ( db)
5769+ . variance_of ( db, type_var) ,
5770+ Type :: TypeVar ( other_type_var) if other_type_var == type_var => {
5771+ // type variables are covariant in themselves
5772+ TypeVarVariance :: Covariant
5773+ }
5774+ Type :: ProtocolInstance ( protocol_instance_type) => protocol_instance_type. variance_of ( db, type_var) ,
5775+ Type :: Union ( union_type) => union_type. elements ( db) . iter ( ) . map ( |ty| ty. variance_of ( db, type_var) ) . collect ( ) ,
5776+ Type :: Intersection ( intersection_type) => itertools:: chain ( intersection_type. positive ( db) . iter ( ) . map ( |ty| ty. variance_of ( db, type_var) ) ,
5777+ intersection_type. negative ( db) . iter ( ) . map ( |ty| ty. with_polarity ( TypeVarVariance :: Contravariant ) . variance_of ( db, type_var) ) ) . collect ( ) ,
5778+ Type :: Tuple ( tuple_type) => tuple_type. tuple ( db) . all_elements ( ) . map ( |ty| ty. variance_of ( db, type_var) ) . collect ( ) ,
5779+ | Type :: Dynamic ( _)
5780+ | Type :: Never
5781+ | Type :: WrapperDescriptor ( _)
5782+ | Type :: MethodWrapper ( _)
5783+ | Type :: DataclassDecorator ( _)
5784+ | Type :: DataclassTransformer ( _)
5785+ | Type :: ModuleLiteral ( _)
5786+ | Type :: IntLiteral ( _)
5787+ | Type :: BooleanLiteral ( _)
5788+ | Type :: StringLiteral ( _)
5789+ | Type :: LiteralString
5790+ | Type :: BytesLiteral ( _)
5791+ | Type :: SpecialForm ( _)
5792+ | Type :: KnownInstance ( _)
5793+ | Type :: AlwaysFalsy
5794+ | Type :: AlwaysTruthy
5795+ | Type :: PropertyInstance ( _)
5796+ | Type :: BoundSuper ( _)
5797+ | Type :: SubclassOf ( _) // TODO: double check
5798+ | Type :: TypeVar ( _)
5799+ | Type :: TypeIs ( _) => TypeVarVariance :: Bivariant ,
5800+ } ;
5801+
5802+ tracing:: debug!(
5803+ "Result of variance of '{tvar}' in `{ty:?}` is `{v:?}`" ,
5804+ tvar = type_var. name( db) ,
5805+ ty = self ,
5806+ ) ;
5807+ v
5808+ }
5809+ }
5810+
57975811/// A mapping that can be applied to a type, producing another type. This is applied inductively to
57985812/// the components of complex types.
57995813///
@@ -6282,7 +6296,7 @@ impl<'db> TypeVarInstance<'db> {
62826296
62836297 #[ track_caller]
62846298 fn inferred_variance ( self , db : & ' db dyn Db ) -> TypeVarVariance {
6285- let _span = tracing:: trace_span!( "variance_of " ) . entered ( ) ;
6299+ let _span = tracing:: trace_span!( "inferred_variance " ) . entered ( ) ;
62866300 assert_eq ! ( self . kind( db) , TypeVarKind :: Pep695 ) ;
62876301 match self . definition ( db) {
62886302 Some ( definition) => {
@@ -6317,10 +6331,7 @@ impl<'db> TypeVarInstance<'db> {
63176331 let semantic = semantic_index ( db, file) ;
63186332 let defn = semantic. expect_single_definition ( defn_key) ;
63196333 let type_inference = infer_definition_types ( db, defn) ;
6320- type_inference
6321- . binding_type ( defn)
6322- // initially, we want any occurrences of this tvar to be identified as covariant
6323- . variance_of ( db, self , TypeVarVariance :: Covariant )
6334+ type_inference. binding_type ( defn) . variance_of ( db, self )
63246335 }
63256336 None => {
63266337 // TODO: idk what to do here
0 commit comments