Skip to content

Commit 4bb1d50

Browse files
refactor variance inferrence to get rid of explicit polarity
1 parent 388d8fd commit 4bb1d50

File tree

6 files changed

+273
-290
lines changed

6 files changed

+273
-290
lines changed

crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,18 +475,17 @@ contravariant occurrence _of_ `D`. The latter occurrence is ultimately thus cova
475475
occurrence in contravariant position). Then `C` has both a covariant and a contravariant occurrence
476476
of `D`, so it is invariant.
477477

478+
TODO: the bottom set of these are failing, perhaps something to do with subclasses of specialized
479+
aliases
480+
478481
```py
479482
static_assert(not is_subtype_of(C[B], C[A]))
480483
static_assert(not is_subtype_of(C[A], C[B]))
481484
static_assert(not is_subtype_of(C[A], C[Any]))
482485
static_assert(not is_subtype_of(C[B], C[Any]))
483486
static_assert(not is_subtype_of(C[Any], C[A]))
484487
static_assert(not is_subtype_of(C[Any], C[B]))
485-
```
486488

487-
TODO: these are failing
488-
489-
```py
490489
static_assert(not is_subtype_of(C[B].D, C[A].D))
491490
static_assert(is_subtype_of(C[A].D, C[B].D))
492491
static_assert(not is_subtype_of(C[A].D, C[Any].D))

crates/ty_python_semantic/src/types.rs

Lines changed: 127 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,65 @@ fn class_lookup_cycle_initial<'db>(
284284
Place::bound(Type::Never).into()
285285
}
286286

287+
pub(crate) trait VarianceInferable<'db>: Sized {
288+
fn variance_of(self, db: &'db dyn Db, type_var: TypeVarInstance<'db>) -> TypeVarVariance;
289+
290+
fn with_polarity(self, polarity: TypeVarVariance) -> WithPolarity<Self> {
291+
WithPolarity {
292+
variance_inferable: self,
293+
polarity,
294+
}
295+
}
296+
}
297+
298+
pub(crate) struct WithPolarity<T> {
299+
variance_inferable: T,
300+
polarity: TypeVarVariance,
301+
}
302+
303+
impl<'db, T> VarianceInferable<'db> for WithPolarity<T>
304+
where
305+
T: VarianceInferable<'db>,
306+
{
307+
// Based on the variance composition/transformation operator in
308+
// https://people.cs.umass.edu/~yannis/variance-extended2011.pdf, page 5
309+
//
310+
// While their operation has compose(invariant, bivariant) = invariant, we
311+
// instead have it evalaute to bivariant. This is a valid choice, as
312+
// discussed on that same page, where type equality is semantic rather than
313+
// syntactic. To see that this holds for our setting consider the type
314+
// ```python
315+
// type ConstantInt[T] = int
316+
// ```
317+
// We would say `ConstantInt[str]` = ConstantInt[float], so we qualify as
318+
// using semantic equivalence.
319+
fn variance_of(self, db: &'db dyn Db, type_var: TypeVarInstance<'db>) -> TypeVarVariance {
320+
let WithPolarity {
321+
variance_inferable,
322+
polarity,
323+
} = self;
324+
match polarity {
325+
TypeVarVariance::Covariant => variance_inferable.variance_of(db, type_var),
326+
TypeVarVariance::Contravariant => variance_inferable.variance_of(db, type_var).flip(),
327+
TypeVarVariance::Bivariant => TypeVarVariance::Bivariant,
328+
TypeVarVariance::Invariant => {
329+
if let TypeVarVariance::Bivariant = variance_inferable.variance_of(db, type_var) {
330+
TypeVarVariance::Bivariant
331+
} else {
332+
TypeVarVariance::Invariant
333+
}
334+
}
335+
}
336+
}
337+
}
338+
287339
#[allow(clippy::trivially_copy_pass_by_ref)]
288340
fn variance_cycle_recover<'db, T>(
289341
_db: &'db dyn Db,
290342
_value: &TypeVarVariance,
291343
count: u32,
292344
_self: T,
293345
_type_var: TypeVarInstance<'db>,
294-
_variance: TypeVarVariance,
295346
) -> salsa::CycleRecoveryAction<TypeVarVariance> {
296347
assert!(
297348
count <= 2,
@@ -304,7 +355,6 @@ fn variance_cycle_initial<'db, T>(
304355
_db: &'db dyn Db,
305356
_self: T,
306357
_type_var: TypeVarInstance<'db>,
307-
_variance: TypeVarVariance,
308358
) -> TypeVarVariance {
309359
TypeVarVariance::Bivariant
310360
}
@@ -757,115 +807,6 @@ impl<'db> Type<'db> {
757807
}
758808
}
759809

760-
#[salsa::tracked(cycle_fn=variance_cycle_recover, cycle_initial=variance_cycle_initial)]
761-
fn variance_of(
762-
self,
763-
db: &'db dyn Db,
764-
type_var: TypeVarInstance<'db>,
765-
variance: TypeVarVariance,
766-
) -> TypeVarVariance {
767-
tracing::debug!(
768-
"Checking variance of '{tvar}' in `{ty:?}` (currently `{variance:?}`)",
769-
tvar = type_var.name(db),
770-
ty = self,
771-
variance = variance
772-
);
773-
774-
// Some optimizations:
775-
// we rewrite all inference to be in terms of covariant polarity.
776-
// Consolidating along a single polarity allows us to re-use the cache,
777-
// i.e. if we need the variance of T in C[T] at both covariant and
778-
// contravariant polarities, instead of traversing everything twice we
779-
// can re-use the contravariance result.
780-
// This is possible due to homomorphism of the variance lattice over inference, in particular
781-
// ty.variance_of(tvar, p).flip() === ty.variance_of(tvar, p.flip())
782-
// and
783-
// ty.variance_of(tvar, p).join(ty.variance_of(tvar, q)) === ty.variance_of(tvar, p.join(q))
784-
match variance {
785-
// If the variance is bivariant, it will never change, because the
786-
// only way this parameter changes is by flipping. Then either the
787-
// type variable doesn't occur in `self` at all in which case it's
788-
// bivariant, or occurs, but because our current polarity is
789-
// bivariant, that will be bivariant as well.
790-
TypeVarVariance::Bivariant => return TypeVarVariance::Bivariant,
791-
TypeVarVariance::Invariant => {
792-
return self
793-
.variance_of(db, type_var, TypeVarVariance::Covariant)
794-
.join(self.variance_of(db, type_var, TypeVarVariance::Contravariant));
795-
}
796-
TypeVarVariance::Covariant => (), // proceed
797-
TypeVarVariance::Contravariant => {
798-
return self
799-
.variance_of(db, type_var, TypeVarVariance::Covariant)
800-
.flip();
801-
}
802-
}
803-
804-
let v = match self {
805-
Type::ClassLiteral(class_literal) => class_literal.variance_of(db, type_var, variance),
806-
807-
Type::FunctionLiteral(function_type) => {
808-
function_type
809-
.signature(db)
810-
.variance_of(db, type_var, variance)
811-
}
812-
813-
Type::BoundMethod(method_type) => {
814-
// TODO: do we need to replace self?
815-
method_type
816-
.function(db)
817-
.signature(db)
818-
.variance_of(db, type_var, variance)
819-
}
820-
821-
Type::NominalInstance(nominal_instance_type) => {
822-
nominal_instance_type.variance_of(db, type_var, variance)
823-
}
824-
Type::GenericAlias(generic_alias) => generic_alias.variance_of(db, type_var, variance),
825-
Type::Callable(callable_type) => callable_type
826-
.signatures(db)
827-
.variance_of(db, type_var, variance),
828-
Type::TypeVar(other_type_var) if other_type_var == type_var => {
829-
// If the TypeVar occurs, we return the variance of the TypeVar itself.
830-
variance
831-
}
832-
Type::ProtocolInstance(protocol_instance_type) => protocol_instance_type.variance_of(db, type_var, variance),
833-
Type::Union(union_type) => union_type.elements(db).iter().map(|ty| ty.variance_of(db, type_var, variance)).collect(),
834-
Type::Intersection(intersection_type) => itertools::chain(intersection_type.positive(db).iter().map(|ty| ty.variance_of(db, type_var, variance)),
835-
intersection_type.negative(db).iter().map(|ty| ty.variance_of(db, type_var, variance.flip()))).collect(),
836-
Type::Tuple(tuple_type) => tuple_type.elements(db).iter().map(|ty| ty.variance_of(db, type_var, variance)).collect(),
837-
| Type::Dynamic(_)
838-
| Type::Never
839-
| Type::WrapperDescriptor(_)
840-
| Type::MethodWrapper(_)
841-
| Type::DataclassDecorator(_)
842-
| Type::DataclassTransformer(_)
843-
| Type::ModuleLiteral(_)
844-
| Type::IntLiteral(_)
845-
| Type::BooleanLiteral(_)
846-
| Type::StringLiteral(_)
847-
| Type::LiteralString
848-
| Type::BytesLiteral(_)
849-
| Type::SpecialForm(_)
850-
| Type::KnownInstance(_)
851-
| Type::AlwaysFalsy
852-
| Type::AlwaysTruthy
853-
| Type::PropertyInstance(_)
854-
| Type::BoundSuper(_)
855-
| Type::SubclassOf(_) // TODO: double check
856-
| Type::TypeVar(_)
857-
| Type::TypeIs(_) => TypeVarVariance::Bivariant,
858-
};
859-
860-
tracing::debug!(
861-
"Result of variance of '{tvar}' in `{ty:?}` is `{v:?}` (polarity was `{variance:?}`)",
862-
tvar = type_var.name(db),
863-
ty = self,
864-
variance = variance
865-
);
866-
v
867-
}
868-
869810
/// Replace references to the class `class` with a self-reference marker. This is currently
870811
/// used for recursive protocols, but could probably be extended to self-referential type-
871812
/// aliases and similar.
@@ -5794,6 +5735,79 @@ impl<'db> From<&Type<'db>> for Type<'db> {
57945735
}
57955736
}
57965737

5738+
impl<'db> VarianceInferable<'db> for Type<'db> {
5739+
fn variance_of(self, db: &'db dyn Db, type_var: TypeVarInstance) -> TypeVarVariance {
5740+
tracing::debug!(
5741+
"Checking variance of '{tvar}' in `{ty:?}`",
5742+
tvar = type_var.name(db),
5743+
ty = self,
5744+
);
5745+
5746+
let v = match self {
5747+
Type::ClassLiteral(class_literal) => class_literal.variance_of(db, type_var),
5748+
5749+
Type::FunctionLiteral(function_type) => {
5750+
function_type
5751+
.signature(db)
5752+
.variance_of(db, type_var)
5753+
}
5754+
5755+
Type::BoundMethod(method_type) => {
5756+
// TODO: do we need to replace self?
5757+
method_type
5758+
.function(db)
5759+
.signature(db)
5760+
.variance_of(db, type_var)
5761+
}
5762+
5763+
Type::NominalInstance(nominal_instance_type) => {
5764+
nominal_instance_type.variance_of(db, type_var)
5765+
}
5766+
Type::GenericAlias(generic_alias) => generic_alias.variance_of(db, type_var),
5767+
Type::Callable(callable_type) => callable_type
5768+
.signatures(db)
5769+
.variance_of(db, type_var),
5770+
Type::TypeVar(other_type_var) if other_type_var == type_var => {
5771+
// type variables are covariant in themselves
5772+
TypeVarVariance::Covariant
5773+
}
5774+
Type::ProtocolInstance(protocol_instance_type) => protocol_instance_type.variance_of(db, type_var),
5775+
Type::Union(union_type) => union_type.elements(db).iter().map(|ty| ty.variance_of(db, type_var)).collect(),
5776+
Type::Intersection(intersection_type) => itertools::chain(intersection_type.positive(db).iter().map(|ty| ty.variance_of(db, type_var)),
5777+
intersection_type.negative(db).iter().map(|ty| ty.with_polarity(TypeVarVariance::Contravariant).variance_of(db, type_var))).collect(),
5778+
Type::Tuple(tuple_type) => tuple_type.tuple(db).all_elements().map(|ty| ty.variance_of(db, type_var)).collect(),
5779+
| Type::Dynamic(_)
5780+
| Type::Never
5781+
| Type::WrapperDescriptor(_)
5782+
| Type::MethodWrapper(_)
5783+
| Type::DataclassDecorator(_)
5784+
| Type::DataclassTransformer(_)
5785+
| Type::ModuleLiteral(_)
5786+
| Type::IntLiteral(_)
5787+
| Type::BooleanLiteral(_)
5788+
| Type::StringLiteral(_)
5789+
| Type::LiteralString
5790+
| Type::BytesLiteral(_)
5791+
| Type::SpecialForm(_)
5792+
| Type::KnownInstance(_)
5793+
| Type::AlwaysFalsy
5794+
| Type::AlwaysTruthy
5795+
| Type::PropertyInstance(_)
5796+
| Type::BoundSuper(_)
5797+
| Type::SubclassOf(_) // TODO: double check
5798+
| Type::TypeVar(_)
5799+
| Type::TypeIs(_) => TypeVarVariance::Bivariant,
5800+
};
5801+
5802+
tracing::debug!(
5803+
"Result of variance of '{tvar}' in `{ty:?}` is `{v:?}`",
5804+
tvar = type_var.name(db),
5805+
ty = self,
5806+
);
5807+
v
5808+
}
5809+
}
5810+
57975811
/// A mapping that can be applied to a type, producing another type. This is applied inductively to
57985812
/// the components of complex types.
57995813
///
@@ -6282,7 +6296,7 @@ impl<'db> TypeVarInstance<'db> {
62826296

62836297
#[track_caller]
62846298
fn inferred_variance(self, db: &'db dyn Db) -> TypeVarVariance {
6285-
let _span = tracing::trace_span!("variance_of").entered();
6299+
let _span = tracing::trace_span!("inferred_variance").entered();
62866300
assert_eq!(self.kind(db), TypeVarKind::Pep695);
62876301
match self.definition(db) {
62886302
Some(definition) => {
@@ -6317,10 +6331,7 @@ impl<'db> TypeVarInstance<'db> {
63176331
let semantic = semantic_index(db, file);
63186332
let defn = semantic.expect_single_definition(defn_key);
63196333
let type_inference = infer_definition_types(db, defn);
6320-
type_inference
6321-
.binding_type(defn)
6322-
// initially, we want any occurrences of this tvar to be identified as covariant
6323-
.variance_of(db, self, TypeVarVariance::Covariant)
6334+
type_inference.binding_type(defn).variance_of(db, self)
63246335
}
63256336
None => {
63266337
// TODO: idk what to do here

0 commit comments

Comments
 (0)