Skip to content

Commit 1b027f9

Browse files
committed
move some locks around
1 parent 9c08b02 commit 1b027f9

File tree

2 files changed

+122
-113
lines changed

2 files changed

+122
-113
lines changed

src/replit_river/server_transport.py

Lines changed: 94 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,41 @@ async def handshake_to_get_session(
9090
async def close(self) -> None:
9191
await self._close_all_sessions()
9292

93+
async def _get_existing_session(self, to_id: str) -> Optional[Session]:
94+
async with self._session_lock:
95+
return self._sessions.get(to_id)
96+
9397
async def _get_or_create_session(
9498
self,
9599
transport_id: str,
96100
to_id: str,
97101
session_id: str,
98102
websocket: WebSocketCommonProtocol,
99103
) -> Session:
100-
session_to_close: Optional[Session] = None
101104
new_session: Optional[Session] = None
102-
async with self._session_lock:
103-
if to_id not in self._sessions:
105+
old_session: Optional[Session] = await self._get_existing_session(to_id)
106+
if not old_session:
107+
logger.info(
108+
'Creating new session with "%s" using ws: %s', to_id, websocket.id
109+
)
110+
new_session = Session(
111+
transport_id,
112+
to_id,
113+
session_id,
114+
websocket,
115+
self._transport_options,
116+
self._is_server,
117+
self._handlers,
118+
close_session_callback=self._delete_session,
119+
)
120+
else:
121+
if old_session.session_id != session_id:
104122
logger.info(
105-
'Creating new session with "%s" using ws: %s', to_id, websocket.id
123+
'Create new session with "%s" for session id %s'
124+
" and close old session %s",
125+
to_id,
126+
session_id,
127+
old_session.session_id,
106128
)
107129
new_session = Session(
108130
transport_id,
@@ -115,45 +137,25 @@ async def _get_or_create_session(
115137
close_session_callback=self._delete_session,
116138
)
117139
else:
118-
old_session = self._sessions[to_id]
119-
if old_session.session_id != session_id:
120-
logger.info(
121-
'Create new session with "%s" for session id %s'
122-
" and close old session %s",
123-
to_id,
124-
session_id,
125-
old_session.session_id,
126-
)
127-
session_to_close = old_session
128-
new_session = Session(
129-
transport_id,
130-
to_id,
131-
session_id,
132-
websocket,
133-
self._transport_options,
134-
self._is_server,
135-
self._handlers,
136-
close_session_callback=self._delete_session,
137-
)
138-
else:
139-
# If the instance id is the same, we reuse the session and assign
140-
# a new websocket to it.
141-
logger.debug(
142-
'Reuse old session with "%s" using new ws: %s',
143-
to_id,
144-
websocket.id,
145-
)
146-
try:
147-
await old_session.replace_with_new_websocket(websocket)
148-
new_session = old_session
149-
except FailedSendingMessageException as e:
150-
raise e
140+
# If the instance id is the same, we reuse the session and assign
141+
# a new websocket to it.
142+
logger.debug(
143+
'Reuse old session with "%s" using new ws: %s',
144+
to_id,
145+
websocket.id,
146+
)
147+
try:
148+
await old_session.replace_with_new_websocket(websocket)
149+
new_session = old_session
150+
except FailedSendingMessageException as e:
151+
raise e
151152

152-
self._set_session(new_session)
153+
if old_session and new_session != old_session:
154+
logger.info("Closing stale session %s", old_session.session_id)
155+
await old_session.close()
153156

154-
if session_to_close:
155-
logger.info("Closing stale session %s", session_to_close.session_id)
156-
await session_to_close.close()
157+
async with self._session_lock:
158+
self._set_session(new_session)
157159

158160
return new_session
159161

@@ -230,67 +232,67 @@ async def _establish_handshake(
230232
)
231233
raise InvalidMessageException("handshake request to wrong server")
232234

233-
async with self._session_lock:
234-
old_session = self._sessions.get(request_message.from_, None)
235-
client_next_expected_seq = (
236-
handshake_request.expectedSessionState.nextExpectedSeq
237-
)
238-
client_next_sent_seq = (
239-
handshake_request.expectedSessionState.nextSentSeq or 0
240-
)
241-
if old_session and old_session.session_id == handshake_request.sessionId:
242-
# check invariants
243-
# ordering must be correct
244-
our_next_seq = await old_session.get_next_sent_seq()
245-
our_ack = await old_session.get_next_expected_seq()
246-
247-
if client_next_sent_seq > our_ack:
248-
message = (
249-
"client is in the future: "
250-
f"server wanted {our_ack} but client has {client_next_sent_seq}"
251-
)
252-
await self._send_handshake_response(
253-
request_message,
254-
HandShakeStatus(ok=False, reason=message),
255-
websocket,
256-
)
257-
raise SessionStateMismatchException(message)
235+
old_session = await self._get_existing_session(request_message.from_)
236+
client_next_expected_seq = (
237+
handshake_request.expectedSessionState.nextExpectedSeq
238+
)
239+
client_next_sent_seq = (
240+
handshake_request.expectedSessionState.nextSentSeq or 0
241+
)
242+
if old_session and old_session.session_id == handshake_request.sessionId:
243+
# check invariants
244+
# ordering must be correct
245+
our_next_seq = await old_session.get_next_sent_seq()
246+
our_ack = await old_session.get_next_expected_seq()
258247

259-
if our_next_seq > client_next_expected_seq:
260-
message = (
261-
"server is in the future: "
262-
f"client wanted {client_next_expected_seq} "
263-
f"but server has {our_next_seq}"
264-
)
265-
await self._send_handshake_response(
266-
request_message,
267-
HandShakeStatus(ok=False, reason=message),
268-
websocket,
269-
)
270-
raise SessionStateMismatchException(message)
271-
elif old_session:
272-
# we have an old session but the session id is different
273-
# just delete the old session
274-
await old_session.close()
275-
old_session = None
248+
if client_next_sent_seq > our_ack:
249+
message = (
250+
"client is in the future: "
251+
f"server wanted {our_ack} but client has {client_next_sent_seq}"
252+
)
253+
await self._send_handshake_response(
254+
request_message,
255+
HandShakeStatus(ok=False, reason=message),
256+
websocket,
257+
)
258+
raise SessionStateMismatchException(message)
276259

277-
if not old_session and (
278-
client_next_sent_seq > 0 or client_next_expected_seq > 0
279-
):
280-
message = "client is trying to resume a session but we don't have it"
260+
if our_next_seq > client_next_expected_seq:
261+
message = (
262+
"server is in the future: "
263+
f"client wanted {client_next_expected_seq} "
264+
f"but server has {our_next_seq}"
265+
)
281266
await self._send_handshake_response(
282267
request_message,
283268
HandShakeStatus(ok=False, reason=message),
284269
websocket,
285270
)
286271
raise SessionStateMismatchException(message)
272+
elif old_session:
273+
# we have an old session but the session id is different
274+
# just delete the old session
275+
await old_session.close()
276+
old_session = None
277+
287278

288-
# from this point on, we're committed to connecting
289-
session_id = handshake_request.sessionId
290-
handshake_response = await self._send_handshake_response(
279+
if not old_session and (
280+
client_next_sent_seq > 0 or client_next_expected_seq > 0
281+
):
282+
message = "client is trying to resume a session but we don't have it"
283+
await self._send_handshake_response(
291284
request_message,
292-
HandShakeStatus(ok=True, sessionId=session_id),
285+
HandShakeStatus(ok=False, reason=message),
293286
websocket,
294287
)
288+
raise SessionStateMismatchException(message)
289+
290+
# from this point on, we're committed to connecting
291+
session_id = handshake_request.sessionId
292+
handshake_response = await self._send_handshake_response(
293+
request_message,
294+
HandShakeStatus(ok=True, sessionId=session_id),
295+
websocket,
296+
)
295297

296-
return handshake_request, handshake_response
298+
return handshake_request, handshake_response

tests/river_fixtures/clientserver.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from typing import AsyncGenerator, Literal
34

@@ -37,31 +38,37 @@ async def client(
3738
transport_options: TransportOptions,
3839
no_logging_error: NoErrors,
3940
) -> AsyncGenerator[Client, None]:
41+
binding = None
4042
try:
41-
async with serve(server.serve, "127.0.0.1") as binding:
42-
sockets = list(binding.sockets)
43-
assert len(sockets) == 1, "Too many sockets!"
44-
socket = sockets[0]
43+
binding = await serve(server.serve, "127.0.0.1")
44+
sockets = list(binding.sockets)
45+
assert len(sockets) == 1, "Too many sockets!"
46+
socket = sockets[0]
4547

46-
async def websocket_uri_factory() -> UriAndMetadata[None]:
47-
return {
48-
"uri": "ws://%s:%d" % socket.getsockname(),
49-
"metadata": None,
50-
}
48+
async def websocket_uri_factory() -> UriAndMetadata[None]:
49+
return {
50+
"uri": "ws://%s:%d" % socket.getsockname(),
51+
"metadata": None,
52+
}
53+
54+
client: Client[Literal[None]] = Client[None](
55+
uri_and_metadata_factory=websocket_uri_factory,
56+
client_id="test_client",
57+
server_id="test_server",
58+
transport_options=transport_options,
59+
)
60+
try:
61+
yield client
62+
finally:
63+
logging.debug("Start closing test client : %s", "test_client")
64+
await client.close()
5165

52-
client: Client[Literal[None]] = Client[None](
53-
uri_and_metadata_factory=websocket_uri_factory,
54-
client_id="test_client",
55-
server_id="test_server",
56-
transport_options=transport_options,
57-
)
58-
try:
59-
yield client
60-
finally:
61-
logging.debug("Start closing test client : %s", "test_client")
62-
await client.close()
6366
finally:
6467
logging.debug("Start closing test server")
68+
if binding:
69+
binding.close()
6570
await server.close()
71+
if binding:
72+
await binding.wait_closed()
6673
# Server should close normally
67-
no_logging_error()
74+
no_logging_error()

0 commit comments

Comments
 (0)