Skip to content

Commit

Permalink
Harmonized default task names across backends (nedbat#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm authored Jun 14, 2020
1 parent a9fc2e1 commit 8cf63e0
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 6 deletions.
15 changes: 14 additions & 1 deletion anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def current_task(loop: Optional[asyncio.AbstractEventLoop] = None) -> Optional[a
_native_task_names = hasattr(asyncio.Task, 'get_name')


def get_callable_name(func: Callable) -> str:
module = getattr(func, '__module__', None)
qualname = getattr(func, '__qualname__', None)
return '.'.join([x for x in (module, qualname) if x])


#
# Event loop
#
Expand Down Expand Up @@ -94,7 +100,13 @@ async def wrapper():
exception = retval = None
loop = asyncio.get_event_loop()
loop.set_debug(debug)
loop.run_until_complete(wrapper())
task = loop.create_task(wrapper())
task_state = TaskState(None, get_callable_name(func), None)
_task_states[task] = task_state
if _native_task_names:
task.set_name(task_state.name)

loop.run_until_complete(task)
if exception is not None:
raise exception
else:
Expand Down Expand Up @@ -409,6 +421,7 @@ async def spawn(self, func: Callable[..., Coroutine], *args, name=None) -> None:
if not self._active:
raise RuntimeError('This task group is not active; no new tasks can be spawned.')

name = name or get_callable_name(func)
if _native_task_names is None:
task = create_task(self._run_wrapped_task(func, args), name=name) # type: ignore
else:
Expand Down
13 changes: 10 additions & 3 deletions anyio/_backends/_curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
spawn_kwargs = {}


def get_callable_name(func: Callable) -> str:
module = getattr(func, '__module__', None)
qualname = getattr(func, '__qualname__', None)
return '.'.join([x for x in (module, qualname) if x])


#
# Event loop
#
Expand All @@ -42,7 +48,9 @@ async def wrapper():
exception = exc

exception = retval = None
curio.run(wrapper, **curio_options)
coro = wrapper()
coro.__qualname__ = get_callable_name(func)
curio.run(coro, **curio_options)
if exception is not None:
raise exception
else:
Expand Down Expand Up @@ -348,8 +356,7 @@ async def spawn(self, func: Callable[..., Coroutine], *args, name=None) -> None:

task = await curio.spawn(self._run_wrapped_task, func, args, daemon=True, **spawn_kwargs)
task.parentid = (await curio.current_task()).id
if name is not None:
task.name = name
task.name = name or get_callable_name(func)

# Make the spawned task inherit the task group's cancel scope
_task_states[task] = TaskState(cancel_scope=self.cancel_scope)
Expand Down
3 changes: 3 additions & 0 deletions anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import trio.from_thread
from async_generator import async_generator, yield_, asynccontextmanager, aclosing
from trio.to_thread import run_sync

from .. import abc, claim_worker_thread, T_Retval, TaskInfo
from ..exceptions import (
ExceptionGroup as BaseExceptionGroup, ClosedResourceError, ResourceBusyError, WouldBlock)
Expand All @@ -17,6 +18,8 @@
from trio.hazmat import wait_readable, wait_writable, notify_closing
else:
from trio.lowlevel import wait_readable, wait_writable, notify_closing


#
# Event loop
#
Expand Down
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
**UNRELEASED**

- Fixed ``fail.after(0)`` not raising a timeout error on asyncio and curio
- Harmonized the default task names across all backends

**1.3.1**

Expand Down
16 changes: 14 additions & 2 deletions tests/test_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,21 @@

import pytest

import anyio
from anyio import (
create_task_group, create_event, wait_all_tasks_blocked, get_running_tasks, get_current_task)


def test_main_task_name(anyio_backend_name, anyio_backend_options):
async def main():
nonlocal task_name
task_name = (await get_current_task()).name

task_name = None
anyio.run(main, backend=anyio_backend_name, backend_options=anyio_backend_options)
assert task_name == 'test_debugging.test_main_task_name.<locals>.main'


@pytest.mark.anyio
async def test_get_running_tasks():
async def inspect():
Expand All @@ -21,10 +32,11 @@ async def inspect():
existing_tasks = set(await get_running_tasks())
await tg.spawn(event.wait, name='task1')
await tg.spawn(event.wait, name='task2')
await tg.spawn(inspect, name='inspector')
await tg.spawn(inspect)

assert len(task_infos) == 3
for task, expected_name in zip(task_infos, ['inspector', 'task1', 'task2']):
expected_names = ['task1', 'task2', 'test_debugging.test_get_running_tasks.<locals>.inspect']
for task, expected_name in zip(task_infos, expected_names):
assert task.parent_id == host_task.id
assert task.name == expected_name
assert repr(task) == "TaskInfo(id={}, name={!r})".format(task.id, expected_name)
Expand Down

0 comments on commit 8cf63e0

Please sign in to comment.