Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TypedDict unpacking in ParamSpec specifications #16120

Open
tmke8 opened this issue Sep 15, 2023 · 0 comments
Open

Support TypedDict unpacking in ParamSpec specifications #16120

tmke8 opened this issue Sep 15, 2023 · 0 comments
Labels
feature topic-paramspec PEP 612, ParamSpec, Concatenate topic-typed-dict

Comments

@tmke8
Copy link
Contributor

tmke8 commented Sep 15, 2023

Feature

Related to

TypedDict unpacking in ParamSpec would work just like it works now in Callables.

P = ParamSpec("P")
class C(Generic[P]):
    def __init__(self, f: Callable[P, None]): ...

class Args(TypedDict):
    x: int
    y: str

def f(*, x: int, y: str) -> None: ...
c: C[[Unpack[Args]]] = C(f)  # OK
d: C[[int, str]] =  = C(f)  # error because `f` expects keyword arguments

Pitch

In order to express a callable type with keyword arguments, you can use a call protocol, but this doesn't work for other classes that are generic in ParamSpec. For example, in pytorch, network layers have to inherit from Module which should be typed approximately like this (using Python 3.12 generic syntax):

class Module[T, **P]:
    @abstractmethod
    def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
        # do other stuff
        return self.forward(*args, **kwargs)

But what to do if I want to override forward with an optional argument?

class Dense(Module[Tensor, [Tensor, bool]]):
    def forward(self, x: Tensor, *, with_dropout: bool = False):
        # my implementation
        return x

With TypedDict unpacking in ParamSpec:

class ExtraArgs(TypedDict):
    with_dropout: bool

class Dense(Module[Tensor, [Tensor, Unpack[ExtraArgs]]]):
    def forward(self, x: Tensor, *, with_dropout: bool = False):
        # my implementation
        return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature topic-paramspec PEP 612, ParamSpec, Concatenate topic-typed-dict
Projects
None yet
Development

No branches or pull requests

2 participants