@@ -68,9 +68,10 @@ async def main():
6868import logging
6969import warnings
7070from collections .abc import Awaitable , Callable
71- from contextlib import AbstractAsyncContextManager , asynccontextmanager
71+ from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
7272from typing import Any , AsyncIterator , Generic , Sequence , TypeVar
7373
74+ import anyio
7475from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
7576from pydantic import AnyUrl
7677
@@ -469,41 +470,47 @@ async def run(
469470 # in-process servers.
470471 raise_exceptions : bool = False ,
471472 ):
472- with warnings .catch_warnings (record = True ) as w :
473- from contextlib import AsyncExitStack
474-
475- async with AsyncExitStack () as stack :
476- lifespan_context = await stack .enter_async_context (self .lifespan (self ))
477- session = await stack .enter_async_context (
478- ServerSession (read_stream , write_stream , initialization_options )
479- )
473+ async with AsyncExitStack () as stack :
474+ lifespan_context = await stack .enter_async_context (self .lifespan (self ))
475+ session = await stack .enter_async_context (
476+ ServerSession (read_stream , write_stream , initialization_options )
477+ )
480478
479+ async with anyio .create_task_group () as tg :
481480 async for message in session .incoming_messages :
482481 logger .debug (f"Received message: { message } " )
483482
484- match message :
485- case (
486- RequestResponder (
487- request = types .ClientRequest (root = req )
488- ) as responder
489- ):
490- with responder :
491- await self ._handle_request (
492- message ,
493- req ,
494- session ,
495- lifespan_context ,
496- raise_exceptions ,
497- )
498- case types .ClientNotification (root = notify ):
499- await self ._handle_notification (notify )
500-
501- for warning in w :
502- logger .info (
503- "Warning: %s: %s" ,
504- warning .category .__name__ ,
505- warning .message ,
483+ tg .start_soon (
484+ self ._handle_message ,
485+ message ,
486+ session ,
487+ lifespan_context ,
488+ raise_exceptions ,
489+ )
490+
491+ async def _handle_message (
492+ self ,
493+ message : RequestResponder [types .ClientRequest , types .ServerResult ]
494+ | types .ClientNotification
495+ | Exception ,
496+ session : ServerSession ,
497+ lifespan_context : LifespanResultT ,
498+ raise_exceptions : bool = False ,
499+ ):
500+ with warnings .catch_warnings (record = True ) as w :
501+ match message :
502+ case (
503+ RequestResponder (request = types .ClientRequest (root = req )) as responder
504+ ):
505+ with responder :
506+ await self ._handle_request (
507+ message , req , session , lifespan_context , raise_exceptions
506508 )
509+ case types .ClientNotification (root = notify ):
510+ await self ._handle_notification (notify )
511+
512+ for warning in w :
513+ logger .info (f"Warning: { warning .category .__name__ } : { warning .message } " )
507514
508515 async def _handle_request (
509516 self ,
0 commit comments