Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Better Type Annotation Support for PyTrees #6

Closed
1 task done
XuehaiPan opened this issue Oct 21, 2022 · 2 comments · Fixed by #166
Closed
1 task done

[Feature Request] Better Type Annotation Support for PyTrees #6

XuehaiPan opened this issue Oct 21, 2022 · 2 comments · Fixed by #166
Assignees
Labels
enhancement New feature or request py Something related to the Python source code

Comments

@XuehaiPan
Copy link
Member

XuehaiPan commented Oct 21, 2022

Motivation

OpTree uses a generic custom class that returns a Union alias in the __class_getitem__ method.

For example:

from optree.typing import PyTree

TreeOfInts = PyTree[int]

will expend to:

TreeOfInts = typing.Union[
    int, 
    typing.Tuple[ForwardRef('PyTree[int]'), ...], 
    typing.List[ForwardRef('PyTree[int]')], 
    typing.Dict[typing.Any, ForwardRef('PyTree[int]')], 
    typing.Deque[ForwardRef('PyTree[int]')],
    optree.typing.CustomTreeNode[ForwardRef('PyTree[int]')]
]

at runtime.

The typing linter mypy is a static analyzer, which does not actually run the code.

In addition, Python does not support generic recursive type annotations yet. For function type annotations, the generic version of the typing substitution PyTree[T] will be substituted to:

typing.Union[~T,
    typing.Tuple[ForwardRef('PyTree[T]'), ...],
    typing.List[ForwardRef('PyTree[T]')],
    typing.Dict[typing.Any, ForwardRef('PyTree[T]')],
    typing.Deque[ForwardRef('PyTree[T]')],
    optree.typing.CustomTreeNode[ForwardRef('PyTree[T]')]
]

while the ForwardRef('PyTree[T]') will never be evaluated. This will cause mypy to fail to infer the arg/return type. It either raises arg-type or assignment error. Using typing.cast alleviate this issue but cast has a non-zero overhead at runtime.

Function signature:

def tree_leaves(
    tree: PyTree[T],
    is_leaf: Optional[Callable[[T], bool]] = None,
    *,
    none_is_leaf: bool = False,
) -> List[T]: ...
  • Use cast:
from typing import List, cast

import optree
from optree.typing import PyTree

list_of_ints: List[int]

tree_of_ints = cast(PyTree[int], (0, {'a': 1, 'b': (2, [3, 4])}))
list_of_ints = optree.tree_leaves(tree_of_ints)
$ mypy test.py
Success: no issues found in 1 source file
  • Annotate returns: mypy infers T = int and requires the input is an exact PyTree[int] object rather than a Union[...] type.
# test1.py

from typing import List

import optree

list_of_ints: List[int]

tree_of_ints = (0, {'a': 1, 'b': (2, [3, 4])})
list_of_ints = optree.tree_leaves(tree_of_ints)
$ mypy test1.py
test1.py:10: error: Argument 1 to "tree_leaves" has incompatible type "Tuple[int, Dict[str, object]]"; expected "PyTree[int]"  [arg-type]
    list_of_ints = optree.tree_leaves(tree_of_ints)
                                      ^~~~~~~~~~~~
Found 1 error in 1 file (checked 1 source file)
  • Annotate args: mypy refuse to add type assignments for PyTree[int].
# test2.py

from typing import List

import optree
from optree.typing import PyTree

list_of_ints: List[int]

tree_of_ints: PyTree[int] = (0, {'a': 1, 'b': (2, [3, 4])})
list_of_ints = optree.tree_leaves(tree_of_ints)
$ mypy test2.py
test2.py:10: error: Incompatible types in assignment (expression has type "Tuple[int, Dict[str, object]]", variable has type "PyTree[int]")  [assignment]
    tree_of_ints: PyTree[int] = (0, {'a': 1, 'b': (2, [3, 4])})
                                ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Found 1 error in 1 file (checked 1 source file)

Related issues:

Checklist

  • I have checked that there is no similar issue in the repo. (required)
@XuehaiPan XuehaiPan added the enhancement New feature or request label Oct 21, 2022
@XuehaiPan XuehaiPan added the py Something related to the Python source code label Jan 22, 2023
@XuehaiPan XuehaiPan self-assigned this May 23, 2023
@rhaps0dy
Copy link

rhaps0dy commented Sep 9, 2023

Here's another possible improvement to the typing. Consider tree_map:

tree_map(Callable[..., U],  tree: PyTree[T], ...) -> PyTree[U]

Type-checking:

a: int = 1
a = tree_map(lambda x: x+1, 1)

will fail, even though a evaluates to 2.

Fixing this is achievable (although quite annoying) with overloads. E.g.:

@overload
tree_map(Callable[..., U],  tree: T, ...) -> U:
    ...
@overload
tree_map(Callable[..., U],  tree: Tuple[T, ...], ...) -> Tuple[U, ...]:
    ...

@overload
tree_map(Callable[..., U],  tree: Dict[Any, T], ...) -> Dict[Any, U]:
    ...

and so on. Probably going one level deep is enough, for extra thoroughness two levels.

@LarsKue
Copy link

LarsKue commented Jul 30, 2024

To expand on this, type hinting return types is usually too loose; I would appreciate an improvement here as well:

def f(x: PyTree[T]) -> PyTree[T]:
    return x

Here, we know that the return value of f(x) necessarily has the same type and layout as x itself. However, while the argument PyTree[T] is permissive, the return value PyTree[T] is unspecific. This means we necessarily lose type information through the invocation of f. Subsequent code may need to rely on the concrete type of f(x), which is no longer possible:

x: list[int] = [1, 2, 3]
y = f(x)  # static type checker warns here due to OP's issue
s = sum(y)  # static type checker warns here due to my issue, but we know that y is a list[int]

Normally, one would avoid this by using TypeVar:

T = TypeVar("T")

def f(x: T) -> T:
    return x

However, no such concept exists for the layout and type of PyTree. A potential fix could look like this:

T = TypeVar("T")
S = SpecVar("S")  # new

def f(x: PyTree[T, S]) -> PyTree[T, S]:
    return x

This way, we can annotate that the returned value follows the same spec as the input argument. Unfortunately, we can't use keyword arguments here to avoid locking us to a single positional argument:

PyTree[T, spec=S]  # SyntaxError

A possible work-around as shown in PEP-472 would be to use slices or dictionaries:

PyTree[T, "spec": S]  # passes slice('spec', S, None) to __class_getitem__
PyTree[T, {"spec": S}]  # passes the dict to __class_getitem__

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request py Something related to the Python source code
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants