From f7195175b3c7999a6ed76073ef67c797c4bd8c0e Mon Sep 17 00:00:00 2001 From: Andre Merzky Date: Wed, 26 Jul 2023 15:40:39 +0200 Subject: [PATCH] cleaner RPC handling --- src/radical/pilot/messages.py | 227 ++++---------------------- src/radical/pilot/utils/__init__.py | 1 + src/radical/pilot/utils/rpc_helper.py | 206 +++++++++++++++++++++++ 3 files changed, 241 insertions(+), 193 deletions(-) create mode 100644 src/radical/pilot/utils/rpc_helper.py diff --git a/src/radical/pilot/messages.py b/src/radical/pilot/messages.py index be6401b012..758d3a0daa 100644 --- a/src/radical/pilot/messages.py +++ b/src/radical/pilot/messages.py @@ -9,48 +9,10 @@ # class HeartbeatMessage(ru.Message): - # ------------------------------ - class Payload(ru.TypedDict): - _schema = {'uid': str } - _defaults = {'uid': None } - # ------------------------------ - _schema = { - 'payload': Payload - } - - _defaults = { - 'msg_type': 'heartbeat', - 'payload' : {} - } - - - - # -------------------------------------------------------------------------- - def __init__(self, uid : Optional[str] = None, - from_dict: Optional[Dict[str, Any]] = None): - ''' - support msg construction and usage like this: - - hb_msg = rp.HeartbeatMessage(uid='foo.1') - assert hb_msg.uid == 'foo.1 - - ''' - - if uid: - from_dict = {'payload': {'uid': uid}} - - super().__init__(from_dict=from_dict) - - - # -------------------------------------------------------------------------- - @property - def uid(self): - return self.payload.uid - - @uid.setter - def uid(self, value): - self.payload.uid = value + _schema = {'uid' : str } + _defaults = {'_msg_type': 'heartbeat', + 'uid' : None} ru.Message.register_msg_type('heartbeat', HeartbeatMessage) @@ -60,82 +22,18 @@ def uid(self, value): # class RPCRequestMessage(ru.Message): - # ------------------------------ - class Payload(ru.TypedDict): - _schema = { - 'uid' : str, # uid of message - 'addr': str, # who is expected to act on the request - 'cmd' : str, # rpc command - 'args': dict, # rpc command arguments - } - _defaults = { - 'uid' : None, - 'addr': None, - 'cmd' : None, - 'args': {}, - } - # ------------------------------ - - _schema = { - 'payload': Payload - } - + _schema = {'uid' : str, # uid of message + 'addr' : str, # who is expected to act on the request + 'cmd' : str, # rpc command + 'args' : list, # rpc command arguments + 'kwargs' : dict} # rpc command named arguments _defaults = { - 'msg_type': 'rpc_req', - 'payload' : {} - } - - - - # -------------------------------------------------------------------------- - def __init__(self, uid : Optional[str] = None, - addr: Optional[str] = None, - rpc : Optional[str] = None, - args: Optional[Dict[str, Any]] = None): - ''' - support msg construction and usage like this: - - msg = rp.Message(addr='pilot.0000', rpc='stop') - assert msg.addr == 'pilot.0000' - - ''' - - from_dict = dict() - - if addr: from_dict['addr'] = addr - if rpc: from_dict['rpc'] = rpc - if args: from_dict['args'] = args - - super().__init__(from_dict=from_dict) - - - # -------------------------------------------------------------------------- - @property - def addr(self): - return self.payload.addr - - @addr.setter - def addr(self, value): - self.payload.addr = value - - - @property - def rpc(self): - return self.payload.rpc - - @rpc.setter - def rpc(self, value): - self.payload.rpc = value - - - @property - def args(self): - return self.payload.args - - @args.setter - def args(self, value): - self.payload.args = value - + '_msg_type': 'rpc_req', + 'uid' : None, + 'addr' : None, + 'cmd' : None, + 'args' : [], + 'kwargs' : {}} ru.Message.register_msg_type('rpc_req', RPCRequestMessage) @@ -144,92 +42,35 @@ def args(self, value): # class RPCResultMessage(ru.Message): - # ------------------------------ - class Payload(ru.TypedDict): - _schema = { - 'uid': str, # uid of rpc call - 'val': Any, # return value (`None` by default) - 'out': str, # stdout - 'err': str, # stderr - 'exc': str, # raised exception representation - } - _defaults = { - 'uid': None, - 'val': None, - 'out': None, - 'err': None, - 'exc': None, - } - # ------------------------------ - - _schema = { - 'payload': Payload - } - - _defaults = { - 'msg_type': 'rpc_res', - 'payload' : {} - } - - + _schema = {'uid' : str, # uid of rpc call + 'val' : Any, # return value (`None` by default) + 'out' : str, # stdout + 'err' : str, # stderr + 'exc' : str} # raised exception representation + _defaults = {'_msg_type': 'rpc_res', + 'uid' : None, + 'val' : None, + 'out' : None, + 'err' : None, + 'exc' : None} # -------------------------------------------------------------------------- - def __init__(self, rpc_req: Optional[RPCRequestMessage] = None, - uid : Optional[str] = None, - val : Optional[Any] = None, - out : Optional[str] = None, - err : Optional[str] = None, - exc : Optional[Tuple[str, str]] = None): - ''' - support rpc response message construction from an rpc request message - (carries over `uid`): + # + def __init__(self, rpc_req=None, from_dict=None, **kwargs): - msg = rp.Message(rpc_req=req_msg, val=42) + # when constfructed from a request message copy the uid - ''' + if rpc_req: + if not from_dict: + from_dict = dict() - from_dict = dict() + from_dict['uid'] = rpc_req['uid'] - if rpc_req: from_dict['uid'] = rpc_req.uid + super().__init__(from_dict, **kwargs) - if uid: from_dict['uid'] = uid - if val: from_dict['val'] = uid - if out: from_dict['out'] = uid - if err: from_dict['err'] = uid - if exc: from_dict['exc'] = uid - super().__init__(from_dict=from_dict) +ru.Message.register_msg_type('rpc_res', RPCResultMessage) - # -------------------------------------------------------------------------- - @property - def addr(self): - return self.payload.addr - - @addr.setter - def addr(self, value): - self.payload.addr = value - - - @property - def rpc(self): - return self.payload.rpc - - @rpc.setter - def rpc(self, value): - self.payload.rpc = value - - - @property - def args(self): - return self.payload.args - - @args.setter - def args(self, value): - self.payload.args = value - - -ru.Message.register_msg_type('rpc_req', RPCRequestMessage) - # ------------------------------------------------------------------------------ diff --git a/src/radical/pilot/utils/__init__.py b/src/radical/pilot/utils/__init__.py index 120969e3a1..b838704efe 100644 --- a/src/radical/pilot/utils/__init__.py +++ b/src/radical/pilot/utils/__init__.py @@ -39,6 +39,7 @@ from .component import * from .component_manager import * from .serializer import * +from .rpc_helper import * # ------------------------------------------------------------------------------ diff --git a/src/radical/pilot/utils/rpc_helper.py b/src/radical/pilot/utils/rpc_helper.py new file mode 100644 index 0000000000..b3ba63531f --- /dev/null +++ b/src/radical/pilot/utils/rpc_helper.py @@ -0,0 +1,206 @@ + +__copyright__ = 'Copyright 2023, The RADICAL-Cybertools Team' +__license__ = 'MIT' + +import io +import sys +import queue + +import threading as mt + +import radical.utils as ru + +from ..constants import CONTROL_PUBSUB +from ..messages import RPCRequestMessage, RPCResultMessage + + +# ------------------------------------------------------------------------------ +# +class RPCHelper(object): + ''' + This class implements a simple synchronous RPC mechanism. It only requires + the addresses of the control pubsub to use. + ''' + + + # -------------------------------------------------------------------------- + # + def __init__(self, ctrl_addr_pub, ctrl_addr_sub, log, prof): + + self._addr_pub = ctrl_addr_pub + self._addr_sub = ctrl_addr_sub + + self._log = log + self._prof = prof + + self._active = None + self._queue = queue.Queue() + self._lock = mt.Lock() + self._handlers = dict() + + self._pub = ru.zmq.Publisher(channel=CONTROL_PUBSUB, + url=self._addr_pub, + log=self._log, + prof=self._prof) + + self._thread = mt.Thread(target=self._work) + self._thread.daemon = True + self._thread.start() + + + # -------------------------------------------------------------------------- + # + def request(self, cmd, *args, **kwargs): + + rid = ru.generate_id('rpc') + req = RPCRequestMessage(uid=rid, cmd=cmd, args=args, kwargs=kwargs) + + self._active = rid + + self._pub.put(CONTROL_PUBSUB, req) + self._log.debug_3('sent rpc req %s', req) + + res = self._queue.get() + + assert res.uid == req.uid + + if res.exc: + # FIXME: try to deserialize exception type + # this should work at least for standard exceptions + raise RuntimeError(str(res.exc)) + + return res + + + # -------------------------------------------------------------------------- + # + def _work(self): + + pub = ru.zmq.Publisher(channel=CONTROL_PUBSUB, + url=self._addr_pub, + log=self._log, + prof=self._prof) + + sub = ru.zmq.Subscriber(channel=CONTROL_PUBSUB, + topic=CONTROL_PUBSUB, + url=self._addr_sub, + log=self._log, + prof=self._prof) + sub.subscribe(CONTROL_PUBSUB) + + import time + time.sleep(1) + + while True: + + data = sub.get_nowait(100) + if not data or data == [None, None]: + continue + + msg_topic = data[0] + msg_data = data[1] + + if not isinstance(msg_data, dict): + continue + + try: + msg = ru.zmq.Message.deserialize(msg_data) + + except Exception as e: + # not a `ru.zmq.Message` type + continue + + if isinstance(msg, RPCRequestMessage): + + # handle any RPC requests for which a handler is registered + self._log.debug_2('got rpc req: %s', msg) + + with self._lock: + if msg.cmd in self._handlers: + rep = self._handle_request(msg) + pub.put(CONTROL_PUBSUB, rep) + else: + self._log.debug_2('no rpc handler for %s', msg.cmd) + + elif isinstance(msg, RPCResultMessage): + + # collect any RPC response whose uid matches the one we wait for + + self._log.debug_2('got rpc res', self._active, msg.uid) + if self._active and msg.uid == self._active: + self._active = None + self._queue.put(msg) + + + # -------------------------------------------------------------------------- + # + def _handle_request(self, msg): + + bakout = sys.stdout + bakerr = sys.stderr + + strout = None + strerr = None + + val = None + out = None + err = None + exc = None + + try: + self._log.debug_2('rpc handler: %s(%s, %s)', + self._handlers[msg.cmd], *msg.args, **msg.kwargs) + + sys.stdout = strout = io.StringIO() + sys.stderr = strerr = io.StringIO() + + val = self._handlers[msg.cmd](*msg.args, **msg.kwargs) + out = strout.getvalue() + err = strerr.getvalue() + + except Exception as e: + self._log.exception('rpc call failed: %s' % (msg)) + val = None + out = strout.getvalue() + err = strerr.getvalue() + exc = (repr(e), '\n'.join(ru.get_exception_trace())) + + finally: + # restore stdio + sys.stdout = bakout + sys.stderr = bakerr + + return RPCResultMessage(rpc_req=msg, val=val, out=out, err=err, exc=exc) + + + # -------------------------------------------------------------------------- + # + def add_handler(self, cmd, handler): + ''' + register a handler for the specified rpc command type + ''' + + with self._lock: + + if cmd in self._handlers: + raise ValueError('handler for rpc cmd %s already set' % cmd) + + self._handlers[cmd] = handler + + + # -------------------------------------------------------------------------- + # + def del_handler(self, cmd): + ''' + unregister a handler for the specified rpc command type + ''' + + with self._lock: + + if cmd not in self._handlers: + raise ValueError('handler for rpc cmd %s not set' % cmd) + + del self._handlers[cmd] + + +# ------------------------------------------------------------------------------