Skip to content

Commit

Permalink
fix: use a custom dispatcher to run coroutines functions
Browse files Browse the repository at this point in the history
  • Loading branch information
hlouzada committed Apr 9, 2024
1 parent a1695db commit 9758b54
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 116 deletions.
213 changes: 108 additions & 105 deletions spyder/api/asyncrunner.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations
import asyncio
import atexit
import contextlib
from concurrent.futures import Future, CancelledError
import functools
import threading
from collections.abc import Coroutine
from typing import Callable
from typing import Any
import logging
import typing

from qt_async_threads import QtAsyncRunner
_logger = logging.getLogger(__name__)


class AsSync:
class AsyncDispatcher:
"""Decorator to convert a coroutine to a sync function.
Helper class to facilitate the conversion of coroutines to sync functions
or to run a coroutine as a sync function without the need to call the event
loop method.
Expand All @@ -20,21 +21,27 @@ class AsSync:
------
As a decorator:
```
@AsSync
@AsyncDispatcher.dispatch()
async def my_coroutine():
pass
my_coroutine()
```
As a class wrapper:
```
sync_coroutine = AsSync(my_coroutine)
sync_coroutine = AsyncDispatcher(my_coroutine)
sync_coroutine()
```
```
"""
def __init__(self, coro, loop=None):

__closed = False
__running_loops: typing.ClassVar[dict[int, asyncio.AbstractEventLoop]] = {}
__running_threads: typing.ClassVar[dict[int, threading.Thread]] = {}
_running_tasks: typing.ClassVar[list[Future]] = []

def __init__(self, coro, *, loop=None, early_return=True):
"""Initialize the decorator.
Parameters
Expand All @@ -44,110 +51,106 @@ def __init__(self, coro, loop=None):
loop : asyncio.AbstractEventLoop, optional
The event loop to be used, by default get the current event loop.
"""
self.__coro = coro
self.__loop = loop or asyncio.get_event_loop()
functools.update_wrapper(self, coro)
if not asyncio.iscoroutinefunction(coro):
msg = f'{coro} is not a coroutine function'
raise TypeError(msg)
self._coro = coro
self._loop = self._ensure_running_loop(loop)
self._early_return = early_return

def __call__(self, *args, **kwargs):
return self.__loop.run_until_complete(self.__coro(*args, **kwargs))
task = asyncio.run_coroutine_threadsafe(self._coro(*args, **kwargs), loop=self._loop)
if self._early_return:
AsyncDispatcher._running_tasks.append(task)
task.add_done_callback(self._callback_task_done)
return task
return task.result()

def __get__(self, instance, owner):
if instance is None:
return self
else:
bound_method = self.__coro.__get__(instance, owner)
return functools.partial(self.__class__(bound_method, self.__loop))
@classmethod
def dispatch(cls, *, loop=None, early_return=True):
"""Create a decorator to run the coroutine with a given event loop."""
def decorator(coro):
@functools.wraps(coro)
def wrapper(*args, **kwargs):
return cls(coro, loop=loop, early_return=early_return)(*args, **kwargs)
return wrapper
return decorator

def _callback_task_done(self, future):
AsyncDispatcher._running_tasks.remove(future)
with contextlib.suppress(asyncio.CancelledError, CancelledError):
if (exception := future.exception()) is not None:
raise exception

@classmethod
def _ensure_running_loop(cls, loop=None):
loop = cls.__running_loops.get(id(loop), asyncio.get_event_loop())

class SpyderQAsyncRunner(QtAsyncRunner):
"""Reimplement QtAsyncRunner as a singleton."""
try:
if loop.is_running():
return loop
except RuntimeError:
_logger.warning('Failed to check if the loop is running, defaulting to the current loop.')
if len(cls.__running_loops):
return cls.__running_loops[id(asyncio.get_event_loop())]

_instance = None
_rlock = threading.RLock()
__inside_instance = False
return cls.__run_loop(id(loop), loop)

@classmethod
def instance(cls, *args, **kwargs):
"""Get *the* class instance.
def __run_loop(cls, loop_id, loop):
cls.__running_threads[loop_id] = threading.Thread(target=loop.run_forever, daemon=True)
cls.__running_threads[loop_id].start()
cls.__running_loops[loop_id] = loop
return loop

Return the instance of the class. If it did not exist yet, create it
by calling the "constructor" with whatever arguments and keyword
arguments provided.
@atexit.register
@staticmethod
def close():
"""Close the thread pool."""
if AsyncDispatcher.__closed:
return
AsyncDispatcher.cancel_all()
AsyncDispatcher.join()
AsyncDispatcher.__closed = True

Returns
-------
instance(object)
Class Singleton Instance
"""
if cls._instance is not None:
return cls._instance
with cls._rlock:
# Re-check, perhaps it was created in the meantime...
if cls._instance is None:
cls.__inside_instance = True
try:
cls._instance = cls(*args, **kwargs)
finally:
cls.__inside_instance = False
return cls._instance

def __new__(cls, *args, **kwargs):
"""Class constructor.
Ensures that this class isn't created without
the ``instance`` class method.
Returns
-------
object
Class instance
Raises
------
RuntimeError
Exception when not called from the ``instance`` class method.
"""
if cls._instance is None:
with cls._rlock:
if cls._instance is None and cls.__inside_instance:
return super().__new__(cls)

raise RuntimeError(
f"Attempt to create a {cls.__qualname__} instance outside of instance()"
)

async def run(
self, func: Callable | Coroutine, *args: Any, **kwargs: Any
) -> Any:
"""
Updated to handle both functions and coroutines. If `func` is a coroutine,
it is scheduled to run in the thread pool.
"""
if asyncio.iscoroutinefunction(func) or isinstance(func, Coroutine):
# Wrap the coroutine in a function that sets up an event loop
async_func = func if asyncio.iscoroutinefunction(func) else functools.partial(lambda x: x, func)
loop = asyncio.get_event_loop()
def run_coro_in_thread():
coroutine = async_func(*args, **kwargs)
return asyncio.run_coroutine_threadsafe(coroutine, loop)

func_to_run = run_coro_in_thread
else:
func_to_run = functools.partial(func, *args, **kwargs)

async for result in self.run_parallel([func_to_run]):
return result
assert False, "should never be reached"
@classmethod
def cancel_all(self):
"""Cancel all running tasks."""
for task in self._running_tasks:
task.cancel()

@atexit.register
def __atexit():
SpyderQAsyncRunner.instance().close()
@classmethod
def join(self, timeout: float | None = None):
""" Blocking call to close the thread pool
:param timeout: timeout for polling a thread to check if its async tasks are all finished
"""
for loop_id in list(self.__running_loops.keys()):
self._stop_running_loop(loop_id, timeout)

@classmethod
def run_async(cls, coro):
"""Decorator to run a coroutine asynchronously."""
@functools.wraps(coro)
def wrapper(*args, **kwargs):
return cls.instance().start_coroutine(cls.instance().run(coro, *args, **kwargs))
return wrapper
def _stop_running_loop(cls, loop_id, timeout = None):
thread = cls.__running_threads.pop(loop_id, None)
loop = cls.__running_loops.pop(loop_id, None)

if loop is None:
return

if thread is None:
return

if loop.is_closed():
thread.join(timeout)
return

loop_stoped = threading.Event()

def _stop():
loop.stop()
loop_stoped.set()

loop.call_soon_threadsafe(_stop)

loop_stoped.wait(timeout)
thread.join(timeout)
loop.close()
21 changes: 10 additions & 11 deletions spyder/plugins/remoteclient/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from spyder.api.plugins import Plugins, SpyderPluginV2
from spyder.plugins.remoteclient.api.protocol import KernelConnectionInfo, DeleteKernel, KernelInfo, KernelsList, SSHClientOptions
from spyder.plugins.remoteclient.api.client import SpyderRemoteClient
from spyder.api.asyncrunner import AsSync
from spyder.api.asyncrunner import SpyderQAsyncRunner
from spyder.api.asyncrunner import AsyncDispatcher

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,7 +76,7 @@ def on_first_registration(self):
def on_close(self, cancellable=True):
"""Stops remote server and close any opened connection."""
for client in self._remote_clients.values():
AsSync(client.close)()
AsyncDispatcher(client.close, early_return=False)()

# ---- Public API
# -------------------------------------------------------------------------
Expand All @@ -87,28 +86,28 @@ def get_remote_server(self, config_id):
if config_id in self._remote_clients:
return self._remote_clients[config_id]

@SpyderQAsyncRunner.run_async
@AsyncDispatcher.dispatch()
async def _install_remote_server(self, config_id):
"""Install remote server."""
if config_id in self._remote_clients:
client = self._remote_clients[config_id]
await client.connect_and_install_remote_server()

@SpyderQAsyncRunner.run_async
@AsyncDispatcher.dispatch()
async def start_remote_server(self, config_id):
"""Start remote server."""
if config_id in self._remote_clients:
server = self._remote_clients[config_id]
await server.connect_and_start_server()

@SpyderQAsyncRunner.run_async
@AsyncDispatcher.dispatch()
async def stop_remote_server(self, config_id):
"""Stop remote server."""
if config_id in self._remote_clients:
client = self._remote_clients[config_id]
await client.stop_remote_server()

@SpyderQAsyncRunner.run_async
@AsyncDispatcher.dispatch()
async def ensure_remote_server(self, config_id):
"""Ensure remote server is running and installed."""
if config_id in self._remote_clients:
Expand Down Expand Up @@ -149,7 +148,7 @@ def get_conf_ids(self):
# -------------------------------------------------------------------------
# --- Remote Server Kernel Methods
@Slot(str)
@SpyderQAsyncRunner.run_async
@AsyncDispatcher.dispatch()
async def get_kernels(self, config_id):
"""Get opened kernels."""
if config_id in self._remote_clients:
Expand All @@ -158,7 +157,7 @@ async def get_kernels(self, config_id):
self.sig_kernel_list.emit(kernels_list)

@Slot(str)
@SpyderQAsyncRunner.run_async
@AsyncDispatcher.dispatch()
async def get_kernel_info(self, config_id, kernel_key):
"""Get kernel info."""
if config_id in self._remote_clients:
Expand All @@ -167,7 +166,7 @@ async def get_kernel_info(self, config_id, kernel_key):
self.sig_kernel_info.emit(kernel_info or {})

@Slot(str)
@SpyderQAsyncRunner.run_async
@AsyncDispatcher.dispatch()
async def terminate_kernel(self, config_id, kernel_key):
"""Terminate opened kernel."""
if config_id in self._remote_clients:
Expand All @@ -176,7 +175,7 @@ async def terminate_kernel(self, config_id, kernel_key):
self.sig_kernel_terminated.emit(delete_kernel or {})

@Slot(str)
@SpyderQAsyncRunner.run_async
@AsyncDispatcher.dispatch()
async def start_new_kernel(self, config_id):
"""Start new kernel."""
if config_id in self._remote_clients:
Expand Down

0 comments on commit 9758b54

Please sign in to comment.