From 7f79fb1d8767575822442e86843ad46091635699 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sun, 10 Dec 2023 11:54:21 +0100 Subject: [PATCH] Add anyio 4.x support --- pyproject.toml | 2 +- requirements.txt | 3 +++ src/aiometer/_compat.py | 23 ++++++++++++++++++++++ src/aiometer/_impl/amap.py | 39 +++++++++++++++++++++----------------- 4 files changed, 49 insertions(+), 18 deletions(-) create mode 100644 src/aiometer/_compat.py diff --git a/pyproject.toml b/pyproject.toml index e2b4301..5e022a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ - "anyio~=3.2", + "anyio>=3.2", "typing-extensions; python_version<'3.8'", ] dynamic = ["version", "readme"] diff --git a/requirements.txt b/requirements.txt index 6e3b886..eff3a20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,8 @@ -e . +# Compatibility testing. +anyio~=3.2; python_version<'3.11' + # Packaging. twine wheel diff --git a/src/aiometer/_compat.py b/src/aiometer/_compat.py new file mode 100644 index 0000000..b95288a --- /dev/null +++ b/src/aiometer/_compat.py @@ -0,0 +1,23 @@ +import sys +from contextlib import contextmanager +from typing import Generator + +has_exceptiongroups = True + +if sys.version_info < (3, 11): # pragma: no cover + try: + from exceptiongroup import BaseExceptionGroup + except ImportError: + has_exceptiongroups = False + + +@contextmanager +def collapse_excgroups() -> Generator[None, None, None]: + try: + yield + except BaseException as exc: + if has_exceptiongroups: + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] + + raise exc diff --git a/src/aiometer/_impl/amap.py b/src/aiometer/_impl/amap.py index 35c828b..d251a49 100644 --- a/src/aiometer/_impl/amap.py +++ b/src/aiometer/_impl/amap.py @@ -14,7 +14,9 @@ ) import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from .._compat import collapse_excgroups from .run_on_each import run_on_each from .types import T, U @@ -61,27 +63,30 @@ def amap( ) -> AsyncContextManager[AsyncIterable]: @asynccontextmanager async def _amap() -> AsyncIterator[AsyncIterable]: - send_channel, receive_channel = anyio.create_memory_object_stream( - max_buffer_size=len(args) - ) + channels: Tuple[ + MemoryObjectSendStream, MemoryObjectReceiveStream + ] = anyio.create_memory_object_stream(max_buffer_size=len(args)) + + send_channel, receive_channel = channels with send_channel, receive_channel: - async with anyio.create_task_group() as task_group: + with collapse_excgroups(): + async with anyio.create_task_group() as task_group: - async def sender() -> None: - # Make any `async for ... in results: ...` terminate. - with send_channel: - await run_on_each( - async_fn, - args, - max_at_once=max_at_once, - max_per_second=max_per_second, - _include_index=_include_index, - _send_to=send_channel, - ) + async def sender() -> None: + # Make any `async for ... in results: ...` terminate. + with send_channel: + await run_on_each( + async_fn, + args, + max_at_once=max_at_once, + max_per_second=max_per_second, + _include_index=_include_index, + _send_to=send_channel, + ) - task_group.start_soon(sender) + task_group.start_soon(sender) - yield receive_channel + yield receive_channel return _amap()