Skip to content

Commit 2e66b81

Browse files
committed
Merge branch 'main' into cjm/fix-pep695
* main: [ty] Add some additional type safety to `CycleDetector` (#19903)
2 parents 493ecc0 + baadb5a commit 2e66b81

File tree

11 files changed

+155
-107
lines changed

11 files changed

+155
-107
lines changed

crates/ty_python_semantic/src/types.rs

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,24 @@ fn definition_expression_type<'db>(
170170
}
171171
}
172172

173+
/// A [`TypeTransformer`] that is used in `apply_type_mapping` methods.
174+
pub(crate) type ApplyTypeMappingVisitor<'db> = TypeTransformer<'db, TypeMapping<'db, 'db>>;
175+
176+
/// A [`PairVisitor`] that is used in `has_relation_to` methods.
177+
pub(crate) type HasRelationToVisitor<'db> = PairVisitor<'db, TypeRelation>;
178+
179+
/// A [`PairVisitor`] that is used in `is_disjoint_from` methods.
180+
pub(crate) type IsDisjointVisitor<'db> = PairVisitor<'db, IsDisjoint>;
181+
pub(crate) struct IsDisjoint;
182+
183+
/// A [`PairVisitor`] that is used in `is_equivalent` methods.
184+
pub(crate) type IsEquivalentVisitor<'db> = PairVisitor<'db, IsEquivalent>;
185+
pub(crate) struct IsEquivalent;
186+
187+
/// A [`TypeTransformer`] that is used in `normalized` methods.
188+
pub(crate) type NormalizedVisitor<'db> = TypeTransformer<'db, Normalized>;
189+
pub(crate) struct Normalized;
190+
173191
/// The descriptor protocol distinguishes two kinds of descriptors. Non-data descriptors
174192
/// define a `__get__` method, while data descriptors additionally define a `__set__`
175193
/// method or a `__delete__` method. This enum is used to categorize attributes into two
@@ -419,7 +437,7 @@ impl<'db> PropertyInstanceType<'db> {
419437
Self::new(db, getter, setter)
420438
}
421439

422-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
440+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
423441
Self::new(
424442
db,
425443
self.getter(db).map(|ty| ty.normalized_impl(db, visitor)),
@@ -1068,7 +1086,7 @@ impl<'db> Type<'db> {
10681086
}
10691087

10701088
#[must_use]
1071-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
1089+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
10721090
match self {
10731091
Type::Union(union) => {
10741092
visitor.visit(self, || Type::Union(union.normalized_impl(db, visitor)))
@@ -1326,7 +1344,7 @@ impl<'db> Type<'db> {
13261344
db: &'db dyn Db,
13271345
target: Type<'db>,
13281346
relation: TypeRelation,
1329-
visitor: &PairVisitor<'db>,
1347+
visitor: &HasRelationToVisitor<'db>,
13301348
) -> bool {
13311349
// Subtyping implies assignability, so if subtyping is reflexive and the two types are
13321350
// equal, it is both a subtype and assignable. Assignability is always reflexive.
@@ -1762,7 +1780,7 @@ impl<'db> Type<'db> {
17621780
self,
17631781
db: &'db dyn Db,
17641782
other: Type<'db>,
1765-
visitor: &PairVisitor<'db>,
1783+
visitor: &IsEquivalentVisitor<'db>,
17661784
) -> bool {
17671785
if self == other {
17681786
return true;
@@ -1848,13 +1866,13 @@ impl<'db> Type<'db> {
18481866
self,
18491867
db: &'db dyn Db,
18501868
other: Type<'db>,
1851-
visitor: &PairVisitor<'db>,
1869+
visitor: &IsDisjointVisitor<'db>,
18521870
) -> bool {
18531871
fn any_protocol_members_absent_or_disjoint<'db>(
18541872
db: &'db dyn Db,
18551873
protocol: ProtocolInstanceType<'db>,
18561874
other: Type<'db>,
1857-
visitor: &PairVisitor<'db>,
1875+
visitor: &IsDisjointVisitor<'db>,
18581876
) -> bool {
18591877
protocol.interface(db).members(db).any(|member| {
18601878
other
@@ -5745,7 +5763,7 @@ impl<'db> Type<'db> {
57455763
self,
57465764
db: &'db dyn Db,
57475765
type_mapping: &TypeMapping<'a, 'db>,
5748-
visitor: &TypeTransformer<'db>,
5766+
visitor: &ApplyTypeMappingVisitor<'db>,
57495767
) -> Type<'db> {
57505768
match self {
57515769
Type::TypeVar(bound_typevar) => match type_mapping {
@@ -6267,7 +6285,7 @@ impl<'db> TypeMapping<'_, 'db> {
62676285
}
62686286
}
62696287

6270-
fn normalized_impl(&self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
6288+
fn normalized_impl(&self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
62716289
match self {
62726290
TypeMapping::Specialization(specialization) => {
62736291
TypeMapping::Specialization(specialization.normalized_impl(db, visitor))
@@ -6353,7 +6371,7 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
63536371
}
63546372

63556373
impl<'db> KnownInstanceType<'db> {
6356-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
6374+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
63576375
match self {
63586376
Self::SubscriptedProtocol(context) => {
63596377
Self::SubscriptedProtocol(context.normalized_impl(db, visitor))
@@ -6777,7 +6795,7 @@ pub struct FieldInstance<'db> {
67776795
impl get_size2::GetSize for FieldInstance<'_> {}
67786796

67796797
impl<'db> FieldInstance<'db> {
6780-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
6798+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
67816799
FieldInstance::new(
67826800
db,
67836801
self.default_type(db).normalized_impl(db, visitor),
@@ -6922,7 +6940,7 @@ impl<'db> TypeVarInstance<'db> {
69226940
})
69236941
}
69246942

6925-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
6943+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
69266944
Self::new(
69276945
db,
69286946
self.name(db),
@@ -7099,7 +7117,7 @@ impl<'db> BoundTypeVarInstance<'db> {
70997117
.map(|ty| ty.apply_type_mapping(db, &TypeMapping::BindLegacyTypevars(binding_context)))
71007118
}
71017119

7102-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
7120+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
71037121
Self::new(
71047122
db,
71057123
self.typevar(db).normalized_impl(db, visitor),
@@ -7176,7 +7194,7 @@ fn walk_type_var_bounds<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
71767194
}
71777195

71787196
impl<'db> TypeVarBoundOrConstraints<'db> {
7179-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
7197+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
71807198
match self {
71817199
TypeVarBoundOrConstraints::UpperBound(bound) => {
71827200
TypeVarBoundOrConstraints::UpperBound(bound.normalized_impl(db, visitor))
@@ -8214,7 +8232,7 @@ impl<'db> BoundMethodType<'db> {
82148232
)
82158233
}
82168234

8217-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8235+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
82188236
Self::new(
82198237
db,
82208238
self.function(db).normalized_impl(db, visitor),
@@ -8331,7 +8349,7 @@ impl<'db> CallableType<'db> {
83318349
/// Return a "normalized" version of this `Callable` type.
83328350
///
83338351
/// See [`Type::normalized`] for more details.
8334-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8352+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
83358353
CallableType::new(
83368354
db,
83378355
self.signatures(db).normalized_impl(db, visitor),
@@ -8495,7 +8513,7 @@ impl<'db> MethodWrapperKind<'db> {
84958513
}
84968514
}
84978515

8498-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8516+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
84998517
match self {
85008518
MethodWrapperKind::FunctionTypeDunderGet(function) => {
85018519
MethodWrapperKind::FunctionTypeDunderGet(function.normalized_impl(db, visitor))
@@ -8679,7 +8697,7 @@ impl<'db> PEP695TypeAliasType<'db> {
86798697
definition_expression_type(db, definition, &type_alias_stmt_node.value)
86808698
}
86818699

8682-
fn normalized_impl(self, _db: &'db dyn Db, _visitor: &TypeTransformer<'db>) -> Self {
8700+
fn normalized_impl(self, _db: &'db dyn Db, _visitor: &NormalizedVisitor<'db>) -> Self {
86838701
self
86848702
}
86858703
}
@@ -8721,7 +8739,7 @@ fn walk_bare_type_alias<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
87218739
}
87228740

87238741
impl<'db> BareTypeAliasType<'db> {
8724-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8742+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
87258743
Self::new(
87268744
db,
87278745
self.name(db),
@@ -8757,7 +8775,7 @@ fn walk_type_alias_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
87578775
}
87588776

87598777
impl<'db> TypeAliasType<'db> {
8760-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
8778+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
87618779
match self {
87628780
TypeAliasType::PEP695(type_alias) => {
87638781
TypeAliasType::PEP695(type_alias.normalized_impl(db, visitor))
@@ -8986,7 +9004,7 @@ impl<'db> UnionType<'db> {
89869004
self.normalized_impl(db, &TypeTransformer::default())
89879005
}
89889006

8989-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
9007+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
89909008
let mut new_elements: Vec<Type<'db>> = self
89919009
.elements(db)
89929010
.iter()
@@ -9060,11 +9078,11 @@ impl<'db> IntersectionType<'db> {
90609078
self.normalized_impl(db, &TypeTransformer::default())
90619079
}
90629080

9063-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
9081+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
90649082
fn normalized_set<'db>(
90659083
db: &'db dyn Db,
90669084
elements: &FxOrderSet<Type<'db>>,
9067-
visitor: &TypeTransformer<'db>,
9085+
visitor: &NormalizedVisitor<'db>,
90689086
) -> FxOrderSet<Type<'db>> {
90699087
let mut elements: FxOrderSet<Type<'db>> = elements
90709088
.iter()
@@ -9314,7 +9332,7 @@ impl<'db> TypedDictType<'db> {
93149332
self,
93159333
db: &'db dyn Db,
93169334
type_mapping: &TypeMapping<'a, 'db>,
9317-
visitor: &TypeTransformer<'db>,
9335+
visitor: &ApplyTypeMappingVisitor<'db>,
93189336
) -> Self {
93199337
Self {
93209338
defining_class: self
@@ -9386,7 +9404,7 @@ pub enum SuperOwnerKind<'db> {
93869404
}
93879405

93889406
impl<'db> SuperOwnerKind<'db> {
9389-
fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
9407+
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
93909408
match self {
93919409
SuperOwnerKind::Dynamic(dynamic) => SuperOwnerKind::Dynamic(dynamic.normalized()),
93929410
SuperOwnerKind::Class(class) => {
@@ -9658,7 +9676,7 @@ impl<'db> BoundSuperType<'db> {
96589676
}
96599677
}
96609678

9661-
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
9679+
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
96629680
Self::new(
96639681
db,
96649682
self.pivot_class(db).normalized_impl(db, visitor),

crates/ty_python_semantic/src/types/class.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ use crate::types::infer::nearest_enclosing_class;
2222
use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature};
2323
use crate::types::tuple::{TupleSpec, TupleType};
2424
use crate::types::{
25-
BareTypeAliasType, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams,
26-
DeprecatedInstance, KnownInstanceType, LazyTypeVarBoundOrConstraints, StringLiteralType,
27-
TypeAliasType, TypeMapping, TypeRelation, TypeTransformer, TypeVarBoundOrConstraints,
28-
TypeVarDefault, TypeVarInstance, TypeVarKind, declaration_type, infer_definition_types,
29-
todo_type,
25+
ApplyTypeMappingVisitor, BareTypeAliasType, Binding, BoundSuperError, BoundSuperType,
26+
CallableType, DataclassParams, DeprecatedInstance, HasRelationToVisitor, KnownInstanceType,
27+
LazyTypeVarBoundOrConstraints, NormalizedVisitor, StringLiteralType, TypeAliasType,
28+
TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarDefault, TypeVarInstance,
29+
TypeVarKind, declaration_type, infer_definition_types, todo_type,
3030
};
3131
use crate::{
3232
Db, FxIndexMap, FxOrderSet, Program,
@@ -232,7 +232,7 @@ pub(super) fn walk_generic_alias<'db, V: super::visitor::TypeVisitor<'db> + ?Siz
232232
impl get_size2::GetSize for GenericAlias<'_> {}
233233

234234
impl<'db> GenericAlias<'db> {
235-
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
235+
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
236236
Self::new(
237237
db,
238238
self.origin(db),
@@ -256,7 +256,7 @@ impl<'db> GenericAlias<'db> {
256256
self,
257257
db: &'db dyn Db,
258258
type_mapping: &TypeMapping<'a, 'db>,
259-
visitor: &TypeTransformer<'db>,
259+
visitor: &ApplyTypeMappingVisitor<'db>,
260260
) -> Self {
261261
Self::new(
262262
db,
@@ -320,7 +320,7 @@ impl<'db> ClassType<'db> {
320320
}
321321
}
322322

323-
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
323+
pub(super) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
324324
match self {
325325
Self::NonGeneric(_) => self,
326326
Self::Generic(generic) => Self::Generic(generic.normalized_impl(db, visitor)),
@@ -407,7 +407,7 @@ impl<'db> ClassType<'db> {
407407
self,
408408
db: &'db dyn Db,
409409
type_mapping: &TypeMapping<'a, 'db>,
410-
visitor: &TypeTransformer<'db>,
410+
visitor: &ApplyTypeMappingVisitor<'db>,
411411
) -> Self {
412412
match self {
413413
Self::NonGeneric(_) => self,
@@ -470,7 +470,7 @@ impl<'db> ClassType<'db> {
470470
db: &'db dyn Db,
471471
other: Self,
472472
relation: TypeRelation,
473-
visitor: &PairVisitor<'db>,
473+
visitor: &HasRelationToVisitor<'db>,
474474
) -> bool {
475475
self.iter_mro(db).any(|base| {
476476
match base {

crates/ty_python_semantic/src/types/class_base.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ use crate::Db;
22
use crate::types::generics::Specialization;
33
use crate::types::tuple::TupleType;
44
use crate::types::{
5-
ClassLiteral, ClassType, DynamicType, KnownClass, KnownInstanceType, MroError, MroIterator,
6-
SpecialFormType, Type, TypeMapping, TypeTransformer, todo_type,
5+
ApplyTypeMappingVisitor, ClassLiteral, ClassType, DynamicType, KnownClass, KnownInstanceType,
6+
MroError, MroIterator, NormalizedVisitor, SpecialFormType, Type, TypeMapping, TypeTransformer,
7+
todo_type,
78
};
89

910
/// Enumeration of the possible kinds of types we allow in class bases.
@@ -33,7 +34,7 @@ impl<'db> ClassBase<'db> {
3334
Self::Dynamic(DynamicType::Unknown)
3435
}
3536

36-
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &TypeTransformer<'db>) -> Self {
37+
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
3738
match self {
3839
Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized()),
3940
Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)),
@@ -269,7 +270,7 @@ impl<'db> ClassBase<'db> {
269270
self,
270271
db: &'db dyn Db,
271272
type_mapping: &TypeMapping<'a, 'db>,
272-
visitor: &TypeTransformer<'db>,
273+
visitor: &ApplyTypeMappingVisitor<'db>,
273274
) -> Self {
274275
match self {
275276
Self::Class(class) => {

crates/ty_python_semantic/src/types/cyclic.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@
1818
//! `visitor.visit` when visiting a protocol type, and then internal `has_relation_to_impl` methods
1919
//! of the Rust types implementing protocols also call `visitor.visit`. The best way to avoid this
2020
//! is to prefer always calling `visitor.visit` only in the main recursive method on `Type`.
21-
use rustc_hash::FxHashMap;
2221
23-
use crate::FxIndexSet;
24-
use crate::types::Type;
2522
use std::cell::RefCell;
2623
use std::cmp::Eq;
2724
use std::hash::Hash;
25+
use std::marker::PhantomData;
2826

29-
pub(crate) type TypeTransformer<'db> = CycleDetector<Type<'db>, Type<'db>>;
27+
use rustc_hash::FxHashMap;
3028

31-
impl Default for TypeTransformer<'_> {
29+
use crate::FxIndexSet;
30+
use crate::types::Type;
31+
32+
pub(crate) type TypeTransformer<'db, Tag> = CycleDetector<Tag, Type<'db>, Type<'db>>;
33+
34+
impl<Tag> Default for TypeTransformer<'_, Tag> {
3235
fn default() -> Self {
3336
// TODO: proper recursive type handling
3437

@@ -38,10 +41,10 @@ impl Default for TypeTransformer<'_> {
3841
}
3942
}
4043

41-
pub(crate) type PairVisitor<'db> = CycleDetector<(Type<'db>, Type<'db>), bool>;
44+
pub(crate) type PairVisitor<'db, Tag> = CycleDetector<Tag, (Type<'db>, Type<'db>), bool>;
4245

4346
#[derive(Debug)]
44-
pub(crate) struct CycleDetector<T, R> {
47+
pub(crate) struct CycleDetector<Tag, T, R> {
4548
/// If the type we're visiting is present in `seen`, it indicates that we've hit a cycle (due
4649
/// to a recursive type); we need to immediately short circuit the whole operation and return
4750
/// the fallback value. That's why we pop items off the end of `seen` after we've visited them.
@@ -56,14 +59,17 @@ pub(crate) struct CycleDetector<T, R> {
5659
cache: RefCell<FxHashMap<T, R>>,
5760

5861
fallback: R,
62+
63+
_tag: PhantomData<Tag>,
5964
}
6065

61-
impl<T: Hash + Eq + Copy, R: Copy> CycleDetector<T, R> {
66+
impl<Tag, T: Hash + Eq + Copy, R: Copy> CycleDetector<Tag, T, R> {
6267
pub(crate) fn new(fallback: R) -> Self {
6368
CycleDetector {
6469
seen: RefCell::new(FxIndexSet::default()),
6570
cache: RefCell::new(FxHashMap::default()),
6671
fallback,
72+
_tag: PhantomData,
6773
}
6874
}
6975

0 commit comments

Comments
 (0)