Skip to content

Commit ca37377

Browse files
committed
Move ServerTransport specific methods from Transport to ServerTransport
1 parent 4a8e199 commit ca37377

File tree

4 files changed

+85
-81
lines changed

4 files changed

+85
-81
lines changed

replit_river/client_transport.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def _establish_new_connection(
125125
handshake_metadata = await self._handshake_metadata_factory()
126126
ws = await websockets.connect(websocket_uri)
127127
session_id = (
128-
self.generate_session_id()
128+
self.generate_nanoid()
129129
if not old_session
130130
else old_session.session_id
131131
)
@@ -178,7 +178,6 @@ async def _create_new_session(
178178
websocket=new_ws,
179179
transport_options=self._transport_options,
180180
is_server=False,
181-
handlers={},
182181
close_session_callback=self._delete_session,
183182
retry_connection_callback=self._retry_connection,
184183
)

replit_river/server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def __init__(self, server_id: str, transport_options: TransportOptions) -> None:
2626
self._transport = ServerTransport(
2727
transport_id=self._server_id,
2828
transport_options=transport_options,
29-
is_server=True,
3029
)
3130

3231
async def close(self) -> None:

replit_river/server_transport.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Tuple
2+
from typing import Any, Optional, Tuple
33

44
import nanoid # type: ignore # type: ignore
55
from pydantic import ValidationError
@@ -29,11 +29,23 @@
2929
)
3030
from replit_river.session import Session
3131
from replit_river.transport import Transport
32+
from replit_river.transport_options import TransportOptions
3233

3334
logger = logging.getLogger(__name__)
3435

3536

3637
class ServerTransport(Transport):
38+
def __init__(
39+
self,
40+
transport_id: str,
41+
transport_options: TransportOptions,
42+
) -> None:
43+
super().__init__(
44+
transport_id=transport_id,
45+
transport_options=transport_options,
46+
is_server=True,
47+
)
48+
3749
async def handshake_to_get_session(
3850
self,
3951
websocket: WebSocketServerProtocol,
@@ -61,7 +73,7 @@ async def handshake_to_get_session(
6173
if not session_id:
6274
raise InvalidMessageException("No session id in handshake request")
6375
try:
64-
return await self.get_or_create_session(
76+
return await self._get_or_create_session(
6577
transport_id,
6678
to_id,
6779
session_id,
@@ -75,6 +87,74 @@ async def handshake_to_get_session(
7587
raise InvalidMessageException(error_msg) from e
7688
raise WebsocketClosedException("No handshake message received")
7789

90+
async def close(self) -> None:
91+
await self._close_all_sessions()
92+
93+
async def _get_or_create_session(
94+
self,
95+
transport_id: str,
96+
to_id: str,
97+
session_id: str,
98+
websocket: WebSocketCommonProtocol,
99+
) -> Session:
100+
async with self._session_lock:
101+
session_to_close: Optional[Session] = None
102+
new_session: Optional[Session] = None
103+
if to_id not in self._sessions:
104+
logger.info(
105+
'Creating new session with "%s" using ws: %s', to_id, websocket.id
106+
)
107+
new_session = Session(
108+
transport_id,
109+
to_id,
110+
session_id,
111+
websocket,
112+
self._transport_options,
113+
self._is_server,
114+
self._handlers,
115+
close_session_callback=self._delete_session,
116+
)
117+
else:
118+
old_session = self._sessions[to_id]
119+
if old_session.session_id != session_id:
120+
logger.info(
121+
'Create new session with "%s" for session id %s'
122+
" and close old session %s",
123+
to_id,
124+
session_id,
125+
old_session.session_id,
126+
)
127+
session_to_close = old_session
128+
new_session = Session(
129+
transport_id,
130+
to_id,
131+
session_id,
132+
websocket,
133+
self._transport_options,
134+
self._is_server,
135+
self._handlers,
136+
close_session_callback=self._delete_session,
137+
)
138+
else:
139+
# If the instance id is the same, we reuse the session and assign
140+
# a new websocket to it.
141+
logger.debug(
142+
'Reuse old session with "%s" using new ws: %s',
143+
to_id,
144+
websocket.id,
145+
)
146+
try:
147+
await old_session.replace_with_new_websocket(websocket)
148+
new_session = old_session
149+
except FailedSendingMessageException as e:
150+
raise e
151+
152+
if session_to_close:
153+
logger.info("Closing stale session %s", session_to_close.session_id)
154+
await session_to_close.close()
155+
self._set_session(new_session)
156+
return new_session
157+
78158
async def _send_handshake_response(
79159
self,
80160
request_message: TransportMessage,

replit_river/transport.py

Lines changed: 2 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import asyncio
22
import logging
3-
from typing import Dict, Optional, Tuple
3+
from typing import Dict, Tuple
44

55
import nanoid # type: ignore
6-
from websockets import WebSocketCommonProtocol
76

8-
from replit_river.messages import FailedSendingMessageException
97
from replit_river.rpc import (
108
GenericRpcHandler,
119
)
@@ -21,21 +19,14 @@ def __init__(
2119
transport_id: str,
2220
transport_options: TransportOptions,
2321
is_server: bool,
24-
handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {},
2522
) -> None:
2623
self._transport_id = transport_id
2724
self._transport_options = transport_options
2825
self._is_server = is_server
2926
self._sessions: Dict[str, Session] = {}
30-
self._handlers = handlers
27+
self._handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {}
3128
self._session_lock = asyncio.Lock()
3229

33-
def generate_session_id(self) -> str:
34-
return self.generate_nanoid()
35-
36-
async def close(self) -> None:
37-
await self._close_all_sessions()
38-
3930
async def _close_all_sessions(self) -> None:
4031
sessions = self._sessions.values()
4132
logger.info(
@@ -61,68 +52,3 @@ def _set_session(self, session: Session) -> None:
6152

6253
def generate_nanoid(self) -> str:
6354
return str(nanoid.generate())
64-
65-
async def get_or_create_session(
66-
self,
67-
transport_id: str,
68-
to_id: str,
69-
session_id: str,
70-
websocket: WebSocketCommonProtocol,
71-
) -> Session:
72-
async with self._session_lock:
73-
session_to_close: Optional[Session] = None
74-
new_session: Optional[Session] = None
75-
if to_id not in self._sessions:
76-
logger.info(
77-
'Creating new session with "%s" using ws: %s', to_id, websocket.id
78-
)
79-
new_session = Session(
80-
transport_id,
81-
to_id,
82-
session_id,
83-
websocket,
84-
self._transport_options,
85-
self._is_server,
86-
self._handlers,
87-
close_session_callback=self._delete_session,
88-
)
89-
else:
90-
old_session = self._sessions[to_id]
91-
if old_session.session_id != session_id:
92-
logger.info(
93-
'Create new session with "%s" for session id %s'
94-
" and close old session %s",
95-
to_id,
96-
session_id,
97-
old_session.session_id,
98-
)
99-
session_to_close = old_session
100-
new_session = Session(
101-
transport_id,
102-
to_id,
103-
session_id,
104-
websocket,
105-
self._transport_options,
106-
self._is_server,
107-
self._handlers,
108-
close_session_callback=self._delete_session,
109-
)
110-
else:
111-
# If the instance id is the same, we reuse the session and assign
112-
# a new websocket to it.
113-
logger.debug(
114-
'Reuse old session with "%s" using new ws: %s',
115-
to_id,
116-
websocket.id,
117-
)
118-
try:
119-
await old_session.replace_with_new_websocket(websocket)
120-
new_session = old_session
121-
except FailedSendingMessageException as e:
122-
raise e
123-
124-
if session_to_close:
125-
logger.info("Closing stale session %s", session_to_close.session_id)
126-
await session_to_close.close()
127-
self._set_session(new_session)
128-
return new_session

0 commit comments

Comments
 (0)