Skip to content
This repository was archived by the owner on Apr 4, 2024. It is now read-only.

Commit 85df9bc

Browse files
committed
Rewrite SocketManager
1 parent a4fc06c commit 85df9bc

File tree

2 files changed

+17
-77
lines changed

2 files changed

+17
-77
lines changed

fastapi_socketio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .socket_manager import SocketManager
1+
from .socket_manager import SocketManager

fastapi_socketio/socket_manager.py

Lines changed: 16 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import socketio
44
from fastapi import FastAPI
5+
from fastapi.middleware.cors import CORSMiddleware
56

6-
class SocketManager:
7+
8+
class SocketManager(socketio.AsyncServer):
79
"""
810
Integrates SocketIO with FastAPI app.
911
Adds `sio` property to FastAPI object (app).
@@ -18,88 +20,26 @@ class SocketManager:
1820
"""
1921

2022
def __init__(
21-
self,
22-
app: FastAPI,
23-
mount_location: str = "/ws",
24-
socketio_path: str = "socket.io",
25-
cors_allowed_origins: Union[str, list] = '*',
26-
async_mode: str = "asgi",
27-
**kwargs
23+
self,
24+
app: FastAPI,
25+
mount_location: str = "/ws",
26+
socketio_path: str = "socket.io",
27+
cors_allowed_origins: Union[str, list] = '*',
28+
async_mode: str = "asgi",
29+
**kwargs
2830
) -> None:
29-
# TODO: Change Cors policy based on fastapi cors Middleware
30-
self._sio = socketio.AsyncServer(async_mode=async_mode, cors_allowed_origins=cors_allowed_origins, **kwargs)
31+
middleware = next((x for x in app.user_middleware if issubclass(x.cls, CORSMiddleware)), None)
32+
if middleware:
33+
cors_allowed_origins = middleware.options.get("allow_origins", "*")
34+
super().__init__(cors_allowed_origins=cors_allowed_origins, async_mode=async_mode, **kwargs)
3135
self._app = socketio.ASGIApp(
32-
socketio_server=self._sio, socketio_path=socketio_path
36+
socketio_server=self, socketio_path=socketio_path
3337
)
3438

3539
app.mount(mount_location, self._app)
3640
app.add_route(f"/{socketio_path}/", route=self._app, methods=["GET", "POST"])
3741
app.add_websocket_route(f"/{socketio_path}/", self._app)
38-
app.sio = self._sio
42+
app.sio = self
3943

4044
def is_asyncio_based(self) -> bool:
4145
return True
42-
43-
@property
44-
def on(self):
45-
return self._sio.on
46-
47-
@property
48-
def attach(self):
49-
return self._sio.attach
50-
51-
@property
52-
def emit(self):
53-
return self._sio.emit
54-
55-
@property
56-
def send(self):
57-
return self._sio.send
58-
59-
@property
60-
def call(self):
61-
return self._sio.call
62-
63-
@property
64-
def close_room(self):
65-
return self._sio.close_room
66-
67-
@property
68-
def get_session(self):
69-
return self._sio.get_session
70-
71-
@property
72-
def save_session(self):
73-
return self._sio.save_session
74-
75-
@property
76-
def session(self):
77-
return self._sio.session
78-
79-
@property
80-
def disconnect(self):
81-
return self._sio.disconnect
82-
83-
@property
84-
def handle_request(self):
85-
return self._sio.handle_request
86-
87-
@property
88-
def start_background_task(self):
89-
return self._sio.start_background_task
90-
91-
@property
92-
def sleep(self):
93-
return self._sio.sleep
94-
95-
@property
96-
def enter_room(self):
97-
return self._sio.enter_room
98-
99-
@property
100-
def leave_room(self):
101-
return self._sio.leave_room
102-
103-
@property
104-
def register_namespace(self):
105-
return self._sio.register_namespace

0 commit comments

Comments
 (0)