5050 Dict ,
5151 Iterable ,
5252 List ,
53+ Mapping ,
5354 Optional ,
5455 Protocol ,
5556 Set ,
8081from synapse .storage .engines import BaseDatabaseEngine , PostgresEngine , Sqlite3Engine
8182from synapse .storage .util .id_generators import MultiWriterIdGenerator
8283from synapse .types import PersistedEventPosition , RoomStreamToken , StrCollection
83- from synapse .util .caches .descriptors import cached
84+ from synapse .util .caches .descriptors import cached , cachedList
8485from synapse .util .caches .stream_change_cache import StreamChangeCache
8586from synapse .util .cancellation import cancellable
8687from synapse .util .iterutils import batch_iter
@@ -1381,40 +1382,85 @@ async def bulk_get_last_event_pos_in_room_before_stream_ordering(
13811382 rooms
13821383 """
13831384
1385+ # First we just get the latest positions for the room, as the vast
1386+ # majority of them will be before the given end token anyway. By doing
1387+ # this we can cache most rooms.
1388+ uncapped_results = await self ._bulk_get_max_event_pos (room_ids )
1389+
1390+ # Check that the stream position for the rooms are from before the
1391+ # minimum position of the token. If not then we need to fetch more
1392+ # rows.
1393+ results : Dict [str , int ] = {}
1394+ recheck_rooms : Set [str ] = set ()
13841395 min_token = end_token .stream
1385- max_token = end_token .get_max_stream_pos ()
1396+ for room_id , stream in uncapped_results .items ():
1397+ if stream <= min_token :
1398+ results [room_id ] = stream
1399+ else :
1400+ recheck_rooms .add (room_id )
1401+
1402+ if not recheck_rooms :
1403+ return results
1404+
1405+ # There shouldn't be many rooms that we need to recheck, so we do them
1406+ # one-by-one.
1407+ for room_id in recheck_rooms :
1408+ result = await self .get_last_event_pos_in_room_before_stream_ordering (
1409+ room_id , end_token
1410+ )
1411+ if result is not None :
1412+ results [room_id ] = result [1 ].stream
1413+
1414+ return results
1415+
1416+ @cached ()
1417+ async def _get_max_event_pos (self , room_id : str ) -> int :
1418+ raise NotImplementedError ()
1419+
1420+ @cachedList (cached_method_name = "_get_max_event_pos" , list_name = "room_ids" )
1421+ async def _bulk_get_max_event_pos (
1422+ self , room_ids : StrCollection
1423+ ) -> Mapping [str , int ]:
1424+ """Fetch the max position of a persisted event in the room."""
1425+
1426+ # We need to be careful not to return positions ahead of the current
1427+ # positions, so we get the current token now and cap our queries to it.
1428+ now_token = self .get_room_max_token ()
1429+ max_pos = now_token .get_max_stream_pos ()
1430+
13861431 results : Dict [str , int ] = {}
13871432
13881433 # First, we check for the rooms in the stream change cache to see if we
13891434 # can just use the latest position from it.
13901435 missing_room_ids : Set [str ] = set ()
13911436 for room_id in room_ids :
13921437 stream_pos = self ._events_stream_cache .get_max_pos_of_last_change (room_id )
1393- if stream_pos and stream_pos <= min_token :
1438+ if stream_pos is not None :
13941439 results [room_id ] = stream_pos
13951440 else :
13961441 missing_room_ids .add (room_id )
13971442
1443+ if not missing_room_ids :
1444+ return results
1445+
13981446 # Next, we query the stream position from the DB. At first we fetch all
13991447 # positions less than the *max* stream pos in the token, then filter
14001448 # them down. We do this as a) this is a cheaper query, and b) the vast
14011449 # majority of rooms will have a latest token from before the min stream
14021450 # pos.
14031451
1404- def bulk_get_last_event_pos_txn (
1405- txn : LoggingTransaction , batch_room_ids : StrCollection
1452+ def bulk_get_max_event_pos_txn (
1453+ txn : LoggingTransaction , batched_room_ids : StrCollection
14061454 ) -> Dict [str , int ]:
1407- # This query fetches the latest stream position in the rooms before
1408- # the given max position.
14091455 clause , args = make_in_list_sql_clause (
1410- self .database_engine , "room_id" , batch_room_ids
1456+ self .database_engine , "room_id" , batched_room_ids
14111457 )
14121458 sql = f"""
14131459 SELECT room_id, (
14141460 SELECT stream_ordering FROM events AS e
14151461 LEFT JOIN rejections USING (event_id)
14161462 WHERE e.room_id = r.room_id
1417- AND stream_ordering <= ?
1463+ AND e. stream_ordering <= ?
14181464 AND NOT outlier
14191465 AND rejection_reason IS NULL
14201466 ORDER BY stream_ordering DESC
@@ -1423,72 +1469,29 @@ def bulk_get_last_event_pos_txn(
14231469 FROM rooms AS r
14241470 WHERE { clause }
14251471 """
1426- txn .execute (sql , [max_token ] + args )
1472+ txn .execute (sql , [max_pos ] + args )
14271473 return {row [0 ]: row [1 ] for row in txn }
14281474
14291475 recheck_rooms : Set [str ] = set ()
1430- for batched in batch_iter (missing_room_ids , 1000 ):
1431- result = await self .db_pool .runInteraction (
1432- "bulk_get_last_event_pos_in_room_before_stream_ordering" ,
1433- bulk_get_last_event_pos_txn ,
1434- batched ,
1476+ for batched in batch_iter (room_ids , 1000 ):
1477+ batch_results = await self .db_pool .runInteraction (
1478+ "_bulk_get_max_event_pos" , bulk_get_max_event_pos_txn , batched
14351479 )
1436-
1437- # Check that the stream position for the rooms are from before the
1438- # minimum position of the token. If not then we need to fetch more
1439- # rows.
1440- for room_id , stream in result .items ():
1441- if stream <= min_token :
1442- results [room_id ] = stream
1480+ for room_id , stream_ordering in batch_results .items ():
1481+ if stream_ordering <= now_token .stream :
1482+ results .update (batch_results )
14431483 else :
14441484 recheck_rooms .add (room_id )
14451485
1446- if not recheck_rooms :
1447- return results
1448-
1449- # For the remaining rooms we need to fetch all rows between the min and
1450- # max stream positions in the end token, and filter out the rows that
1451- # are after the end token.
1452- #
1453- # This query should be fast as the range between the min and max should
1454- # be small.
1455-
1456- def bulk_get_last_event_pos_recheck_txn (
1457- txn : LoggingTransaction , batch_room_ids : StrCollection
1458- ) -> Dict [str , int ]:
1459- clause , args = make_in_list_sql_clause (
1460- self .database_engine , "room_id" , batch_room_ids
1461- )
1462- sql = f"""
1463- SELECT room_id, instance_name, stream_ordering
1464- FROM events
1465- WHERE ? < stream_ordering AND stream_ordering <= ?
1466- AND NOT outlier
1467- AND rejection_reason IS NULL
1468- AND { clause }
1469- ORDER BY stream_ordering ASC
1470- """
1471- txn .execute (sql , [min_token , max_token ] + args )
1472-
1473- # We take the max stream ordering that is less than the token. Since
1474- # we ordered by stream ordering we just need to iterate through and
1475- # take the last matching stream ordering.
1476- txn_results : Dict [str , int ] = {}
1477- for row in txn :
1478- room_id = row [0 ]
1479- event_pos = PersistedEventPosition (row [1 ], row [2 ])
1480- if not event_pos .persisted_after (end_token ):
1481- txn_results [room_id ] = event_pos .stream
1482-
1483- return txn_results
1484-
1485- for batched in batch_iter (recheck_rooms , 1000 ):
1486- recheck_result = await self .db_pool .runInteraction (
1487- "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck" ,
1488- bulk_get_last_event_pos_recheck_txn ,
1489- batched ,
1486+ # We now need to handle rooms where the above query returned a stream
1487+ # position that was potentially too new. This should happen very rarely
1488+ # so we just query the rooms one-by-one
1489+ for room_id in recheck_rooms :
1490+ result = await self .get_last_event_pos_in_room_before_stream_ordering (
1491+ room_id , now_token
14901492 )
1491- results .update (recheck_result )
1493+ if result is not None :
1494+ results [room_id ] = result [1 ].stream
14921495
14931496 return results
14941497
0 commit comments