Skip to content

Commit

Permalink
ehn: use qt-async-thread to run coro in parallel thread safe
Browse files Browse the repository at this point in the history
  • Loading branch information
hlouzada committed Apr 8, 2024
1 parent 815d507 commit a27fa35
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 70 deletions.
153 changes: 153 additions & 0 deletions spyder/api/asyncrunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import asyncio
import atexit
import functools
import threading
from collections.abc import Coroutine
from typing import Callable
from typing import Any

from qt_async_threads import QtAsyncRunner


class AsSync:
"""Decorator to convert a coroutine to a sync function.
Helper class to facilitate the conversion of coroutines to sync functions
or to run a coroutine as a sync function without the need to call the event
loop method.
Usage
------
As a decorator:
```
@AsSync
async def my_coroutine():
pass
my_coroutine()
```
As a class wrapper:
```
sync_coroutine = AsSync(my_coroutine)
sync_coroutine()
```
"""
def __init__(self, coro, loop=None):
"""Initialize the decorator.
Parameters
----------
coro : coroutine
The coroutine to be wrapped.
loop : asyncio.AbstractEventLoop, optional
The event loop to be used, by default get the current event loop.
"""
self.__coro = coro
self.__loop = loop or asyncio.get_event_loop()
functools.update_wrapper(self, coro)

def __call__(self, *args, **kwargs):
return self.__loop.run_until_complete(self.__coro(*args, **kwargs))

def __get__(self, instance, owner):
if instance is None:
return self
else:
bound_method = self.__coro.__get__(instance, owner)
return functools.partial(self.__class__(bound_method, self.__loop))


class SpyderQAsyncRunner(QtAsyncRunner):
"""Reimplement QtAsyncRunner as a singleton."""

_instance = None
_rlock = threading.RLock()
__inside_instance = False

@classmethod
def instance(cls, *args, **kwargs):
"""Get *the* class instance.
Return the instance of the class. If it did not exist yet, create it
by calling the "constructor" with whatever arguments and keyword
arguments provided.
Returns
-------
instance(object)
Class Singleton Instance
"""
if cls._instance is not None:
return cls._instance
with cls._rlock:
# Re-check, perhaps it was created in the meantime...
if cls._instance is None:
cls.__inside_instance = True
try:
cls._instance = cls(*args, **kwargs)
finally:
cls.__inside_instance = False
return cls._instance

def __new__(cls, *args, **kwargs):
"""Class constructor.
Ensures that this class isn't created without
the ``instance`` class method.
Returns
-------
object
Class instance
Raises
------
RuntimeError
Exception when not called from the ``instance`` class method.
"""
if cls._instance is None:
with cls._rlock:
if cls._instance is None and cls.__inside_instance:
return super().__new__(cls)

raise RuntimeError(
f"Attempt to create a {cls.__qualname__} instance outside of instance()"
)

async def run(
self, func: Callable | Coroutine, *args: Any, **kwargs: Any
) -> Any:
"""
Updated to handle both functions and coroutines. If `func` is a coroutine,
it is scheduled to run in the thread pool.
"""
if asyncio.iscoroutinefunction(func) or isinstance(func, Coroutine):
# Wrap the coroutine in a function that sets up an event loop
async_func = func if asyncio.iscoroutinefunction(func) else functools.partial(lambda x: x, func)
loop = asyncio.get_event_loop()
def run_coro_in_thread():
coroutine = async_func(*args, **kwargs)
return asyncio.run_coroutine_threadsafe(coroutine, loop)

func_to_run = run_coro_in_thread
else:
func_to_run = functools.partial(func, *args, **kwargs)

async for result in self.run_parallel([func_to_run]):
return result
assert False, "should never be reached"

@atexit.register
def __atexit():
SpyderQAsyncRunner.instance().close()

@classmethod
def run_async(cls, coro):
"""Decorator to run a coroutine asynchronously."""
@functools.wraps(coro)
def wrapper(*args, **kwargs):
return cls.instance().start_coroutine(cls.instance().run(coro, *args, **kwargs))
return wrapper

51 changes: 0 additions & 51 deletions spyder/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,54 +66,3 @@ class classproperty(property):

def __get__(self, cls, owner):
return classmethod(self.fget).__get__(None, owner)()


class AsSync:
"""Decorator to convert a coroutine to a sync function.
Helper class to facilitate the conversion of coroutines to sync functions
or to run a coroutine as a sync function without the need to call the event
loop method.
Usage
------
As a decorator:
```
@AsSync
async def my_coroutine():
pass
my_coroutine()
```
As a class wrapper:
```
sync_coroutine = AsSync(my_coroutine)
sync_coroutine()
```
"""
def __init__(self, coro, loop=None):
"""Initialize the decorator.
Parameters
----------
coro : coroutine
The coroutine to be wrapped.
loop : asyncio.AbstractEventLoop, optional
The event loop to be used, by default get the current event loop.
"""
self.__coro = coro
self.__loop = loop or asyncio.get_event_loop()
functools.update_wrapper(self, coro)

def __call__(self, *args, **kwargs):
return self.__loop.run_until_complete(self.__coro(*args, **kwargs))

def __get__(self, instance, owner):
if instance is None:
return self
else:
bound_method = self.__coro.__get__(instance, owner)
return functools.partial(self.__class__(bound_method, self.__loop))

35 changes: 16 additions & 19 deletions spyder/plugins/remoteclient/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from spyder.api.plugins import Plugins, SpyderPluginV2
from spyder.plugins.remoteclient.api.protocol import KernelConnectionInfo, DeleteKernel, KernelInfo, KernelsList, SSHClientOptions
from spyder.plugins.remoteclient.api.client import SpyderRemoteClient
from spyder.api.utils import AsSync
from spyder.api.asyncrunner import AsSync
from spyder.api.asyncrunner import SpyderQAsyncRunner

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,33 +87,33 @@ def get_remote_server(self, config_id):
if config_id in self._remote_clients:
return self._remote_clients[config_id]

@AsSync
async def install_remote_server(self, config_id):
@SpyderQAsyncRunner.as_sync
async def _install_remote_server(self, config_id):
"""Install remote server."""
if config_id in self._remote_clients:
client = self._remote_clients[config_id]
return await client.connect_and_install_remote_server()
await client.connect_and_install_remote_server()

@AsSync
@SpyderQAsyncRunner.as_sync
async def start_remote_server(self, config_id):
"""Start remote server."""
if config_id in self._remote_clients:
server = self._remote_clients[config_id]
return await server.connect_and_start_server()
await server.connect_and_start_server()

@AsSync
@SpyderQAsyncRunner.as_sync
async def stop_remote_server(self, config_id):
"""Stop remote server."""
if config_id in self._remote_clients:
client = self._remote_clients[config_id]
await client.stop_remote_server()

@AsSync
@SpyderQAsyncRunner.as_sync
async def ensure_remote_server(self, config_id):
"""Ensure remote server is running and installed."""
if config_id in self._remote_clients:
client = self._remote_clients[config_id]
return await client.connect_and_ensure_server()
await client.connect_and_ensure_server()

def restart_remote_server(self, config_id):
"""Restart remote server."""
Expand All @@ -123,7 +124,7 @@ def restart_remote_server(self, config_id):
def load_client_from_id(self, config_id):
"""Load remote server from configuration id."""
options = self.load_conf(config_id)
self.load_client(config_id, options)
self.load_client(config_id, options)

def load_client(self, config_id: str, options: SSHClientOptions):
"""Load remote server."""
Expand All @@ -148,41 +149,37 @@ def get_conf_ids(self):
# -------------------------------------------------------------------------
# --- Remote Server Kernel Methods
@Slot(str)
@AsSync
@SpyderQAsyncRunner.as_sync
async def get_kernels(self, config_id):
"""Get opened kernels."""
if config_id in self._remote_clients:
client = self._remote_clients[config_id]
kernels_list = await client.get_kernels()
self.sig_kernel_list.emit(kernels_list)
return kernels_list

@Slot(str)
@AsSync
@SpyderQAsyncRunner.as_sync
async def get_kernel_info(self, config_id, kernel_key):
"""Get kernel info."""
if config_id in self._remote_clients:
client = self._remote_clients[config_id]
kernel_info = await client.get_kernel_info(kernel_key)
self.sig_kernel_info.emit(kernel_info or {})
return kernel_info

@Slot(str)
@AsSync
@SpyderQAsyncRunner.as_sync
async def terminate_kernel(self, config_id, kernel_key):
"""Terminate opened kernel."""
if config_id in self._remote_clients:
client = self._remote_clients[config_id]
delete_kernel = await client.terminate_kernel(kernel_key)
self.sig_kernel_terminated.emit(delete_kernel or {})
return delete_kernel

@Slot(str)
@AsSync
@SpyderQAsyncRunner.as_sync
async def start_new_kernel(self, config_id):
"""Start new kernel."""
if config_id in self._remote_clients:
client = self._remote_clients[config_id]
kernel_connection_info = await client.start_new_kernel_ensure_server()
self.sig_kernel_started.emit(kernel_connection_info)
return kernel_connection_info
self.sig_kernel_started.emit(kernel_connection_info or {})

0 comments on commit a27fa35

Please sign in to comment.