Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,8 @@ reveal_type(C.__init__) # revealed: (self: C, normal: int, conditionally_presen
python-version = "3.12"
```

### Basic

```py
from dataclasses import dataclass

Expand All @@ -658,6 +660,34 @@ reveal_type(d_int.description) # revealed: str
DataWithDescription[int](None, "description")
```

### Deriving from generic dataclasses

This is a regression test for <https://github.com/astral-sh/ty/issues/853>.

```py
from dataclasses import dataclass

@dataclass
class Wrap[T]:
data: T

reveal_type(Wrap[int].__init__) # revealed: (self: Wrap[int], data: int) -> None

@dataclass
class WrappedInt(Wrap[int]):
other_field: str

reveal_type(WrappedInt.__init__) # revealed: (self: WrappedInt, data: int, other_field: str) -> None

# Make sure that another generic type parameter does not affect the `data` field
@dataclass
class WrappedIntAndExtraData[T](Wrap[int]):
extra_data: T

# revealed: (self: WrappedIntAndExtraData[bytes], data: int, extra_data: bytes) -> None
reveal_type(WrappedIntAndExtraData[bytes].__init__)
```

## Descriptor-typed fields

### Same type in `__get__` and `__set__`
Expand Down
23 changes: 17 additions & 6 deletions crates/ty_python_semantic/src/types/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1669,16 +1669,16 @@ impl<'db> ClassLiteral<'db> {
if field_policy == CodeGeneratorKind::NamedTuple {
// NamedTuples do not allow multiple inheritance, so it is sufficient to enumerate the
// fields of this class only.
return self.own_fields(db);
return self.own_fields(db, specialization);
}

let matching_classes_in_mro: Vec<_> = self
.iter_mro(db, specialization)
.filter_map(|superclass| {
if let Some(class) = superclass.into_class() {
let class_literal = class.class_literal(db).0;
let (class_literal, specialization) = class.class_literal(db);
if field_policy.matches(db, class_literal) {
Some(class_literal)
Some((class_literal, specialization))
} else {
None
}
Expand All @@ -1692,7 +1692,7 @@ impl<'db> ClassLiteral<'db> {
matching_classes_in_mro
.into_iter()
.rev()
.flat_map(|class| class.own_fields(db))
.flat_map(|(class, specialization)| class.own_fields(db, specialization))
// We collect into a FxOrderMap here to deduplicate attributes
.collect()
}
Expand All @@ -1708,7 +1708,11 @@ impl<'db> ClassLiteral<'db> {
/// y: str = "a"
/// ```
/// we return a map `{"x": (int, None), "y": (str, Some(Literal["a"]))}`.
fn own_fields(self, db: &'db dyn Db) -> FxOrderMap<Name, (Type<'db>, Option<Type<'db>>)> {
fn own_fields(
self,
db: &'db dyn Db,
specialization: Option<Specialization<'db>>,
) -> FxOrderMap<Name, (Type<'db>, Option<Type<'db>>)> {
let mut attributes = FxOrderMap::default();

let class_body_scope = self.body_scope(db);
Expand Down Expand Up @@ -1748,7 +1752,14 @@ impl<'db> ClassLiteral<'db> {
let bindings = use_def.end_of_scope_bindings(place_id);
let default_ty = place_from_bindings(db, bindings).ignore_possibly_unbound();

attributes.insert(place_expr.expect_name().clone(), (attr_ty, default_ty));
attributes.insert(
place_expr.expect_name().clone(),
(
attr_ty.apply_optional_specialization(db, specialization),
default_ty
.map(|ty| ty.apply_optional_specialization(db, specialization)),
),
);
}
}
}
Expand Down
Loading