Skip to content

Commit 1a94b3f

Browse files
committed
is this faster
1 parent 0a478a2 commit 1a94b3f

File tree

2 files changed

+29
-44
lines changed

2 files changed

+29
-44
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -811,29 +811,6 @@ impl<'db> Type<'db> {
811811
}
812812
}
813813

814-
/// Normalize the type `bool` -> `Literal[True, False]`.
815-
///
816-
/// Using this method in various type-relational methods
817-
/// ensures that the following invariants hold true:
818-
///
819-
/// - bool ≡ Literal[True, False]
820-
/// - bool | T ≡ Literal[True, False] | T
821-
/// - bool <: Literal[True, False]
822-
/// - bool | T <: Literal[True, False] | T
823-
/// - Literal[True, False] <: bool
824-
/// - Literal[True, False] | T <: bool | T
825-
#[must_use]
826-
pub fn with_normalized_bools(self, db: &'db dyn Db) -> Self {
827-
match self {
828-
Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => {
829-
Type::normalized_bool(db)
830-
}
831-
// TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`?
832-
// We'd need to rename this method... --Alex
833-
_ => self,
834-
}
835-
}
836-
837814
/// Return a normalized version of `self` in which all unions and intersections are sorted
838815
/// according to a canonical order, no matter how "deeply" a union/intersection may be nested.
839816
#[must_use]
@@ -905,10 +882,11 @@ impl<'db> Type<'db> {
905882
(_, Type::Never) => false,
906883

907884
(Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => {
908-
Type::normalized_bool(db).is_subtype_of(db, target)
885+
Type::BooleanLiteral(true).is_subtype_of(db, target)
886+
&& Type::BooleanLiteral(false).is_subtype_of(db, target)
909887
}
910888
(_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => {
911-
self.is_subtype_of(db, Type::normalized_bool(db))
889+
self.is_boolean_literal()
912890
}
913891

914892
(Type::Union(union), _) => union
@@ -1125,10 +1103,12 @@ impl<'db> Type<'db> {
11251103
}
11261104

11271105
(Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => {
1128-
Type::normalized_bool(db).is_assignable_to(db, target)
1106+
Type::BooleanLiteral(true).is_assignable_to(db, target)
1107+
&& Type::BooleanLiteral(false).is_assignable_to(db, target)
11291108
}
11301109
(_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => {
1131-
self.is_assignable_to(db, Type::normalized_bool(db))
1110+
self.is_assignable_to(db, Type::BooleanLiteral(true))
1111+
|| self.is_assignable_to(db, Type::BooleanLiteral(false))
11321112
}
11331113

11341114
// A union is assignable to a type T iff every element of the union is assignable to T.
@@ -2409,13 +2389,6 @@ impl<'db> Type<'db> {
24092389
KnownClass::NoneType.to_instance(db)
24102390
}
24112391

2412-
/// The type `Literal[True, False]`, which is exactly equivalent to `bool`
2413-
/// (and which `bool` is eagerly normalized to in several situations)
2414-
pub fn normalized_bool(db: &'db dyn Db) -> Type<'db> {
2415-
const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)];
2416-
Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS)))
2417-
}
2418-
24192392
/// Return the type of `tuple(sys.version_info)`.
24202393
///
24212394
/// This is not exactly the type that `sys.version_info` has at runtime,

crates/red_knot_python_semantic/src/types/builder.rs

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
//! eliminate the supertype from the intersection).
2727
//! * An intersection containing two non-overlapping types should simplify to [`Type::Never`].
2828
29-
use crate::types::{IntersectionType, KnownClass, Type, UnionType};
29+
use crate::types::{InstanceType, IntersectionType, KnownClass, Type, UnionType};
3030
use crate::{Db, FxOrderSet};
3131
use smallvec::SmallVec;
3232

@@ -45,7 +45,6 @@ impl<'db> UnionBuilder<'db> {
4545

4646
/// Adds a type to this union.
4747
pub(crate) fn add(mut self, ty: Type<'db>) -> Self {
48-
let ty = ty.with_normalized_bools(self.db);
4948
match ty {
5049
Type::Union(union) => {
5150
let new_elements = union.elements(self.db);
@@ -55,6 +54,11 @@ impl<'db> UnionBuilder<'db> {
5554
}
5655
}
5756
Type::Never => {}
57+
Type::Instance(InstanceType { class }) if class.is_known(self.db, KnownClass::Bool) => {
58+
self = self
59+
.add(Type::BooleanLiteral(true))
60+
.add(Type::BooleanLiteral(false));
61+
}
5862
_ => {
5963
let mut to_remove = SmallVec::<[usize; 2]>::new();
6064
let ty_negated = ty.negate(self.db);
@@ -162,9 +166,20 @@ impl<'db> IntersectionBuilder<'db> {
162166
}
163167
}
164168

169+
fn elements_of_union(&self, ty: Type<'db>) -> Option<&'db [Type<'db>]> {
170+
const BOOL_LITERALS: &[Type] = &[Type::BooleanLiteral(true), Type::BooleanLiteral(false)];
171+
172+
match ty {
173+
Type::Union(union) => Some(union.elements(self.db)),
174+
Type::Instance(InstanceType { class }) if class.is_known(self.db, KnownClass::Bool) => {
175+
Some(BOOL_LITERALS)
176+
}
177+
_ => None,
178+
}
179+
}
180+
165181
pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self {
166-
let ty = ty.with_normalized_bools(self.db);
167-
if let Type::Union(union) = ty {
182+
if let Some(elements) = self.elements_of_union(ty) {
168183
// Distribute ourself over this union: for each union element, clone ourself and
169184
// intersect with that union element, then create a new union-of-intersections with all
170185
// of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2`
@@ -173,8 +188,7 @@ impl<'db> IntersectionBuilder<'db> {
173188
// (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)`
174189
// and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 &
175190
// T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea.
176-
union
177-
.elements(self.db)
191+
elements
178192
.iter()
179193
.map(|elem| self.clone().add_positive(*elem))
180194
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
@@ -194,10 +208,8 @@ impl<'db> IntersectionBuilder<'db> {
194208
pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self {
195209
// See comments above in `add_positive`; this is just the negated version.
196210

197-
let ty = ty.with_normalized_bools(self.db);
198-
199-
if let Type::Union(union) = ty {
200-
for elem in union.elements(self.db) {
211+
if let Some(elements) = self.elements_of_union(ty) {
212+
for elem in elements {
201213
self = self.add_negative(*elem);
202214
}
203215
self

0 commit comments

Comments
 (0)