1515from contextlib import asynccontextmanager
1616from dataclasses import dataclass
1717from http import HTTPStatus
18+ from types import TracebackType
19+ from typing import Self
1820
1921import anyio
2022from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
@@ -140,6 +142,7 @@ def __init__(
140142 is_json_response_enabled : bool = False ,
141143 event_store : EventStore | None = None ,
142144 security_settings : TransportSecuritySettings | None = None ,
145+ timeout : float | None = None ,
143146 ) -> None :
144147 """
145148 Initialize a new StreamableHTTP server transport.
@@ -153,6 +156,9 @@ def __init__(
153156 resumability will be enabled, allowing clients to
154157 reconnect and resume messages.
155158 security_settings: Optional security settings for DNS rebinding protection.
159+ timeout: Optional idle timeout for transport. If provided, the transport will
160+ terminate if it remains idle for longer than the defined timeout
161+ duration in seconds.
156162
157163 Raises:
158164 ValueError: If the session ID contains invalid characters.
@@ -172,6 +178,12 @@ def __init__(
172178 ],
173179 ] = {}
174180 self ._terminated = False
181+ self ._timeout = timeout
182+
183+ # for idle detection
184+ self ._processing_request_count = 0
185+ self ._idle_condition = anyio .Condition ()
186+ self ._has_request = False
175187
176188 @property
177189 def is_terminated (self ) -> bool :
@@ -626,6 +638,9 @@ async def terminate(self) -> None:
626638 Once terminated, all requests with this session ID will receive 404 Not Found.
627639 """
628640
641+ if self ._terminated :
642+ return
643+
629644 self ._terminated = True
630645 logger .info (f"Terminating session: { self .mcp_session_id } " )
631646
@@ -796,6 +811,42 @@ async def send_event(event_message: EventMessage) -> None:
796811 )
797812 await response (request .scope , request .receive , send )
798813
814+ async def __aenter__ (self ) -> Self :
815+ async with self ._idle_condition :
816+ self ._processing_request_count += 1
817+ self ._has_request = True
818+ return self
819+
820+ async def __aexit__ (
821+ self ,
822+ exc_type : type [BaseException ] | None ,
823+ exc_value : BaseException | None ,
824+ traceback : TracebackType | None ,
825+ ) -> None :
826+ async with self ._idle_condition :
827+ self ._processing_request_count -= 1
828+ if self ._processing_request_count == 0 :
829+ self ._idle_condition .notify_all ()
830+
831+ async def _idle_timeout_terminate (self , timeout : float ) -> None :
832+ """
833+ Terminate the transport if it remains idle for longer than the defined timeout duration.
834+ """
835+ while not self ._terminated :
836+ # wait for transport to be idle
837+ async with self ._idle_condition :
838+ if self ._processing_request_count > 0 :
839+ await self ._idle_condition .wait ()
840+ self ._has_request = False
841+
842+ # wait for idle timeout
843+ await anyio .sleep (timeout )
844+
845+ # If there are no requests during the wait period, terminate the transport
846+ if not self ._has_request :
847+ logger .debug (f"Terminating transport due to idle timeout: { self .mcp_session_id } " )
848+ await self .terminate ()
849+
799850 @asynccontextmanager
800851 async def connect (
801852 self ,
@@ -812,6 +863,10 @@ async def connect(
812863 Tuple of (read_stream, write_stream) for bidirectional communication
813864 """
814865
866+ # Terminated transports should not be connected again
867+ if self ._terminated :
868+ raise RuntimeError ("Transport is terminated" )
869+
815870 # Create the memory streams for this connection
816871
817872 read_stream_writer , read_stream = anyio .create_memory_object_stream [SessionMessage | Exception ](0 )
@@ -884,20 +939,13 @@ async def message_router():
884939 # Start the message router
885940 tg .start_soon (message_router )
886941
942+ # Start idle timeout task if timeout is set
943+ if self ._timeout is not None :
944+ tg .start_soon (self ._idle_timeout_terminate , self ._timeout )
945+
887946 try :
888947 # Yield the streams for the caller to use
889948 yield read_stream , write_stream
890949 finally :
891- for stream_id in list (self ._request_streams .keys ()):
892- await self ._clean_up_memory_streams (stream_id )
893- self ._request_streams .clear ()
894-
895- # Clean up the read and write streams
896- try :
897- await read_stream_writer .aclose ()
898- await read_stream .aclose ()
899- await write_stream_reader .aclose ()
900- await write_stream .aclose ()
901- except Exception as e :
902- # During cleanup, we catch all exceptions since streams might be in various states
903- logger .debug (f"Error closing streams: { e } " )
950+ # Terminate the transport when the context manager exits
951+ await self .terminate ()
0 commit comments