Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5844c01
[ty] propagate the annotated return type of functions to the inferenc…
mtshiba Sep 23, 2025
c6f798c
don't wrap the raw return types of async functions in `CoroutineType`
mtshiba Sep 23, 2025
d14dcc8
improve `SpecializationBuilder::infer` behavior when `formal` is a un…
mtshiba Sep 23, 2025
238ebf5
Merge branch 'main' into bidi-return-type
mtshiba Sep 23, 2025
69e8582
Merge branch 'main' into bidi-return-type
mtshiba Sep 24, 2025
d33ee55
fix `nearest_enclosing_function` returning incorrect types for decora…
mtshiba Sep 24, 2025
9622841
Update bidirectional.md
mtshiba Sep 24, 2025
ef75f0a
prevent incorrect specializations in `SpecializationBuilder::infer`
mtshiba Sep 24, 2025
6c5625b
Update crates/ty_python_semantic/src/types/infer/builder.rs
mtshiba Sep 26, 2025
1565280
Update crates/ty_python_semantic/resources/mdtest/bidirectional.md
mtshiba Sep 26, 2025
7e8595a
Update crates/ty_python_semantic/resources/mdtest/bidirectional.md
mtshiba Sep 26, 2025
d666494
refactor according to the review
mtshiba Sep 26, 2025
40daa3f
Update bidirectional.md
mtshiba Sep 29, 2025
32bd211
Merge branch 'main' into bidi-return-type
mtshiba Sep 29, 2025
b0a62f1
Update signatures.rs
mtshiba Sep 29, 2025
4211ac9
Update bidirectional.md
mtshiba Sep 29, 2025
2212cc7
Update bidirectional.md
mtshiba Sep 29, 2025
f99d03f
Merge branch 'main' into bidi-return-type
mtshiba Oct 6, 2025
819a415
Update generics.rs
mtshiba Oct 6, 2025
1b6a505
Apply suggestion from @ibraheemdev
mtshiba Oct 6, 2025
b27c1ed
Apply suggestions from code review
mtshiba Oct 6, 2025
0249ba2
improve bidirectional inference in `infer_collection_literal`
mtshiba Oct 6, 2025
9507d72
don't set `generic_context` on the `Signature` returned by `OverloadL…
mtshiba Oct 6, 2025
bd5a465
Revert "don't set `generic_context` on the `Signature` returned by `O…
mtshiba Oct 6, 2025
8834fec
Update function.rs
mtshiba Oct 6, 2025
d1e9455
improve specialization between unions
mtshiba Oct 7, 2025
b5fddbc
Revert "improve specialization between unions"
mtshiba Oct 7, 2025
1b42e05
Merge branch 'main' into bidi-return-type
mtshiba Oct 7, 2025
97c065d
add `TypeDict` test cases
mtshiba Oct 7, 2025
7898360
Merge branch 'main' into bidi-return-type
mtshiba Oct 10, 2025
9cea903
update mdtest
mtshiba Oct 10, 2025
db90cda
Merge branch 'main' into bidi-return-type
ibraheemdev Oct 11, 2025
e7b5c14
update tests
ibraheemdev Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,75 @@ r: dict[int | str, int | str] = {1: 1, 2: 2, 3: 3}
reveal_type(r) # revealed: dict[int | str, int | str]
```

## Incorrect collection literal assignments are complained aobut
## Optional collection literal annotations are understood

```toml
[environment]
python-version = "3.12"
```

```py
import typing

a: list[int] | None = [1, 2, 3]
reveal_type(a) # revealed: list[int]

b: list[int | str] | None = [1, 2, 3]
reveal_type(b) # revealed: list[int | str]

c: typing.List[int] | None = [1, 2, 3]
reveal_type(c) # revealed: list[int]

d: list[typing.Any] | None = []
reveal_type(d) # revealed: list[Any]

e: set[int] | None = {1, 2, 3}
reveal_type(e) # revealed: set[int]

f: set[int | str] | None = {1, 2, 3}
reveal_type(f) # revealed: set[int | str]

g: typing.Set[int] | None = {1, 2, 3}
reveal_type(g) # revealed: set[int]

h: list[list[int]] | None = [[], [42]]
reveal_type(h) # revealed: list[list[int]]

i: list[typing.Any] | None = [1, 2, "3", ([4],)]
reveal_type(i) # revealed: list[Any | int | str | tuple[list[Unknown | int]]]

j: list[tuple[str | int, ...]] | None = [(1, 2), ("foo", "bar"), ()]
reveal_type(j) # revealed: list[tuple[str | int, ...]]

k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7])]
reveal_type(k) # revealed: list[tuple[list[int], ...]]

l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]

type IntList = list[int]

m: IntList | None = [1, 2, 3]
reveal_type(m) # revealed: list[int]

n: list[typing.Literal[1, 2, 3]] | None = [1, 2, 3]
reveal_type(n) # revealed: list[Literal[1, 2, 3]]

o: list[typing.LiteralString] | None = ["a", "b", "c"]
reveal_type(o) # revealed: list[LiteralString]

p: dict[int, int] | None = {}
reveal_type(p) # revealed: dict[int, int]

q: dict[int | str, int] | None = {1: 1, 2: 2, 3: 3}
reveal_type(q) # revealed: dict[int | str, int]

r: dict[int | str, int | str] | None = {1: 1, 2: 2, 3: 3}
reveal_type(r) # revealed: dict[int | str, int | str]
```

## Incorrect collection literal assignments are complained about

```py
# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[str]`"
Expand Down
147 changes: 147 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/bidirectional.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Bidirectional type inference

ty partially supports bidirectional type inference. This is a mechanism for inferring the type of an
expression "from the outside in". Normally, type inference proceeds "from the inside out". That is,
in order to infer the type of an expression, the types of all sub-expressions must first be
inferred. There is no reverse dependency. However, when performing complex type inference, such as
when generics are involved, the type of an outer expression can sometimes be useful in inferring
inner expressions. Bidirectional type inference is a mechanism that propagates such "expected types"
to the inference of inner expressions.

## Propagating target type annotation

```toml
[environment]
python-version = "3.12"
```

```py
def list1[T](x: T) -> list[T]:
return [x]

l1 = list1(1)
reveal_type(l1) # revealed: list[Literal[1]]
l2: list[int] = list1(1)
reveal_type(l2) # revealed: list[int]

# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`.
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l2 = l1

intermediate = list1(1)
# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]`
# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`"
l3: list[int] = intermediate
# TODO: it would be nice if this were `list[int]`
reveal_type(intermediate) # revealed: list[Literal[1]]
reveal_type(l3) # revealed: list[int]

l4: list[int | str] | None = list1(1)
reveal_type(l4) # revealed: list[int | str]

def _(l: list[int] | None = None):
l1 = l or list()
reveal_type(l1) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]

l2: list[int] = l or list()
# it would be better if this were `list[int]`? (https://github.com/astral-sh/ty/issues/136)
reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown]

def f[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]

# TODO: no error
# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`"
l5: int | list[int] = f(1, True)
```

`typed_dict.py`:

```py
from typing import TypedDict

class TD(TypedDict):
x: int

d1 = {"x": 1}
d2: TD = {"x": 1}
d3: dict[str, int] = {"x": 1}

reveal_type(d1) # revealed: dict[Unknown | str, Unknown | int]
reveal_type(d2) # revealed: TD
reveal_type(d3) # revealed: dict[str, int]

def _() -> TD:
return {"x": 1}

def _() -> TD:
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
return {}
```

## Propagating return type annotation

```toml
[environment]
python-version = "3.12"
```

```py
from typing import overload, Callable

def list1[T](x: T) -> list[T]:
return [x]

def get_data() -> dict | None:
return {}

def wrap_data() -> list[dict]:
if not (res := get_data()):
return list1({})
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
# `list[dict[Unknown, Unknown] & ~AlwaysFalsy]` and `list[dict[Unknown, Unknown]]` are incompatible,
# but the return type check passes here because the type of `list1(res)` is inferred
# by bidirectional type inference using the annotated return type, and the type of `res` is not used.
return list1(res)

def wrap_data2() -> list[dict] | None:
if not (res := get_data()):
return None
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
return list1(res)

def deco[T](func: Callable[[], T]) -> Callable[[], T]:
return func

def outer() -> Callable[[], list[dict]]:
@deco
def inner() -> list[dict]:
if not (res := get_data()):
return list1({})
reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy]
return list1(res)
return inner

@overload
def f(x: int) -> list[int]: ...
@overload
def f(x: str) -> list[str]: ...
def f(x: int | str) -> list[int] | list[str]:
# `list[int] | list[str]` is disjoint from `list[int | str]`.
if isinstance(x, int):
return list1(x)
else:
return list1(x)

reveal_type(f(1)) # revealed: list[int]
reveal_type(f("a")) # revealed: list[str]

async def g() -> list[int | str]:
return list1(1)

def h[T](x: T, cond: bool) -> T | list[T]:
return i(x, cond)

def i[T](x: T, cond: bool) -> T | list[T]:
return x if cond else [x]
```
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def union_param(x: T | None) -> T:
reveal_type(union_param("a")) # revealed: Literal["a"]
reveal_type(union_param(1)) # revealed: Literal[1]
reveal_type(union_param(None)) # revealed: Unknown

def _(x: int | None):
reveal_type(union_param(x)) # revealed: int
```

```py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def union_param[T](x: T | None) -> T:
reveal_type(union_param("a")) # revealed: Literal["a"]
reveal_type(union_param(1)) # revealed: Literal[1]
reveal_type(union_param(None)) # revealed: Unknown

def _(x: int | None):
reveal_type(union_param(x)) # revealed: int
```

```py
Expand Down
14 changes: 11 additions & 3 deletions crates/ty_python_semantic/resources/mdtest/typed_dict.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,10 @@ def homogeneous_list[T](*args: T) -> list[T]:
reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]]
plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None}
reveal_type(plot2["y"]) # revealed: list[int]
# TODO: no error
# error: [invalid-argument-type]

plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)}
reveal_type(plot3["y"]) # revealed: list[int]
reveal_type(plot3["x"]) # revealed: list[int] | None

Y = "y"
X = "x"
Expand Down Expand Up @@ -362,7 +363,7 @@ qualifiers override the class-level `total` setting, which sets the default (`to
all keys are required by default, `total=False` means that all keys are non-required by default):

```py
from typing_extensions import TypedDict, Required, NotRequired
from typing_extensions import TypedDict, Required, NotRequired, Final

# total=False by default, but id is explicitly Required
class Message(TypedDict, total=False):
Expand All @@ -376,10 +377,17 @@ class User(TypedDict):
email: Required[str] # Explicitly required (redundant here)
bio: NotRequired[str] # Optional despite total=True

ID: Final = "id"

# Valid Message constructions
msg1 = Message(id=1) # id required, content optional
msg2 = Message(id=2, content="Hello") # both provided
msg3 = Message(id=3, timestamp="2024-01-01") # id required, timestamp optional
msg4: Message = {"id": 4} # id required, content optional
msg5: Message = {ID: 5} # id required, content optional

def msg() -> Message:
return {ID: 1}

# Valid User constructions
user1 = User(name="Alice", email="alice@example.com") # required fields
Expand Down
13 changes: 13 additions & 0 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,10 @@ impl<'db> Type<'db> {
}
}

pub(crate) fn has_type_var(self, db: &'db dyn Db) -> bool {
any_over_type(db, self, &|ty| matches!(ty, Type::TypeVar(_)), false)
}

pub(crate) const fn into_class_literal(self) -> Option<ClassLiteral<'db>> {
match self {
Type::ClassLiteral(class_type) => Some(class_type),
Expand Down Expand Up @@ -1167,6 +1171,15 @@ impl<'db> Type<'db> {
if yes { self.negate(db) } else { *self }
}

/// Remove the union elements that are not related to `target`.
pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> {
if let Type::Union(union) = self {
union.filter(db, |elem| !elem.is_disjoint_from(db, target))
} else {
self
}
}

/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
/// is not a literal.
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {
Expand Down
Loading
Loading