@@ -123,7 +123,89 @@ def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
123123_IS_SYNC = False
124124
125125
126- class AsyncConnection :
126+ class AsyncBaseConnection :
127+ """A base connection object for server and kms connections."""
128+
129+ def __init__ (self , conn : AsyncNetworkingInterface , opts : PoolOptions ):
130+ self .conn = conn
131+ self .socket_checker : SocketChecker = SocketChecker ()
132+ self .cancel_context : _CancellationContext = _CancellationContext ()
133+ self .is_sdam = False
134+ self .closed = False
135+ self .last_timeout : float | None = None
136+ self .more_to_come = False
137+ self .opts = opts
138+ self .max_wire_version = - 1
139+
140+ def set_conn_timeout (self , timeout : Optional [float ]) -> None :
141+ """Cache last timeout to avoid duplicate calls to conn.settimeout."""
142+ if timeout == self .last_timeout :
143+ return
144+ self .last_timeout = timeout
145+ self .conn .get_conn .settimeout (timeout )
146+
147+ def apply_timeout (
148+ self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
149+ ) -> Optional [float ]:
150+ # CSOT: use remaining timeout when set.
151+ timeout = _csot .remaining ()
152+ if timeout is None :
153+ # Reset the socket timeout unless we're performing a streaming monitor check.
154+ if not self .more_to_come :
155+ self .set_conn_timeout (self .opts .socket_timeout )
156+ return None
157+ # RTT validation.
158+ rtt = _csot .get_rtt ()
159+ if rtt is None :
160+ rtt = self .connect_rtt
161+ max_time_ms = timeout - rtt
162+ if max_time_ms < 0 :
163+ timeout_details = _get_timeout_details (self .opts )
164+ formatted = format_timeout_details (timeout_details )
165+ # CSOT: raise an error without running the command since we know it will time out.
166+ errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
167+ if self .max_wire_version != - 1 :
168+ raise ExecutionTimeout (
169+ errmsg ,
170+ 50 ,
171+ {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
172+ self .max_wire_version ,
173+ )
174+ else :
175+ raise TimeoutError (errmsg )
176+ if cmd is not None :
177+ cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
178+ self .set_conn_timeout (timeout )
179+ return timeout
180+
181+ async def close_conn (self , reason : Optional [str ]) -> None :
182+ """Close this connection with a reason."""
183+ if self .closed :
184+ return
185+ await self ._close_conn ()
186+
187+ async def _close_conn (self ) -> None :
188+ """Close this connection."""
189+ if self .closed :
190+ return
191+ self .closed = True
192+ self .cancel_context .cancel ()
193+ # Note: We catch exceptions to avoid spurious errors on interpreter
194+ # shutdown.
195+ try :
196+ await self .conn .close ()
197+ except Exception : # noqa: S110
198+ pass
199+
200+ def conn_closed (self ) -> bool :
201+ """Return True if we know socket has been closed, False otherwise."""
202+ if _IS_SYNC :
203+ return self .socket_checker .socket_closed (self .conn .get_conn )
204+ else :
205+ return self .conn .is_closing ()
206+
207+
208+ class AsyncConnection (AsyncBaseConnection ):
127209 """Store a connection with some metadata.
128210
129211 :param conn: a raw connection object
@@ -141,29 +223,27 @@ def __init__(
141223 id : int ,
142224 is_sdam : bool ,
143225 ):
226+ super ().__init__ (conn , pool .opts )
144227 self .pool_ref = weakref .ref (pool )
145- self .conn = conn
146- self .address = address
147- self .id = id
228+ self .address : tuple [str , int ] = address
229+ self .id : int = id
148230 self .is_sdam = is_sdam
149- self .closed = False
150231 self .last_checkin_time = time .monotonic ()
151232 self .performed_handshake = False
152233 self .is_writable : bool = False
153234 self .max_wire_version = MAX_WIRE_VERSION
154- self .max_bson_size = MAX_BSON_SIZE
155- self .max_message_size = MAX_MESSAGE_SIZE
156- self .max_write_batch_size = MAX_WRITE_BATCH_SIZE
235+ self .max_bson_size : int = MAX_BSON_SIZE
236+ self .max_message_size : int = MAX_MESSAGE_SIZE
237+ self .max_write_batch_size : int = MAX_WRITE_BATCH_SIZE
157238 self .supports_sessions = False
158239 self .hello_ok : bool = False
159- self .is_mongos = False
240+ self .is_mongos : bool = False
160241 self .op_msg_enabled = False
161242 self .listeners = pool .opts ._event_listeners
162243 self .enabled_for_cmap = pool .enabled_for_cmap
163244 self .enabled_for_logging = pool .enabled_for_logging
164245 self .compression_settings = pool .opts ._compression_settings
165246 self .compression_context : Union [SnappyContext , ZlibContext , ZstdContext , None ] = None
166- self .socket_checker : SocketChecker = SocketChecker ()
167247 self .oidc_token_gen_id : Optional [int ] = None
168248 # Support for mechanism negotiation on the initial handshake.
169249 self .negotiated_mechs : Optional [list [str ]] = None
@@ -174,9 +254,6 @@ def __init__(
174254 self .pool_gen = pool .gen
175255 self .generation = self .pool_gen .get_overall ()
176256 self .ready = False
177- self .cancel_context : _CancellationContext = _CancellationContext ()
178- self .opts = pool .opts
179- self .more_to_come : bool = False
180257 # For load balancer support.
181258 self .service_id : Optional [ObjectId ] = None
182259 self .server_connection_id : Optional [int ] = None
@@ -192,44 +269,6 @@ def __init__(
192269 # For gossiping $clusterTime from the connection handshake to the client.
193270 self ._cluster_time = None
194271
195- def set_conn_timeout (self , timeout : Optional [float ]) -> None :
196- """Cache last timeout to avoid duplicate calls to conn.settimeout."""
197- if timeout == self .last_timeout :
198- return
199- self .last_timeout = timeout
200- self .conn .get_conn .settimeout (timeout )
201-
202- def apply_timeout (
203- self , client : AsyncMongoClient [Any ], cmd : Optional [MutableMapping [str , Any ]]
204- ) -> Optional [float ]:
205- # CSOT: use remaining timeout when set.
206- timeout = _csot .remaining ()
207- if timeout is None :
208- # Reset the socket timeout unless we're performing a streaming monitor check.
209- if not self .more_to_come :
210- self .set_conn_timeout (self .opts .socket_timeout )
211- return None
212- # RTT validation.
213- rtt = _csot .get_rtt ()
214- if rtt is None :
215- rtt = self .connect_rtt
216- max_time_ms = timeout - rtt
217- if max_time_ms < 0 :
218- timeout_details = _get_timeout_details (self .opts )
219- formatted = format_timeout_details (timeout_details )
220- # CSOT: raise an error without running the command since we know it will time out.
221- errmsg = f"operation would exceed time limit, remaining timeout:{ timeout :.5f} <= network round trip time:{ rtt :.5f} { formatted } "
222- raise ExecutionTimeout (
223- errmsg ,
224- 50 ,
225- {"ok" : 0 , "errmsg" : errmsg , "code" : 50 },
226- self .max_wire_version ,
227- )
228- if cmd is not None :
229- cmd ["maxTimeMS" ] = int (max_time_ms * 1000 )
230- self .set_conn_timeout (timeout )
231- return timeout
232-
233272 def pin_txn (self ) -> None :
234273 self .pinned_txn = True
235274 assert not self .pinned_cursor
@@ -573,26 +612,6 @@ async def close_conn(self, reason: Optional[str]) -> None:
573612 error = reason ,
574613 )
575614
576- async def _close_conn (self ) -> None :
577- """Close this connection."""
578- if self .closed :
579- return
580- self .closed = True
581- self .cancel_context .cancel ()
582- # Note: We catch exceptions to avoid spurious errors on interpreter
583- # shutdown.
584- try :
585- await self .conn .close ()
586- except Exception : # noqa: S110
587- pass
588-
589- def conn_closed (self ) -> bool :
590- """Return True if we know socket has been closed, False otherwise."""
591- if _IS_SYNC :
592- return self .socket_checker .socket_closed (self .conn .get_conn )
593- else :
594- return self .conn .is_closing ()
595-
596615 def send_cluster_time (
597616 self ,
598617 command : MutableMapping [str , Any ],
0 commit comments