-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Better Type Annotation Support for PyTrees #6
Comments
Here's another possible improvement to the typing. Consider tree_map(Callable[..., U], tree: PyTree[T], ...) -> PyTree[U] Type-checking: a: int = 1
a = tree_map(lambda x: x+1, 1) will fail, even though Fixing this is achievable (although quite annoying) with overloads. E.g.: @overload
tree_map(Callable[..., U], tree: T, ...) -> U:
...
@overload
tree_map(Callable[..., U], tree: Tuple[T, ...], ...) -> Tuple[U, ...]:
...
@overload
tree_map(Callable[..., U], tree: Dict[Any, T], ...) -> Dict[Any, U]:
... and so on. Probably going one level deep is enough, for extra thoroughness two levels. |
To expand on this, type hinting return types is usually too loose; I would appreciate an improvement here as well: def f(x: PyTree[T]) -> PyTree[T]:
return x Here, we know that the return value of x: list[int] = [1, 2, 3]
y = f(x) # static type checker warns here due to OP's issue
s = sum(y) # static type checker warns here due to my issue, but we know that y is a list[int] Normally, one would avoid this by using T = TypeVar("T")
def f(x: T) -> T:
return x However, no such concept exists for the layout and type of T = TypeVar("T")
S = SpecVar("S") # new
def f(x: PyTree[T, S]) -> PyTree[T, S]:
return x This way, we can annotate that the returned value follows the same spec as the input argument. Unfortunately, we can't use keyword arguments here to avoid locking us to a single positional argument: PyTree[T, spec=S] # SyntaxError A possible work-around as shown in PEP-472 would be to use slices or dictionaries: PyTree[T, "spec": S] # passes slice('spec', S, None) to __class_getitem__
PyTree[T, {"spec": S}] # passes the dict to __class_getitem__ |
Motivation
OpTree uses a generic custom class that returns a
Union
alias in the__class_getitem__
method.For example:
will expend to:
at runtime.
The typing linter
mypy
is a static analyzer, which does not actually run the code.In addition, Python does not support generic recursive type annotations yet. For function type annotations, the generic version of the typing substitution
PyTree[T]
will be substituted to:while the
ForwardRef('PyTree[T]')
will never be evaluated. This will causemypy
to fail to infer the arg/return type. It either raisesarg-type
orassignment
error. Usingtyping.cast
alleviate this issue butcast
has a non-zero overhead at runtime.Function signature:
cast
:mypy
infersT = int
and requires the input is an exactPyTree[int]
object rather than aUnion[...]
type.mypy
refuse to add type assignments forPyTree[int]
.Related issues:
typing.ForwardRef
to support generic recursive typesChecklist
The text was updated successfully, but these errors were encountered: