Skip to content

Commit b250304

Browse files
jgeralnikcarljm
andauthored
[red-knot] Improve is_disjoint for two intersections (#16636)
## Summary Background - as a follow up to #16611 I noticed that there's a lot of code duplicated between the `is_assignable_to` and `is_subtype_of` functions and considered trying to merge them. [A subtype and an assignable type are pretty much the same](https://typing.python.org/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation), except that subtypes are by definition fully static, so I think we can replace the whole of `is_subtype_of` with: ``` if !self.is_fully_static(db) || !target.is_fully_static(db) { return false; } return self.is_assignable_to(target) ``` if we move all of the logic to is_assignable_to and delete duplicate code. Then we can discuss if it even makes sense to have a separate is_subtype_of function (I think the answer is yes since it's used by a bunch of other places, but we may be able to basically rip out the concept). Anyways while playing with combining the functions I noticed is that the handling of Intersections in `is_subtype_of` has a special case for two intersections, which I didn't include in the last PR - rather I first handled right hand intersections before left hand, which should properly handle double intersections (hand-wavy explanation I can justify if needed - (A & B & C) is assignable to (A & B) because the left is assignable to both A and B, but none of A, B, or C is assignable to (A & B)). I took a look at what breaks if I remove the handling for double intersections, and the reason it is needed is because is_disjoint does not properly handle intersections with negative conditions (so instead `is_subtype_of` basically implements the check correctly). This PR adds support to is_disjoint for properly checking negative branches, which also lets us simplify `is_subtype_of`, bringing it in line with `is_assignable_to` ## Test Plan Added a bunch of tests, most of which failed before this fix --------- Co-authored-by: Carl Meyer <carl@astral.sh>
1 parent 11b5cbc commit b250304

File tree

3 files changed

+65
-43
lines changed

3 files changed

+65
-43
lines changed

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ static_assert(is_assignable_to(Intersection[int, Parent], Intersection[int, Not[
266266
static_assert(not is_assignable_to(int, Not[int]))
267267
static_assert(not is_assignable_to(int, Not[Literal[1]]))
268268

269+
static_assert(is_assignable_to(Not[Parent], Not[Child1]))
270+
static_assert(not is_assignable_to(Not[Parent], Parent))
271+
static_assert(not is_assignable_to(Intersection[Unrelated, Not[Parent]], Parent))
272+
269273
# Intersection with `Any` dominates the left hand side of intersections
270274
static_assert(is_assignable_to(Intersection[Any, Parent], Parent))
271275
static_assert(is_assignable_to(Intersection[Any, Child1], Parent))
@@ -277,6 +281,7 @@ static_assert(is_assignable_to(Intersection[Any, Parent, Unrelated], Intersectio
277281

278282
# Even Any & Not[Parent] is assignable to Parent, since it could be Never
279283
static_assert(is_assignable_to(Intersection[Any, Not[Parent]], Parent))
284+
static_assert(is_assignable_to(Intersection[Any, Not[Parent]], Not[Parent]))
280285

281286
# Intersection with `Any` is effectively ignored on the right hand side for the sake of assignment
282287
static_assert(is_assignable_to(Parent, Intersection[Any, Parent]))

crates/red_knot_python_semantic/resources/mdtest/type_properties/is_disjoint_from.md

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ static_assert(not is_disjoint_from(bool, object))
1616

1717
static_assert(not is_disjoint_from(Any, bool))
1818
static_assert(not is_disjoint_from(Any, Any))
19+
static_assert(not is_disjoint_from(Any, Not[Any]))
1920

2021
static_assert(not is_disjoint_from(LiteralString, LiteralString))
2122
static_assert(not is_disjoint_from(str, LiteralString))
@@ -95,8 +96,8 @@ static_assert(not is_disjoint_from(Literal[1, 2], Literal[2, 3]))
9596
## Intersections
9697

9798
```py
98-
from typing_extensions import Literal, final
99-
from knot_extensions import Intersection, is_disjoint_from, static_assert
99+
from typing_extensions import Literal, final, Any
100+
from knot_extensions import Intersection, is_disjoint_from, static_assert, Not
100101

101102
@final
102103
class P: ...
@@ -130,6 +131,27 @@ static_assert(not is_disjoint_from(Y, Z))
130131
static_assert(not is_disjoint_from(Intersection[X, Y], Z))
131132
static_assert(not is_disjoint_from(Intersection[X, Z], Y))
132133
static_assert(not is_disjoint_from(Intersection[Y, Z], X))
134+
135+
# If one side has a positive fully-static element and the other side has a negative of that element, they are disjoint
136+
static_assert(is_disjoint_from(int, Not[int]))
137+
static_assert(is_disjoint_from(Intersection[X, Y, Not[Z]], Intersection[X, Z]))
138+
static_assert(is_disjoint_from(Intersection[X, Not[Literal[1]]], Literal[1]))
139+
140+
class Parent: ...
141+
class Child(Parent): ...
142+
143+
static_assert(not is_disjoint_from(Parent, Child))
144+
static_assert(not is_disjoint_from(Parent, Not[Child]))
145+
static_assert(not is_disjoint_from(Not[Parent], Not[Child]))
146+
static_assert(is_disjoint_from(Not[Parent], Child))
147+
static_assert(is_disjoint_from(Intersection[X, Not[Parent]], Child))
148+
static_assert(is_disjoint_from(Intersection[X, Not[Parent]], Intersection[X, Child]))
149+
150+
static_assert(not is_disjoint_from(Intersection[Any, X], Intersection[Any, Not[Y]]))
151+
static_assert(not is_disjoint_from(Intersection[Any, Not[Y]], Intersection[Any, X]))
152+
153+
static_assert(is_disjoint_from(Intersection[int, Any], Not[int]))
154+
static_assert(is_disjoint_from(Not[int], Intersection[int, Any]))
133155
```
134156

135157
## Special types
@@ -152,7 +174,7 @@ static_assert(is_disjoint_from(Never, object))
152174

153175
```py
154176
from typing_extensions import Literal, LiteralString
155-
from knot_extensions import is_disjoint_from, static_assert
177+
from knot_extensions import is_disjoint_from, static_assert, Intersection, Not
156178

157179
static_assert(is_disjoint_from(None, Literal[True]))
158180
static_assert(is_disjoint_from(None, Literal[1]))
@@ -165,6 +187,9 @@ static_assert(is_disjoint_from(None, type[object]))
165187
static_assert(not is_disjoint_from(None, None))
166188
static_assert(not is_disjoint_from(None, int | None))
167189
static_assert(not is_disjoint_from(None, object))
190+
191+
static_assert(is_disjoint_from(Intersection[int, Not[str]], None))
192+
static_assert(is_disjoint_from(None, Intersection[int, Not[str]]))
168193
```
169194

170195
### Literals

crates/red_knot_python_semantic/src/types.rs

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -580,38 +580,9 @@ impl<'db> Type<'db> {
580580
true
581581
}
582582

583-
(Type::Intersection(self_intersection), Type::Intersection(target_intersection)) => {
584-
// Check that all target positive values are covered in self positive values
585-
target_intersection
586-
.positive(db)
587-
.iter()
588-
.all(|&target_pos_elem| {
589-
self_intersection
590-
.positive(db)
591-
.iter()
592-
.any(|&self_pos_elem| self_pos_elem.is_subtype_of(db, target_pos_elem))
593-
})
594-
// Check that all target negative values are excluded in self, either by being
595-
// subtypes of a self negative value or being disjoint from a self positive value.
596-
&& target_intersection
597-
.negative(db)
598-
.iter()
599-
.all(|&target_neg_elem| {
600-
// Is target negative value is subtype of a self negative value
601-
self_intersection.negative(db).iter().any(|&self_neg_elem| {
602-
target_neg_elem.is_subtype_of(db, self_neg_elem)
603-
// Is target negative value is disjoint from a self positive value?
604-
}) || self_intersection.positive(db).iter().any(|&self_pos_elem| {
605-
self_pos_elem.is_disjoint_from(db, target_neg_elem)
606-
})
607-
})
608-
}
609-
610-
(Type::Intersection(intersection), _) => intersection
611-
.positive(db)
612-
.iter()
613-
.any(|&elem_ty| elem_ty.is_subtype_of(db, target)),
614-
583+
// If both sides are intersections we need to handle the right side first
584+
// (A & B & C) is a subtype of (A & B) because the left is a subtype of both A and B,
585+
// but none of A, B, or C is a subtype of (A & B).
615586
(_, Type::Intersection(intersection)) => {
616587
intersection
617588
.positive(db)
@@ -623,6 +594,11 @@ impl<'db> Type<'db> {
623594
.all(|&neg_ty| self.is_disjoint_from(db, neg_ty))
624595
}
625596

597+
(Type::Intersection(intersection), _) => intersection
598+
.positive(db)
599+
.iter()
600+
.any(|&elem_ty| elem_ty.is_subtype_of(db, target)),
601+
626602
// Note that the definition of `Type::AlwaysFalsy` depends on the return value of `__bool__`.
627603
// If `__bool__` always returns True or False, it can be treated as a subtype of `AlwaysTruthy` or `AlwaysFalsy`, respectively.
628604
(left, Type::AlwaysFalsy) => left.bool(db).is_always_false(),
@@ -799,6 +775,10 @@ impl<'db> Type<'db> {
799775
.iter()
800776
.any(|&elem_ty| ty.is_assignable_to(db, elem_ty)),
801777

778+
// If both sides are intersections we need to handle the right side first
779+
// (A & B & C) is assignable to (A & B) because the left is assignable to both A and B,
780+
// but none of A, B, or C is assignable to (A & B).
781+
//
802782
// A type S is assignable to an intersection type T if
803783
// S is assignable to all positive elements of T (e.g. `str & int` is assignable to `str & Any`), and
804784
// S is disjoint from all negative elements of T (e.g. `int` is not assignable to Intersection[int, Not[Literal[1]]]).
@@ -995,19 +975,31 @@ impl<'db> Type<'db> {
995975
.iter()
996976
.all(|e| e.is_disjoint_from(db, other)),
997977

978+
// If we have two intersections, we test the positive elements of each one against the other intersection
979+
// Negative elements need a positive element on the other side in order to be disjoint.
980+
// This is similar to what would happen if we tried to build a new intersection that combines the two
981+
(Type::Intersection(self_intersection), Type::Intersection(other_intersection)) => {
982+
self_intersection
983+
.positive(db)
984+
.iter()
985+
.any(|p| p.is_disjoint_from(db, other))
986+
|| other_intersection
987+
.positive(db)
988+
.iter()
989+
.any(|p: &Type<'_>| p.is_disjoint_from(db, self))
990+
}
991+
998992
(Type::Intersection(intersection), other)
999993
| (other, Type::Intersection(intersection)) => {
1000-
if intersection
994+
intersection
1001995
.positive(db)
1002996
.iter()
1003997
.any(|p| p.is_disjoint_from(db, other))
1004-
{
1005-
true
1006-
} else {
1007-
// TODO we can do better here. For example:
1008-
// X & ~Literal[1] is disjoint from Literal[1]
1009-
false
1010-
}
998+
// A & B & Not[C] is disjoint from C
999+
|| intersection
1000+
.negative(db)
1001+
.iter()
1002+
.any(|&neg_ty| other.is_subtype_of(db, neg_ty))
10111003
}
10121004

10131005
// any single-valued type is disjoint from another single-valued type

0 commit comments

Comments
 (0)