33"""
44import asyncio
55from inspect import _empty , getmembers , ismethod , signature
6- from typing import Any , Coroutine , Dict , List
6+ from typing import Any , Dict , List
77
88from pydantic import ValidationError
99
10- from .utils import gen_uid
10+ from .logger import get_logger
1111from .rpc_methods import EXPOSED_BUILT_IN_METHODS , NoResponse , RpcMethodsBase
1212from .schemas import RpcMessage , RpcRequest , RpcResponse
13+ from .utils import gen_uid , pydantic_parse
1314
14- from .logger import get_logger
1515logger = get_logger ("RPC_CHANNEL" )
1616
1717
@@ -31,6 +31,7 @@ class RpcChannelClosedException(Exception):
3131 """
3232 Raised when the channel is closed mid-operation
3333 """
34+
3435 pass
3536
3637
@@ -92,11 +93,16 @@ class RpcCaller:
9293
9394 def __init__ (self , channel , methods = None ) -> None :
9495 self ._channel = channel
95- self ._method_names = [method [0 ] for method in getmembers (
96- methods , lambda i : ismethod (i ))] if methods is not None else None
96+ self ._method_names = (
97+ [method [0 ] for method in getmembers (methods , lambda i : ismethod (i ))]
98+ if methods is not None
99+ else None
100+ )
97101
98102 def __getattribute__ (self , name : str ):
99- if (not name .startswith ("_" ) or name in EXPOSED_BUILT_IN_METHODS ) and (self ._method_names is None or name in self ._method_names ):
103+ if (not name .startswith ("_" ) or name in EXPOSED_BUILT_IN_METHODS ) and (
104+ self ._method_names is None or name in self ._method_names
105+ ):
100106 return RpcProxy (self ._channel , name )
101107 else :
102108 return super ().__getattribute__ (name )
@@ -124,7 +130,15 @@ class RpcChannel:
124130 e.g. answer = channel.other.add(a=1,b=1) will (For example) ask the other side to perform 1+1 and will return an RPC-response of 2
125131 """
126132
127- def __init__ (self , methods : RpcMethodsBase , socket , channel_id = None , default_response_timeout = None , sync_channel_id = False , ** kwargs ):
133+ def __init__ (
134+ self ,
135+ methods : RpcMethodsBase ,
136+ socket ,
137+ channel_id = None ,
138+ default_response_timeout = None ,
139+ sync_channel_id = False ,
140+ ** kwargs ,
141+ ):
128142 """
129143
130144 Args:
@@ -177,12 +191,18 @@ async def get_other_channel_id(self) -> str:
177191 The _channel_id_synced verify we have it
178192 Timeout exception can be raised if the value isn't available
179193 """
180- await asyncio .wait_for (self ._channel_id_synced .wait (), self .default_response_timeout )
194+ await asyncio .wait_for (
195+ self ._channel_id_synced .wait (), self .default_response_timeout
196+ )
181197 return self ._other_channel_id
182198
183199 def get_return_type (self , method ):
184200 method_signature = signature (method )
185- return method_signature .return_annotation if method_signature .return_annotation is not _empty else str
201+ return (
202+ method_signature .return_annotation
203+ if method_signature .return_annotation is not _empty
204+ else str
205+ )
186206
187207 async def send (self , data ):
188208 """
@@ -217,14 +237,13 @@ async def on_message(self, data):
217237 This is the main function servers/clients using the channel need to call (upon reading a message on the wire)
218238 """
219239 try :
220- message = RpcMessage . parse_obj ( data )
240+ message = pydantic_parse ( RpcMessage , data )
221241 if message .request is not None :
222242 await self .on_request (message .request )
223243 if message .response is not None :
224244 await self .on_response (message .response )
225245 except ValidationError as e :
226- logger .error (f"Failed to parse message" , {
227- 'message' : data , 'error' : e })
246+ logger .error (f"Failed to parse message" , {"message" : data , "error" : e })
228247 await self .on_error (e )
229248 except Exception as e :
230249 await self .on_error (e )
@@ -267,7 +286,8 @@ async def on_connect(self):
267286 """
268287 if self ._sync_channel_id :
269288 self ._get_other_channel_id_task = asyncio .create_task (
270- self ._get_other_channel_id ())
289+ self ._get_other_channel_id ()
290+ )
271291 await self .on_handler_event (self ._connect_handlers , self )
272292
273293 async def _get_other_channel_id (self ):
@@ -277,7 +297,11 @@ async def _get_other_channel_id(self):
277297 """
278298 if self ._other_channel_id is None :
279299 other_channel_id = await self .other ._get_channel_id_ ()
280- self ._other_channel_id = other_channel_id .result if other_channel_id and other_channel_id .result else None
300+ self ._other_channel_id = (
301+ other_channel_id .result
302+ if other_channel_id and other_channel_id .result
303+ else None
304+ )
281305 if self ._other_channel_id is None :
282306 raise RemoteValueError ()
283307 # update asyncio event that we received remote channel id
@@ -303,11 +327,14 @@ async def on_request(self, message: RpcRequest):
303327 message (RpcRequest): the RPC request with the method to call
304328 """
305329 # TODO add exception support (catch exceptions and pass to other side as response with errors)
306- logger .debug ("Handling RPC request - %s" ,
307- {'request' : message , 'channel' : self .id })
330+ logger .debug (
331+ "Handling RPC request - %s" , {"request" : message , "channel" : self .id }
332+ )
308333 method_name = message .method
309334 # Ignore "_" prefixed methods (except the built in "_ping_")
310- if (isinstance (method_name , str ) and (not method_name .startswith ("_" ) or method_name in EXPOSED_BUILT_IN_METHODS )):
335+ if isinstance (method_name , str ) and (
336+ not method_name .startswith ("_" ) or method_name in EXPOSED_BUILT_IN_METHODS
337+ ):
311338 method = getattr (self .methods , method_name )
312339 if callable (method ):
313340 result = await method (** message .arguments )
@@ -317,8 +344,17 @@ async def on_request(self, message: RpcRequest):
317344 # if no type given - try to convert to string
318345 if result_type is str and type (result ) is not str :
319346 result = str (result )
320- response = RpcMessage (response = RpcResponse [result_type ](
321- call_id = message .call_id , result = result , result_type = getattr (result_type , "__name__" , getattr (result_type , "_name" , "unknown-type" ))))
347+ response = RpcMessage (
348+ response = RpcResponse [result_type ](
349+ call_id = message .call_id ,
350+ result = result ,
351+ result_type = getattr (
352+ result_type ,
353+ "__name__" ,
354+ getattr (result_type , "_name" , "unknown-type" ),
355+ ),
356+ )
357+ )
322358 await self .send (response )
323359
324360 def get_saved_promise (self , call_id ):
@@ -338,7 +374,7 @@ async def on_response(self, response: RpcResponse):
338374 Args:
339375 response (RpcResponse): the received response
340376 """
341- logger .debug ("Handling RPC response - %s" , {' response' : response })
377+ logger .debug ("Handling RPC response - %s" , {" response" : response })
342378 if response .call_id is not None and response .call_id in self .requests :
343379 self .responses [response .call_id ] = response
344380 promise = self .requests [response .call_id ]
@@ -360,15 +396,23 @@ async def wait_for_response(self, promise, timeout=DEFAULT_TIMEOUT) -> RpcRespon
360396 if timeout is DEFAULT_TIMEOUT :
361397 timeout = self .default_response_timeout
362398 # wait for the promise or until the channel is terminated
363- _ , pending = await asyncio .wait ([asyncio .ensure_future (promise .wait ()), asyncio .ensure_future (self ._closed .wait ())], timeout = timeout , return_when = asyncio .FIRST_COMPLETED )
399+ _ , pending = await asyncio .wait (
400+ [
401+ asyncio .ensure_future (promise .wait ()),
402+ asyncio .ensure_future (self ._closed .wait ()),
403+ ],
404+ timeout = timeout ,
405+ return_when = asyncio .FIRST_COMPLETED ,
406+ )
364407 # Cancel all pending futures and then detect if close was the first done
365408 for fut in pending :
366409 fut .cancel ()
367410 response = self .responses .get (promise .call_id , NoResponse )
368411 # if the channel was closed before we could finish
369412 if response is NoResponse :
370413 raise RpcChannelClosedException (
371- f"Channel Closed before RPC response for { promise .call_id } could be received" )
414+ f"Channel Closed before RPC response for { promise .call_id } could be received"
415+ )
372416 self .clear_saved_call (promise .call_id )
373417 return response
374418
@@ -382,9 +426,10 @@ async def async_call(self, name, args={}, call_id=None) -> RpcPromise:
382426 call_id (string, optional): a UUID to use to track the call () - override only with true UUIDs
383427 """
384428 call_id = call_id or gen_uid ()
385- msg = RpcMessage (request = RpcRequest (
386- method = name , arguments = args , call_id = call_id ))
387- logger .debug ("Calling RPC method - %s" , {'message' : msg })
429+ msg = RpcMessage (
430+ request = RpcRequest (method = name , arguments = args , call_id = call_id )
431+ )
432+ logger .debug ("Calling RPC method - %s" , {"message" : msg })
388433 await self .send (msg )
389434 promise = self .requests [msg .request .call_id ] = RpcPromise (msg .request )
390435 return promise
0 commit comments