Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permit referring to some generic types in generic ways #908

Merged
merged 10 commits into from
Feb 8, 2019
2 changes: 1 addition & 1 deletion docs/source/reference-core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1221,7 +1221,7 @@ many cases, you just want to pass objects between different tasks
inside a single process, and for that you can use
:func:`trio.open_memory_channel`:

.. autofunction:: open_memory_channel
.. autofunction:: open_memory_channel(max_buffer_size)

.. note:: If you've used the :mod:`threading` or :mod:`asyncio`
modules, you may be familiar with :class:`queue.Queue` or
Expand Down
7 changes: 7 additions & 0 deletions newsfragments/908.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:class:`~trio.abc.SendChannel`, :class:`~trio.abc.ReceiveChannel`, :class:`~trio.abc.Listener`,
and :func:`~trio.open_memory_channel` can now be referenced using a generic type parameter
(the type of object sent over the channel or produced by the listener) using PEP 484 syntax:
``trio.abc.SendChannel[bytes]``, ``trio.abc.Listener[trio.SocketStream]``,
``trio.open_memory_channel[MyMessage](5)``, etc. The added type information does not change
the runtime semantics, but permits better integration with external static type checkers.

22 changes: 19 additions & 3 deletions trio/_abc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import Generic, TypeVar
from ._util import aiter_compat
from . import _core

Expand Down Expand Up @@ -483,7 +484,22 @@ async def send_eof(self):
"""


class Listener(AsyncResource):
# The type of object produced by a ReceiveChannel (covariant because
# ReceiveChannel[Derived] can be passed to someone expecting
# ReceiveChannel[Base])
T_co = TypeVar("T_co", covariant=True)

# The type of object accepted by a SendChannel (contravariant because
# SendChannel[Base] can be passed to someone expecting
# SendChannel[Derived])
T_contra = TypeVar("T_contra", contravariant=True)

# The type of object produced by a Listener (covariant plus must be
# an AsyncResource)
T_resource = TypeVar("T_resource", bound=AsyncResource, covariant=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the bound=AsyncResource will disappear when we fix #636. We could start skipping it now, or we could put it in and then take it out later, doesn't make much difference IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You think we'll have Listeners that provide objects not representable as an AsyncResource? I'm a bit surprised at that, can you give an example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an example in mind. But the reason we restrict it to AsyncResource right now is because serve_listeners automatically closes the object after the handler returns:

async def _run_handler(stream, handler):
try:
await handler(stream)
finally:
await trio.aclose_forcefully(stream)

With #636, this code goes away, because the Listener itself is responsible for managing the handler lifetime. So... at that point it's really between the person developing a Listener[X] and their users, what kind of API X should support :-).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes sense! I think it's probably sensible to keep the AsyncResource bound for now, since it reflects the current requirements, and we can easily change it later.



class Listener(AsyncResource, Generic[T_resource]):
oremanj marked this conversation as resolved.
Show resolved Hide resolved
"""A standard interface for listening for incoming connections.

:class:`Listener` objects also implement the :class:`AsyncResource`
Expand Down Expand Up @@ -521,7 +537,7 @@ async def accept(self):
"""


class SendChannel(AsyncResource):
class SendChannel(AsyncResource, Generic[T_contra]):
"""A standard interface for sending Python objects to some receiver.

:class:`SendChannel` objects also implement the :class:`AsyncResource`
Expand Down Expand Up @@ -595,7 +611,7 @@ def clone(self):
"""


class ReceiveChannel(AsyncResource):
class ReceiveChannel(AsyncResource, Generic[T_co]):
"""A standard interface for receiving Python objects from some sender.

You can iterate over a :class:`ReceiveChannel` using an ``async for``
Expand Down
2 changes: 2 additions & 0 deletions trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from . import _core
from .abc import SendChannel, ReceiveChannel
from ._util import generic_function


@generic_function
def open_memory_channel(max_buffer_size):
"""Open a channel for passing objects between tasks within a process.

Expand Down
35 changes: 33 additions & 2 deletions trio/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import signal
import sys
import pathlib
from functools import wraps
from functools import wraps, update_wrapper
import typing as t

import async_generator
Expand All @@ -22,6 +22,7 @@
"ConflictDetector",
"fixup_module_metadata",
"fspath",
"indexable_function",
oremanj marked this conversation as resolved.
Show resolved Hide resolved
]

# Equivalent to the C function raise(), which Python doesn't wrap
Expand Down Expand Up @@ -177,7 +178,9 @@ def fix_one(obj):
obj.__module__ = module_name
if isinstance(obj, type):
for attr_value in obj.__dict__.values():
fix_one(attr_value)
# avoid infinite recursion when using typing.Generic
if attr_value is not obj:
fix_one(attr_value)
oremanj marked this conversation as resolved.
Show resolved Hide resolved

for objname, obj in namespace.items():
if not objname.startswith("_"): # ignore private attributes
Expand Down Expand Up @@ -242,3 +245,31 @@ def fspath(path) -> t.Union[str, bytes]:

if hasattr(os, "fspath"):
fspath = os.fspath # noqa


class generic_function:
"""Decorator that makes a function indexable, to communicate
non-inferrable generic type parameters to a static type checker.

If you write::

@generic_function
def open_memory_channel(max_buffer_size: int) -> Tuple[
SendChannel[T], ReceiveChannel[T]
]: ...

it is valid at runtime to say ``open_memory_channel[bytes](5)``.
This behaves identically to ``open_memory_channel(5)`` at runtime,
and currently won't type-check without a mypy plugin or clever stubs,
but at least it becomes possible to write those.
"""

def __init__(self, fn):
update_wrapper(self, fn)
self._fn = fn

def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)

def __getitem__(self, _):
return self
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you actually can make open_memory_channel directly typeable by mypy without any plugin, and it's not too terribly awful: python/mypy#6073 (comment)

But I think it does require inlining this class into the definition of open_memory_channel.

If that's where were going to end up, then maybe setting up a standalone @generic_function decorator, testing it, etc., is overkill, and we should instead move this code into _memory_channels.py right now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with that approach is that AFAICT, when you write Type[T], mypy can't infer T as one of its "special forms" (Tuple, Union, Callable, etc). You just silently get Any if you use such a one. I figured that was error-prone enough that we should go this route instead, at least for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ick.

21 changes: 21 additions & 0 deletions trio/tests/test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,24 @@ async def aclose(self):
assert myar.record == []

assert myar.record == ["ac"]


def test_abc_generics():
oremanj marked this conversation as resolved.
Show resolved Hide resolved
class SlottedChannel(tabc.SendChannel[tabc.Stream]):
__slots__ = ("x",)

def send_nowait(self, value):
raise RuntimeError

async def send(self, value):
raise RuntimeError # pragma: no cover

def clone(self):
raise RuntimeError # pragma: no cover

async def aclose(self):
pass # pragma: no cover

channel = SlottedChannel()
with pytest.raises(RuntimeError):
channel.send_nowait(None)
18 changes: 17 additions & 1 deletion trio/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from .. import _core
from .._threads import run_sync_in_worker_thread
from .._util import signal_raise, ConflictDetector, fspath, is_main_thread
from .._util import (
signal_raise, ConflictDetector, fspath, is_main_thread, generic_function
)
from ..testing import wait_all_tasks_blocked, assert_checkpoints


Expand Down Expand Up @@ -168,3 +170,17 @@ def not_main_thread():
assert not is_main_thread()

await run_sync_in_worker_thread(not_main_thread)


def test_generic_function():
@generic_function
def test_func(arg):
"""Look, a docstring!"""
return arg

assert test_func is test_func[int] is test_func[int, str]
assert test_func(42) == test_func[int](42) == 42
assert test_func.__doc__ == "Look, a docstring!"
assert test_func.__qualname__ == "test_generic_function.<locals>.test_func"
assert test_func.__name__ == "test_func"
assert test_func.__module__ == __name__