Skip to content

Commit 68b0386

Browse files
[ty] Implement DataClassInstance protocol for dataclasses. (#18018)
Fixes: astral-sh/ty#92 ## Summary We currently get a `invalid-argument-type` error when using `dataclass.fields` on a dataclass, because we do not synthesize the `__dataclass_fields__` member. This PR fixes this diagnostic. Note that we do not yet model the `Field` type correctly. After that is done, we can assign a more precise `tuple[Field, ...]` type to this new member. ## Test Plan New mdtest. --------- Co-authored-by: David Peter <mail@david-peter.de>
1 parent 0ae07cd commit 68b0386

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

crates/ty_python_semantic/resources/mdtest/dataclasses.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,25 @@ reveal_type(C.__init__) # revealed: (field: str | int = int) -> None
616616

617617
To do
618618

619+
## `dataclass.fields`
620+
621+
Dataclasses have `__dataclass_fields__` in them, which makes them a subtype of the
622+
`DataclassInstance` protocol.
623+
624+
Here, we verify that dataclasses can be passed to `dataclasses.fields` without any errors, and that
625+
the return type of `dataclasses.fields` is correct.
626+
627+
```py
628+
from dataclasses import dataclass, fields
629+
630+
@dataclass
631+
class Foo:
632+
x: int
633+
634+
reveal_type(Foo.__dataclass_fields__) # revealed: dict[str, Field[Any]]
635+
reveal_type(fields(Foo)) # revealed: tuple[Field[Any], ...]
636+
```
637+
619638
## Other special cases
620639

621640
### `dataclasses.dataclass`

crates/ty_python_semantic/src/types.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2939,6 +2939,19 @@ impl<'db> Type<'db> {
29392939
))
29402940
.into()
29412941
}
2942+
Type::ClassLiteral(class)
2943+
if name == "__dataclass_fields__" && class.dataclass_params(db).is_some() =>
2944+
{
2945+
// Make this class look like a subclass of the `DataClassInstance` protocol
2946+
Symbol::bound(KnownClass::Dict.to_specialized_instance(
2947+
db,
2948+
[
2949+
KnownClass::Str.to_instance(db),
2950+
KnownClass::Field.to_specialized_instance(db, [Type::any()]),
2951+
],
2952+
))
2953+
.with_qualifiers(TypeQualifiers::CLASS_VAR)
2954+
}
29422955
Type::BoundMethod(bound_method) => match name_str {
29432956
"__self__" => Symbol::bound(bound_method.self_instance(db)).into(),
29442957
"__func__" => {

crates/ty_python_semantic/src/types/class.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,6 +1958,8 @@ pub enum KnownClass {
19581958
// backported as `builtins.ellipsis` by typeshed on Python <=3.9
19591959
EllipsisType,
19601960
NotImplementedType,
1961+
// dataclasses
1962+
Field,
19611963
}
19621964

19631965
impl<'db> KnownClass {
@@ -2037,7 +2039,8 @@ impl<'db> KnownClass {
20372039
// and raises a `TypeError` in Python >=3.14
20382040
// (see https://docs.python.org/3/library/constants.html#NotImplemented)
20392041
| Self::NotImplementedType
2040-
| Self::Classmethod => Truthiness::Ambiguous,
2042+
| Self::Classmethod
2043+
| Self::Field => Truthiness::Ambiguous,
20412044
}
20422045
}
20432046

@@ -2108,7 +2111,8 @@ impl<'db> KnownClass {
21082111
| Self::VersionInfo
21092112
| Self::EllipsisType
21102113
| Self::NotImplementedType
2111-
| Self::UnionType => false,
2114+
| Self::UnionType
2115+
| Self::Field => false,
21122116
}
21132117
}
21142118

@@ -2181,6 +2185,7 @@ impl<'db> KnownClass {
21812185
}
21822186
}
21832187
Self::NotImplementedType => "_NotImplementedType",
2188+
Self::Field => "Field",
21842189
}
21852190
}
21862191

@@ -2405,6 +2410,7 @@ impl<'db> KnownClass {
24052410
| Self::DefaultDict
24062411
| Self::Deque
24072412
| Self::OrderedDict => KnownModule::Collections,
2413+
Self::Field => KnownModule::Dataclasses,
24082414
}
24092415
}
24102416

@@ -2464,7 +2470,8 @@ impl<'db> KnownClass {
24642470
| Self::ABCMeta
24652471
| Self::Super
24662472
| Self::NamedTuple
2467-
| Self::NewType => false,
2473+
| Self::NewType
2474+
| Self::Field => false,
24682475
}
24692476
}
24702477

@@ -2526,7 +2533,8 @@ impl<'db> KnownClass {
25262533
| Self::Super
25272534
| Self::UnionType
25282535
| Self::NamedTuple
2529-
| Self::NewType => false,
2536+
| Self::NewType
2537+
| Self::Field => false,
25302538
}
25312539
}
25322540

@@ -2596,6 +2604,7 @@ impl<'db> KnownClass {
25962604
Self::EllipsisType
25972605
}
25982606
"_NotImplementedType" => Self::NotImplementedType,
2607+
"Field" => Self::Field,
25992608
_ => return None,
26002609
};
26012610

@@ -2647,7 +2656,8 @@ impl<'db> KnownClass {
26472656
| Self::UnionType
26482657
| Self::GeneratorType
26492658
| Self::AsyncGeneratorType
2650-
| Self::WrapperDescriptorType => module == self.canonical_module(db),
2659+
| Self::WrapperDescriptorType
2660+
| Self::Field => module == self.canonical_module(db),
26512661
Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types),
26522662
Self::SpecialForm
26532663
| Self::TypeVar

0 commit comments

Comments
 (0)