Skip to content

Commit 1138667

Browse files
authored
Merge pull request #24 from ff137/upgrade/pydantic-v2
Support both Pydantic v1 and v2
2 parents f8dce43 + ba2b718 commit 1138667

File tree

8 files changed

+156
-63
lines changed

8 files changed

+156
-63
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ Websockets are ideal to create bi-directional realtime connections over the web.
134134
- Server Endpoint:
135135
- Based on [FAST-API](https://github.com/tiangolo/fastapi): enjoy all the benefits of a full ASGI platform, including Async-io and dependency injections (for example to authenticate connections)
136136

137-
- Based on [Pydnatic](https://pydantic-docs.helpmanual.io/): easily serialize structured data as part of RPC requests and responses (see 'tests/basic_rpc_test.py :: test_structured_response' for an example)
137+
- Based on [Pydantic](https://pydantic-docs.helpmanual.io/): easily serialize structured data as part of RPC requests and responses (see 'tests/basic_rpc_test.py :: test_structured_response' for an example)
138138

139139
- Client :
140140
- Based on [Tenacity](https://tenacity.readthedocs.io/en/latest/index.html): allowing configurable retries to keep to connection alive

fastapi_websocket_rpc/rpc_channel.py

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
"""
44
import asyncio
55
from inspect import _empty, getmembers, ismethod, signature
6-
from typing import Any, Coroutine, Dict, List
6+
from typing import Any, Dict, List
77

88
from pydantic import ValidationError
99

10-
from .utils import gen_uid
10+
from .logger import get_logger
1111
from .rpc_methods import EXPOSED_BUILT_IN_METHODS, NoResponse, RpcMethodsBase
1212
from .schemas import RpcMessage, RpcRequest, RpcResponse
13+
from .utils import gen_uid, pydantic_parse
1314

14-
from .logger import get_logger
1515
logger = get_logger("RPC_CHANNEL")
1616

1717

@@ -31,6 +31,7 @@ class RpcChannelClosedException(Exception):
3131
"""
3232
Raised when the channel is closed mid-operation
3333
"""
34+
3435
pass
3536

3637

@@ -92,11 +93,16 @@ class RpcCaller:
9293

9394
def __init__(self, channel, methods=None) -> None:
9495
self._channel = channel
95-
self._method_names = [method[0] for method in getmembers(
96-
methods, lambda i: ismethod(i))] if methods is not None else None
96+
self._method_names = (
97+
[method[0] for method in getmembers(methods, lambda i: ismethod(i))]
98+
if methods is not None
99+
else None
100+
)
97101

98102
def __getattribute__(self, name: str):
99-
if (not name.startswith("_") or name in EXPOSED_BUILT_IN_METHODS) and (self._method_names is None or name in self._method_names):
103+
if (not name.startswith("_") or name in EXPOSED_BUILT_IN_METHODS) and (
104+
self._method_names is None or name in self._method_names
105+
):
100106
return RpcProxy(self._channel, name)
101107
else:
102108
return super().__getattribute__(name)
@@ -124,7 +130,15 @@ class RpcChannel:
124130
e.g. answer = channel.other.add(a=1,b=1) will (For example) ask the other side to perform 1+1 and will return an RPC-response of 2
125131
"""
126132

127-
def __init__(self, methods: RpcMethodsBase, socket, channel_id=None, default_response_timeout=None, sync_channel_id=False, **kwargs):
133+
def __init__(
134+
self,
135+
methods: RpcMethodsBase,
136+
socket,
137+
channel_id=None,
138+
default_response_timeout=None,
139+
sync_channel_id=False,
140+
**kwargs,
141+
):
128142
"""
129143
130144
Args:
@@ -177,12 +191,18 @@ async def get_other_channel_id(self) -> str:
177191
The _channel_id_synced verify we have it
178192
Timeout exception can be raised if the value isn't available
179193
"""
180-
await asyncio.wait_for(self._channel_id_synced.wait(), self.default_response_timeout)
194+
await asyncio.wait_for(
195+
self._channel_id_synced.wait(), self.default_response_timeout
196+
)
181197
return self._other_channel_id
182198

183199
def get_return_type(self, method):
184200
method_signature = signature(method)
185-
return method_signature.return_annotation if method_signature.return_annotation is not _empty else str
201+
return (
202+
method_signature.return_annotation
203+
if method_signature.return_annotation is not _empty
204+
else str
205+
)
186206

187207
async def send(self, data):
188208
"""
@@ -217,14 +237,13 @@ async def on_message(self, data):
217237
This is the main function servers/clients using the channel need to call (upon reading a message on the wire)
218238
"""
219239
try:
220-
message = RpcMessage.parse_obj(data)
240+
message = pydantic_parse(RpcMessage, data)
221241
if message.request is not None:
222242
await self.on_request(message.request)
223243
if message.response is not None:
224244
await self.on_response(message.response)
225245
except ValidationError as e:
226-
logger.error(f"Failed to parse message", {
227-
'message': data, 'error': e})
246+
logger.error(f"Failed to parse message", {"message": data, "error": e})
228247
await self.on_error(e)
229248
except Exception as e:
230249
await self.on_error(e)
@@ -267,7 +286,8 @@ async def on_connect(self):
267286
"""
268287
if self._sync_channel_id:
269288
self._get_other_channel_id_task = asyncio.create_task(
270-
self._get_other_channel_id())
289+
self._get_other_channel_id()
290+
)
271291
await self.on_handler_event(self._connect_handlers, self)
272292

273293
async def _get_other_channel_id(self):
@@ -277,7 +297,11 @@ async def _get_other_channel_id(self):
277297
"""
278298
if self._other_channel_id is None:
279299
other_channel_id = await self.other._get_channel_id_()
280-
self._other_channel_id = other_channel_id.result if other_channel_id and other_channel_id.result else None
300+
self._other_channel_id = (
301+
other_channel_id.result
302+
if other_channel_id and other_channel_id.result
303+
else None
304+
)
281305
if self._other_channel_id is None:
282306
raise RemoteValueError()
283307
# update asyncio event that we received remote channel id
@@ -303,11 +327,14 @@ async def on_request(self, message: RpcRequest):
303327
message (RpcRequest): the RPC request with the method to call
304328
"""
305329
# TODO add exception support (catch exceptions and pass to other side as response with errors)
306-
logger.debug("Handling RPC request - %s",
307-
{'request': message, 'channel': self.id})
330+
logger.debug(
331+
"Handling RPC request - %s", {"request": message, "channel": self.id}
332+
)
308333
method_name = message.method
309334
# Ignore "_" prefixed methods (except the built in "_ping_")
310-
if (isinstance(method_name, str) and (not method_name.startswith("_") or method_name in EXPOSED_BUILT_IN_METHODS)):
335+
if isinstance(method_name, str) and (
336+
not method_name.startswith("_") or method_name in EXPOSED_BUILT_IN_METHODS
337+
):
311338
method = getattr(self.methods, method_name)
312339
if callable(method):
313340
result = await method(**message.arguments)
@@ -317,8 +344,17 @@ async def on_request(self, message: RpcRequest):
317344
# if no type given - try to convert to string
318345
if result_type is str and type(result) is not str:
319346
result = str(result)
320-
response = RpcMessage(response=RpcResponse[result_type](
321-
call_id=message.call_id, result=result, result_type=getattr(result_type, "__name__", getattr(result_type, "_name", "unknown-type"))))
347+
response = RpcMessage(
348+
response=RpcResponse[result_type](
349+
call_id=message.call_id,
350+
result=result,
351+
result_type=getattr(
352+
result_type,
353+
"__name__",
354+
getattr(result_type, "_name", "unknown-type"),
355+
),
356+
)
357+
)
322358
await self.send(response)
323359

324360
def get_saved_promise(self, call_id):
@@ -338,7 +374,7 @@ async def on_response(self, response: RpcResponse):
338374
Args:
339375
response (RpcResponse): the received response
340376
"""
341-
logger.debug("Handling RPC response - %s", {'response': response})
377+
logger.debug("Handling RPC response - %s", {"response": response})
342378
if response.call_id is not None and response.call_id in self.requests:
343379
self.responses[response.call_id] = response
344380
promise = self.requests[response.call_id]
@@ -360,15 +396,23 @@ async def wait_for_response(self, promise, timeout=DEFAULT_TIMEOUT) -> RpcRespon
360396
if timeout is DEFAULT_TIMEOUT:
361397
timeout = self.default_response_timeout
362398
# wait for the promise or until the channel is terminated
363-
_, pending = await asyncio.wait([asyncio.ensure_future(promise.wait()), asyncio.ensure_future(self._closed.wait())], timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
399+
_, pending = await asyncio.wait(
400+
[
401+
asyncio.ensure_future(promise.wait()),
402+
asyncio.ensure_future(self._closed.wait()),
403+
],
404+
timeout=timeout,
405+
return_when=asyncio.FIRST_COMPLETED,
406+
)
364407
# Cancel all pending futures and then detect if close was the first done
365408
for fut in pending:
366409
fut.cancel()
367410
response = self.responses.get(promise.call_id, NoResponse)
368411
# if the channel was closed before we could finish
369412
if response is NoResponse:
370413
raise RpcChannelClosedException(
371-
f"Channel Closed before RPC response for {promise.call_id} could be received")
414+
f"Channel Closed before RPC response for {promise.call_id} could be received"
415+
)
372416
self.clear_saved_call(promise.call_id)
373417
return response
374418

@@ -382,9 +426,10 @@ async def async_call(self, name, args={}, call_id=None) -> RpcPromise:
382426
call_id (string, optional): a UUID to use to track the call () - override only with true UUIDs
383427
"""
384428
call_id = call_id or gen_uid()
385-
msg = RpcMessage(request=RpcRequest(
386-
method=name, arguments=args, call_id=call_id))
387-
logger.debug("Calling RPC method - %s", {'message': msg})
429+
msg = RpcMessage(
430+
request=RpcRequest(method=name, arguments=args, call_id=call_id)
431+
)
432+
logger.debug("Calling RPC method - %s", {"message": msg})
388433
await self.send(msg)
389434
promise = self.requests[msg.request.call_id] = RpcPromise(msg.request)
390435
return promise

fastapi_websocket_rpc/schemas.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Dict, Generic, List, Optional, TypeVar
21
from enum import Enum
2+
from typing import Dict, Generic, Optional, TypeVar
33

44
from pydantic import BaseModel
5-
from pydantic.generics import GenericModel
5+
6+
from .utils import is_pydantic_pre_v2
67

78
UUID = str
89

@@ -13,13 +14,24 @@ class RpcRequest(BaseModel):
1314
call_id: Optional[UUID] = None
1415

1516

16-
ResponseT = TypeVar('ResponseT')
17+
ResponseT = TypeVar("ResponseT")
1718

1819

19-
class RpcResponse(GenericModel, Generic[ResponseT]):
20-
result: ResponseT
21-
result_type: Optional[str]
22-
call_id: Optional[UUID] = None
20+
# Check pydantic version to handle deprecated GenericModel
21+
if is_pydantic_pre_v2():
22+
from pydantic.generics import GenericModel
23+
24+
class RpcResponse(GenericModel, Generic[ResponseT]):
25+
result: ResponseT
26+
result_type: Optional[str]
27+
call_id: Optional[UUID] = None
28+
29+
else:
30+
31+
class RpcResponse(BaseModel, Generic[ResponseT]):
32+
result: ResponseT
33+
result_type: Optional[str]
34+
call_id: Optional[UUID] = None
2335

2436

2537
class RpcMessage(BaseModel):

fastapi_websocket_rpc/simplewebsocket.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import Any
2-
from abc import ABC, abstractmethod
31
import json
2+
from abc import ABC, abstractmethod
3+
4+
from .utils import pydantic_serialize
45

56

67
class SimpleWebSocket(ABC):
78
"""
89
Abstract base class for all websocket related wrappers.
910
"""
11+
1012
@abstractmethod
1113
def send(self, msg):
1214
pass
@@ -25,7 +27,7 @@ def __init__(self, websocket: SimpleWebSocket):
2527
self._websocket = websocket
2628

2729
def _serialize(self, msg):
28-
return msg.json()
30+
return pydantic_serialize(msg)
2931

3032
def _deserialize(self, buffer):
3133
return json.loads(buffer)

fastapi_websocket_rpc/utils.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
from datetime import timedelta
66
from random import SystemRandom, randrange
77

8-
__author__ = 'OrW'
8+
import pydantic
9+
from packaging import version
10+
11+
__author__ = "OrW"
912

1013

1114
class RandomUtils(object):
1215
@staticmethod
1316
def gen_cookie_id():
14-
return os.urandom(16).encode('hex')
17+
return os.urandom(16).encode("hex")
1518

1619
@staticmethod
1720
def gen_uid():
@@ -21,8 +24,10 @@ def gen_uid():
2124
def gen_token(size=256):
2225
if size % 2 != 0:
2326
raise ValueError("Size in bits must be an even number.")
24-
return uuid.UUID(int=SystemRandom().getrandbits(size/2)).hex + \
25-
uuid.UUID(int=SystemRandom().getrandbits(size/2)).hex
27+
return (
28+
uuid.UUID(int=SystemRandom().getrandbits(size / 2)).hex
29+
+ uuid.UUID(int=SystemRandom().getrandbits(size / 2)).hex
30+
)
2631

2732
@staticmethod
2833
def random_datetime(start=None, end=None):
@@ -52,9 +57,28 @@ def random_datetime(start=None, end=None):
5257
class StringUtils(object):
5358
@staticmethod
5459
def convert_camelcase_to_underscore(name, lower=True):
55-
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
56-
res = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1)
60+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
61+
res = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1)
5762
if lower:
5863
return res.lower()
5964
else:
6065
return res.upper()
66+
67+
68+
# Helper methods for supporting Pydantic v1 and v2
69+
def is_pydantic_pre_v2():
70+
return version.parse(pydantic.VERSION) < version.parse("2.0.0")
71+
72+
73+
def pydantic_serialize(model, **kwargs):
74+
if is_pydantic_pre_v2():
75+
return model.json(**kwargs)
76+
else:
77+
return model.model_dump_json(**kwargs)
78+
79+
80+
def pydantic_parse(model, data, **kwargs):
81+
if is_pydantic_pre_v2():
82+
return model.parse_obj(data, **kwargs)
83+
else:
84+
return model.model_validate(data, **kwargs)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
fastapi>=0.78.0,<1
2-
pydantic>=1.9.1,<2
2+
packaging>=20.4
3+
pydantic>=1.9.1
34
uvicorn>=0.17.6,<1
45
websockets>=10.3,<11
56
tenacity>=8.0.1,<9

0 commit comments

Comments
 (0)