Skip to content

Commit

Permalink
Merge pull request #908 from oremanj/generics
Browse files Browse the repository at this point in the history
Permit referring to some generic types in generic ways
  • Loading branch information
njsmith authored Feb 8, 2019
2 parents 777b7bb + 7c2d2d6 commit 1dc21fa
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/reference-core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1221,7 +1221,7 @@ many cases, you just want to pass objects between different tasks
inside a single process, and for that you can use
:func:`trio.open_memory_channel`:

.. autofunction:: open_memory_channel
.. autofunction:: open_memory_channel(max_buffer_size)

.. note:: If you've used the :mod:`threading` or :mod:`asyncio`
modules, you may be familiar with :class:`queue.Queue` or
Expand Down
7 changes: 7 additions & 0 deletions newsfragments/908.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:class:`~trio.abc.SendChannel`, :class:`~trio.abc.ReceiveChannel`, :class:`~trio.abc.Listener`,
and :func:`~trio.open_memory_channel` can now be referenced using a generic type parameter
(the type of object sent over the channel or produced by the listener) using PEP 484 syntax:
``trio.abc.SendChannel[bytes]``, ``trio.abc.Listener[trio.SocketStream]``,
``trio.open_memory_channel[MyMessage](5)``, etc. The added type information does not change
the runtime semantics, but permits better integration with external static type checkers.

22 changes: 19 additions & 3 deletions trio/_abc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import Generic, TypeVar
from ._util import aiter_compat
from . import _core

Expand Down Expand Up @@ -483,7 +484,22 @@ async def send_eof(self):
"""


class Listener(AsyncResource):
# The type of object produced by a ReceiveChannel (covariant because
# ReceiveChannel[Derived] can be passed to someone expecting
# ReceiveChannel[Base])
T_co = TypeVar("T_co", covariant=True)

# The type of object accepted by a SendChannel (contravariant because
# SendChannel[Base] can be passed to someone expecting
# SendChannel[Derived])
T_contra = TypeVar("T_contra", contravariant=True)

# The type of object produced by a Listener (covariant plus must be
# an AsyncResource)
T_resource = TypeVar("T_resource", bound=AsyncResource, covariant=True)


class Listener(AsyncResource, Generic[T_resource]):
"""A standard interface for listening for incoming connections.
:class:`Listener` objects also implement the :class:`AsyncResource`
Expand Down Expand Up @@ -521,7 +537,7 @@ async def accept(self):
"""


class SendChannel(AsyncResource):
class SendChannel(AsyncResource, Generic[T_contra]):
"""A standard interface for sending Python objects to some receiver.
:class:`SendChannel` objects also implement the :class:`AsyncResource`
Expand Down Expand Up @@ -595,7 +611,7 @@ def clone(self):
"""


class ReceiveChannel(AsyncResource):
class ReceiveChannel(AsyncResource, Generic[T_co]):
"""A standard interface for receiving Python objects from some sender.
You can iterate over a :class:`ReceiveChannel` using an ``async for``
Expand Down
2 changes: 2 additions & 0 deletions trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from . import _core
from .abc import SendChannel, ReceiveChannel
from ._util import generic_function


@generic_function
def open_memory_channel(max_buffer_size):
"""Open a channel for passing objects between tasks within a process.
Expand Down
2 changes: 1 addition & 1 deletion trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def getsockopt(self, level, option, buffersize=0):
pass


class SocketListener(Listener):
class SocketListener(Listener[SocketStream]):
"""A :class:`~trio.abc.Listener` that uses a listening socket to accept
incoming connections as :class:`SocketStream` objects.
Expand Down
2 changes: 1 addition & 1 deletion trio/_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ async def wait_send_all_might_not_block(self):
await self.transport_stream.wait_send_all_might_not_block()


class SSLListener(Listener):
class SSLListener(Listener[SSLStream]):
"""A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers.
:class:`SSLListener` wraps around another Listener, and converts
Expand Down
39 changes: 38 additions & 1 deletion trio/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import signal
import sys
import pathlib
from functools import wraps
from functools import wraps, update_wrapper
import typing as t

import async_generator
Expand All @@ -22,6 +22,7 @@
"ConflictDetector",
"fixup_module_metadata",
"fspath",
"generic_function",
]

# Equivalent to the C function raise(), which Python doesn't wrap
Expand Down Expand Up @@ -171,7 +172,15 @@ def decorator(func):


def fixup_module_metadata(module_name, namespace):
seen_ids = set()

def fix_one(obj):
# avoid infinite recursion (relevant when using
# typing.Generic, for example)
if id(obj) in seen_ids:
return
seen_ids.add(id(obj))

mod = getattr(obj, "__module__", None)
if mod is not None and mod.startswith("trio."):
obj.__module__ = module_name
Expand Down Expand Up @@ -242,3 +251,31 @@ def fspath(path) -> t.Union[str, bytes]:

if hasattr(os, "fspath"):
fspath = os.fspath # noqa


class generic_function:
"""Decorator that makes a function indexable, to communicate
non-inferrable generic type parameters to a static type checker.
If you write::
@generic_function
def open_memory_channel(max_buffer_size: int) -> Tuple[
SendChannel[T], ReceiveChannel[T]
]: ...
it is valid at runtime to say ``open_memory_channel[bytes](5)``.
This behaves identically to ``open_memory_channel(5)`` at runtime,
and currently won't type-check without a mypy plugin or clever stubs,
but at least it becomes possible to write those.
"""

def __init__(self, fn):
update_wrapper(self, fn)
self._fn = fn

def __call__(self, *args, **kwargs):
return self._fn(*args, **kwargs)

def __getitem__(self, _):
return self
28 changes: 28 additions & 0 deletions trio/tests/test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,31 @@ async def aclose(self):
assert myar.record == []

assert myar.record == ["ac"]


def test_abc_generics():
# Pythons below 3.5.2 had a typing.Generic that would throw
# errors when instantiating or subclassing a parameterized
# version of a class with any __slots__. This is why RunVar
# (which has slots) is not generic. This tests that
# the generic ABCs are fine, because while they are slotted
# they don't actually define any slots.

class SlottedChannel(tabc.SendChannel[tabc.Stream]):
__slots__ = ("x",)

def send_nowait(self, value):
raise RuntimeError

async def send(self, value):
raise RuntimeError # pragma: no cover

def clone(self):
raise RuntimeError # pragma: no cover

async def aclose(self):
pass # pragma: no cover

channel = SlottedChannel()
with pytest.raises(RuntimeError):
channel.send_nowait(None)
18 changes: 17 additions & 1 deletion trio/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from .. import _core
from .._threads import run_sync_in_worker_thread
from .._util import signal_raise, ConflictDetector, fspath, is_main_thread
from .._util import (
signal_raise, ConflictDetector, fspath, is_main_thread, generic_function
)
from ..testing import wait_all_tasks_blocked, assert_checkpoints


Expand Down Expand Up @@ -168,3 +170,17 @@ def not_main_thread():
assert not is_main_thread()

await run_sync_in_worker_thread(not_main_thread)


def test_generic_function():
@generic_function
def test_func(arg):
"""Look, a docstring!"""
return arg

assert test_func is test_func[int] is test_func[int, str]
assert test_func(42) == test_func[int](42) == 42
assert test_func.__doc__ == "Look, a docstring!"
assert test_func.__qualname__ == "test_generic_function.<locals>.test_func"
assert test_func.__name__ == "test_func"
assert test_func.__module__ == __name__

0 comments on commit 1dc21fa

Please sign in to comment.