@@ -133,9 +133,9 @@ def __init__(self, hs: "HomeServer"):
133133 if hs .should_send_federation ():
134134 self .send_handler = FederationSenderHandler (hs )
135135
136- # Map from stream to list of deferreds waiting for the stream to
136+ # Map from stream and instance to list of deferreds waiting for the stream to
137137 # arrive at a particular position. The lists are sorted by stream position.
138- self ._streams_to_waiters : Dict [str , List [Tuple [int , Deferred ]]] = {}
138+ self ._streams_to_waiters : Dict [Tuple [ str , str ] , List [Tuple [int , Deferred ]]] = {}
139139
140140 async def on_rdata (
141141 self , stream_name : str , instance_name : str , token : int , rows : list
@@ -270,7 +270,7 @@ async def on_rdata(
270270 # Notify any waiting deferreds. The list is ordered by position so we
271271 # just iterate through the list until we reach a position that is
272272 # greater than the received row position.
273- waiting_list = self ._streams_to_waiters .get (stream_name , [])
273+ waiting_list = self ._streams_to_waiters .get (( stream_name , instance_name ) , [])
274274
275275 # Index of first item with a position after the current token, i.e we
276276 # have called all deferreds before this index. If not overwritten by
@@ -279,14 +279,13 @@ async def on_rdata(
279279 # `len(list)` works for both cases.
280280 index_of_first_deferred_not_called = len (waiting_list )
281281
282+ # We don't fire the deferreds until after we finish iterating over the
283+ # list, to avoid the list changing when we fire the deferreds.
284+ deferreds_to_callback = []
285+
282286 for idx , (position , deferred ) in enumerate (waiting_list ):
283287 if position <= token :
284- try :
285- with PreserveLoggingContext ():
286- deferred .callback (None )
287- except Exception :
288- # The deferred has been cancelled or timed out.
289- pass
288+ deferreds_to_callback .append (deferred )
290289 else :
291290 # The list is sorted by position so we don't need to continue
292291 # checking any further entries in the list.
@@ -297,6 +296,14 @@ async def on_rdata(
297296 # loop. (This maintains the order so no need to resort)
298297 waiting_list [:] = waiting_list [index_of_first_deferred_not_called :]
299298
299+ for deferred in deferreds_to_callback :
300+ try :
301+ with PreserveLoggingContext ():
302+ deferred .callback (None )
303+ except Exception :
304+ # The deferred has been cancelled or timed out.
305+ pass
306+
300307 async def on_position (
301308 self , stream_name : str , instance_name : str , token : int
302309 ) -> None :
@@ -349,7 +356,9 @@ async def wait_for_stream_position(
349356 deferred , _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS , self ._reactor
350357 )
351358
352- waiting_list = self ._streams_to_waiters .setdefault (stream_name , [])
359+ waiting_list = self ._streams_to_waiters .setdefault (
360+ (stream_name , instance_name ), []
361+ )
353362
354363 waiting_list .append ((position , deferred ))
355364 waiting_list .sort (key = lambda t : t [0 ])
0 commit comments