diff --git a/labgrid/remote/client.py b/labgrid/remote/client.py index 63fdfa7af..6924090f5 100755 --- a/labgrid/remote/client.py +++ b/labgrid/remote/client.py @@ -4,6 +4,7 @@ import argparse import asyncio import contextlib +from contextvars import ContextVar import enum import os import pathlib @@ -1529,8 +1530,45 @@ def print_version(self): print(labgrid_version()) -def start_session(address, extra, debug=False): - loop = asyncio.get_event_loop() +_loop: ContextVar["asyncio.AbstractEventLoop | None"] = ContextVar("_loop", default=None) + + +def ensure_event_loop(external_loop=None): + """Get the event loop for this thread, or create a new event loop.""" + # get stashed loop + loop = _loop.get() + + # ignore closed stashed loop + if loop and loop.is_closed(): + loop = None + + if external_loop: + # if a loop is stashed, expect it to be the same as the external one + if loop: + assert loop is external_loop + _loop.set(external_loop) + return external_loop + + # return stashed loop + if loop: + return loop + + try: + # if called from async code, try to get current's thread loop + loop = asyncio.get_running_loop() + except RuntimeError: + # no previous, external or running loop found, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # stash it + _loop.set(loop) + return loop + + +def start_session(address, extra, debug=False, loop=None): + loop = ensure_event_loop(loop) + if debug: loop.set_debug(True) @@ -2040,7 +2078,9 @@ def main(): coordinator_address = os.environ.get("LG_COORDINATOR", "127.0.0.1:20408") logging.debug('Starting session with "%s"', coordinator_address) - session = start_session(coordinator_address, extra, args.debug) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + session = start_session(coordinator_address, extra=extra, debug=args.debug, loop=loop) logging.debug("Started session") try: