@@ -90,19 +90,41 @@ async def handshake_to_get_session(
9090 async def close (self ) -> None :
9191 await self ._close_all_sessions ()
9292
93+ async def _get_existing_session (self , to_id : str ) -> Optional [Session ]:
94+ async with self ._session_lock :
95+ return self ._sessions .get (to_id )
96+
9397 async def _get_or_create_session (
9498 self ,
9599 transport_id : str ,
96100 to_id : str ,
97101 session_id : str ,
98102 websocket : WebSocketCommonProtocol ,
99103 ) -> Session :
100- session_to_close : Optional [Session ] = None
101104 new_session : Optional [Session ] = None
102- async with self ._session_lock :
103- if to_id not in self ._sessions :
105+ old_session : Optional [Session ] = await self ._get_existing_session (to_id )
106+ if not old_session :
107+ logger .info (
108+ 'Creating new session with "%s" using ws: %s' , to_id , websocket .id
109+ )
110+ new_session = Session (
111+ transport_id ,
112+ to_id ,
113+ session_id ,
114+ websocket ,
115+ self ._transport_options ,
116+ self ._is_server ,
117+ self ._handlers ,
118+ close_session_callback = self ._delete_session ,
119+ )
120+ else :
121+ if old_session .session_id != session_id :
104122 logger .info (
105- 'Creating new session with "%s" using ws: %s' , to_id , websocket .id
123+ 'Create new session with "%s" for session id %s'
124+ " and close old session %s" ,
125+ to_id ,
126+ session_id ,
127+ old_session .session_id ,
106128 )
107129 new_session = Session (
108130 transport_id ,
@@ -115,45 +137,25 @@ async def _get_or_create_session(
115137 close_session_callback = self ._delete_session ,
116138 )
117139 else :
118- old_session = self ._sessions [to_id ]
119- if old_session .session_id != session_id :
120- logger .info (
121- 'Create new session with "%s" for session id %s'
122- " and close old session %s" ,
123- to_id ,
124- session_id ,
125- old_session .session_id ,
126- )
127- session_to_close = old_session
128- new_session = Session (
129- transport_id ,
130- to_id ,
131- session_id ,
132- websocket ,
133- self ._transport_options ,
134- self ._is_server ,
135- self ._handlers ,
136- close_session_callback = self ._delete_session ,
137- )
138- else :
139- # If the instance id is the same, we reuse the session and assign
140- # a new websocket to it.
141- logger .debug (
142- 'Reuse old session with "%s" using new ws: %s' ,
143- to_id ,
144- websocket .id ,
145- )
146- try :
147- await old_session .replace_with_new_websocket (websocket )
148- new_session = old_session
149- except FailedSendingMessageException as e :
150- raise e
140+ # If the instance id is the same, we reuse the session and assign
141+ # a new websocket to it.
142+ logger .debug (
143+ 'Reuse old session with "%s" using new ws: %s' ,
144+ to_id ,
145+ websocket .id ,
146+ )
147+ try :
148+ await old_session .replace_with_new_websocket (websocket )
149+ new_session = old_session
150+ except FailedSendingMessageException as e :
151+ raise e
151152
152- self ._set_session (new_session )
153+ if old_session and new_session != old_session :
154+ logger .info ("Closing stale session %s" , old_session .session_id )
155+ await old_session .close ()
153156
154- if session_to_close :
155- logger .info ("Closing stale session %s" , session_to_close .session_id )
156- await session_to_close .close ()
157+ async with self ._session_lock :
158+ self ._set_session (new_session )
157159
158160 return new_session
159161
@@ -230,67 +232,67 @@ async def _establish_handshake(
230232 )
231233 raise InvalidMessageException ("handshake request to wrong server" )
232234
233- async with self ._session_lock :
234- old_session = self ._sessions .get (request_message .from_ , None )
235- client_next_expected_seq = (
236- handshake_request .expectedSessionState .nextExpectedSeq
237- )
238- client_next_sent_seq = (
239- handshake_request .expectedSessionState .nextSentSeq or 0
240- )
241- if old_session and old_session .session_id == handshake_request .sessionId :
242- # check invariants
243- # ordering must be correct
244- our_next_seq = await old_session .get_next_sent_seq ()
245- our_ack = await old_session .get_next_expected_seq ()
246-
247- if client_next_sent_seq > our_ack :
248- message = (
249- "client is in the future: "
250- f"server wanted { our_ack } but client has { client_next_sent_seq } "
251- )
252- await self ._send_handshake_response (
253- request_message ,
254- HandShakeStatus (ok = False , reason = message ),
255- websocket ,
256- )
257- raise SessionStateMismatchException (message )
235+ old_session = await self ._get_existing_session (request_message .from_ )
236+ client_next_expected_seq = (
237+ handshake_request .expectedSessionState .nextExpectedSeq
238+ )
239+ client_next_sent_seq = (
240+ handshake_request .expectedSessionState .nextSentSeq or 0
241+ )
242+ if old_session and old_session .session_id == handshake_request .sessionId :
243+ # check invariants
244+ # ordering must be correct
245+ our_next_seq = await old_session .get_next_sent_seq ()
246+ our_ack = await old_session .get_next_expected_seq ()
258247
259- if our_next_seq > client_next_expected_seq :
260- message = (
261- "server is in the future: "
262- f"client wanted { client_next_expected_seq } "
263- f"but server has { our_next_seq } "
264- )
265- await self ._send_handshake_response (
266- request_message ,
267- HandShakeStatus (ok = False , reason = message ),
268- websocket ,
269- )
270- raise SessionStateMismatchException (message )
271- elif old_session :
272- # we have an old session but the session id is different
273- # just delete the old session
274- await old_session .close ()
275- old_session = None
248+ if client_next_sent_seq > our_ack :
249+ message = (
250+ "client is in the future: "
251+ f"server wanted { our_ack } but client has { client_next_sent_seq } "
252+ )
253+ await self ._send_handshake_response (
254+ request_message ,
255+ HandShakeStatus (ok = False , reason = message ),
256+ websocket ,
257+ )
258+ raise SessionStateMismatchException (message )
276259
277- if not old_session and (
278- client_next_sent_seq > 0 or client_next_expected_seq > 0
279- ):
280- message = "client is trying to resume a session but we don't have it"
260+ if our_next_seq > client_next_expected_seq :
261+ message = (
262+ "server is in the future: "
263+ f"client wanted { client_next_expected_seq } "
264+ f"but server has { our_next_seq } "
265+ )
281266 await self ._send_handshake_response (
282267 request_message ,
283268 HandShakeStatus (ok = False , reason = message ),
284269 websocket ,
285270 )
286271 raise SessionStateMismatchException (message )
272+ elif old_session :
273+ # we have an old session but the session id is different
274+ # just delete the old session
275+ await old_session .close ()
276+ old_session = None
277+
287278
288- # from this point on, we're committed to connecting
289- session_id = handshake_request .sessionId
290- handshake_response = await self ._send_handshake_response (
279+ if not old_session and (
280+ client_next_sent_seq > 0 or client_next_expected_seq > 0
281+ ):
282+ message = "client is trying to resume a session but we don't have it"
283+ await self ._send_handshake_response (
291284 request_message ,
292- HandShakeStatus (ok = True , sessionId = session_id ),
285+ HandShakeStatus (ok = False , reason = message ),
293286 websocket ,
294287 )
288+ raise SessionStateMismatchException (message )
289+
290+ # from this point on, we're committed to connecting
291+ session_id = handshake_request .sessionId
292+ handshake_response = await self ._send_handshake_response (
293+ request_message ,
294+ HandShakeStatus (ok = True , sessionId = session_id ),
295+ websocket ,
296+ )
295297
296- return handshake_request , handshake_response
298+ return handshake_request , handshake_response
0 commit comments