77from  dataclasses  import  dataclass 
88from  typing  import  Any , Callable 
99
10+ from  exceptiongroup  import  BaseExceptionGroup 
1011from  starlette .applications  import  Starlette 
1112from  starlette .middleware .cors  import  CORSMiddleware 
1213from  starlette .requests  import  Request 
@@ -137,8 +138,6 @@ async def serve_index(request: Request) -> HTMLResponse:
137138def  _setup_single_view_dispatcher_route (
138139    options : Options , app : Starlette , component : RootComponentConstructor 
139140) ->  None :
140-     @app .websocket_route (str (STREAM_PATH )) 
141-     @app .websocket_route (f"{ STREAM_PATH }  ) 
142141    async  def  model_stream (socket : WebSocket ) ->  None :
143142        await  socket .accept ()
144143        send , recv  =  _make_send_recv_callbacks (socket )
@@ -162,8 +161,16 @@ async def model_stream(socket: WebSocket) -> None:
162161                send ,
163162                recv ,
164163            )
165-         except  WebSocketDisconnect  as  error :
166-             logger .info (f"WebSocket disconnect: { error .code }  )
164+         except  BaseExceptionGroup  as  egroup :
165+             for  e  in  egroup .exceptions :
166+                 if  isinstance (e , WebSocketDisconnect ):
167+                     logger .info (f"WebSocket disconnect: { e .code }  )
168+                     break 
169+             else :
170+                 raise 
171+ 
172+     app .add_websocket_route (str (STREAM_PATH ), model_stream )
173+     app .add_websocket_route (f"{ STREAM_PATH }  , model_stream )
167174
168175
169176def  _make_send_recv_callbacks (
0 commit comments