From c71199e2d30bf3f79b6f46229a36e75b940dcfc0 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 17 Jan 2023 11:53:19 +0000 Subject: [PATCH] Update all stream IDs after processing replication rows (#14723) (#52) * Update all stream IDs after processing replication rows (#14723) This creates a new store method, `process_replication_position` that is called after `process_replication_rows`. By moving stream ID advances here this guarantees any relevant cache invalidations will have been applied before the stream is advanced. This avoids race conditions where Python switches between threads mid way through processing the `process_replication_rows` method where stream IDs may be advanced before caches are invalidated due to class resolution ordering. See this comment/issue for further discussion: https://github.com/matrix-org/synapse/issues/14158#issuecomment-1344048703 # Conflicts: # synapse/storage/databases/main/devices.py # synapse/storage/databases/main/events_worker.py * Fix bad cherry-picking * Remove leftover stream advance --- changelog.d/14723.bugfix | 1 + synapse/replication/tcp/client.py | 3 ++ synapse/storage/_base.py | 17 +++++++- .../storage/databases/main/account_data.py | 14 ++++-- synapse/storage/databases/main/cache.py | 21 ++++++--- synapse/storage/databases/main/deviceinbox.py | 9 +++- synapse/storage/databases/main/devices.py | 13 ++++-- .../databases/main/event_federation.py | 2 +- .../storage/databases/main/events_worker.py | 43 ++----------------- synapse/storage/databases/main/presence.py | 8 +++- synapse/storage/databases/main/push_rule.py | 8 +++- synapse/storage/databases/main/pusher.py | 6 +-- synapse/storage/databases/main/receipts.py | 10 +++-- synapse/storage/databases/main/stream.py | 18 ++++++++ synapse/storage/databases/main/tags.py | 8 +++- 15 files changed, 115 insertions(+), 66 deletions(-) create mode 100644 changelog.d/14723.bugfix diff --git a/changelog.d/14723.bugfix b/changelog.d/14723.bugfix new file mode 100644 index 000000000000..e1f89cee35c8 --- /dev/null +++ b/changelog.d/14723.bugfix @@ -0,0 +1 @@ +Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b4dad47b45ad..d1c45e647805 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -148,6 +148,9 @@ async def on_rdata( rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + # NOTE: this must be called after process_replication_rows to ensure any + # cache invalidations are first handled before any stream ID advances. + self.store.process_replication_position(stream_name, instance_name, token) if self.send_handler: await self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7f690129c940..e8092b85f137 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -59,7 +59,22 @@ def process_replication_rows( # noqa: B027 (no-op by design) token: int, rows: Iterable[Any], ) -> None: - pass + """ + Used by storage classes to invalidate caches based on incoming replication data. These + must not update any ID generators, use `process_replication_position`. + """ + + def process_replication_position( # noqa: B027 (no-op by design) + self, + stream_name: str, + instance_name: str, + token: int, + ) -> None: + """ + Used by storage classes to advance ID generators based on incoming replication data. This + is called after process_replication_rows such that caches are invalidated before any token + positions advance. + """ def _invalidate_state_caches( self, room_id: str, members_changed: Collection[str] diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 913efddcdb4a..f7874b7be650 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -415,10 +415,7 @@ def process_replication_rows( token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) + if stream_name == AccountDataStream.NAME: for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( @@ -433,6 +430,15 @@ def process_replication_rows( super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index bfa9561adcbd..858575dfddbd 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -164,9 +164,6 @@ def process_replication_rows( backfilled=True, ) elif stream_name == CachesStream.NAME: - if self._cache_id_gen: - self._cache_id_gen.advance(instance_name, token) - for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.keys is None: @@ -182,6 +179,14 @@ def process_replication_rows( super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data @@ -198,8 +203,14 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: - # TODO: Nothing to do here, handled in events_worker, cleanup? - pass + assert isinstance(data, EventsStreamCurrentStateRow) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) + + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key,) + ) + self.get_rooms_for_user.invalidate((data.state_key,)) else: raise Exception("Unknown events stream row type %s" % (row.type,)) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 87a0798e8f55..89b85ec9caa7 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -160,10 +160,15 @@ def process_replication_rows( self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) - # Important that the ID gen advances after stream change caches - self._device_inbox_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 16d39d7bfd20..db1cbfb506ef 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -163,15 +163,20 @@ def process_replication_rows( ) -> None: if stream_name == DeviceListsStream.NAME: self._invalidate_caches_for_devices(token, rows) - # Important that the ID gen advances after stream change caches - self._device_list_id_gen.advance(instance_name, token) elif stream_name == UserSignatureStream.NAME: for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) - # Important that the ID gen advances after stream change caches - self._device_list_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 9c2b38df401e..bbee02ab18f0 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1187,7 +1187,7 @@ async def get_forward_extremities_for_room_at_stream_ordering( """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. - last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) + last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined] # We don't always have a full stream_to_exterm_id table, e.g. after # the upgrade that introduced it, so we make sure we never ask for a diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index db90c8769ea1..2ed44790c718 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -249,22 +249,6 @@ def __init__( prefilled_cache=curr_state_delta_prefill, ) - event_cache_prefill, min_event_val = self.db_pool.get_cache_dict( - db_conn, - "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", - min_event_val, - prefilled_cache=event_cache_prefill, - ) - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max - ) - if hs.config.worker.run_background_tasks: # We periodically clean out old transaction ID mappings self._clock.looping_call( @@ -325,35 +309,14 @@ def get_chain_id_txn(txn: Cursor) -> int: id_column="chain_id", ) - def process_replication_rows( - self, - stream_name: str, - instance_name: str, - token: int, - rows: Iterable[Any], + def process_replication_position( + self, stream_name: str, instance_name: str, token: int ) -> None: - # Process event stream replication rows, handling both the ID generators from the events - # worker store and the stream change caches in this store as the two are interlinked. if stream_name == EventsStream.NAME: - for row in rows: - if row.type == EventsStreamEventRow.TypeId: - self._events_stream_cache.entity_has_changed( - row.data.room_id, token - ) - if row.data.type == EventTypes.Member: - self._membership_stream_cache.entity_has_changed( - row.data.state_key, token - ) - if row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) - # Important that the ID gen advances after stream change caches self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: self._backfill_id_gen.advance(instance_name, -token) - - super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9769a18a9d0c..7b60815043a6 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -439,8 +439,14 @@ def process_replication_rows( rows: Iterable[Any], ) -> None: if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d4c64c46ad44..478eec8f9e9f 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -148,12 +148,18 @@ def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == PushRulesStream.NAME: - self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 40fd781a6ab1..7f24a3b6ec5e 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -111,12 +111,12 @@ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + def process_replication_position( + self, stream_name: str, instance_name: str, token: int ) -> None: if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 89414097720f..b5cf2d8c3511 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -600,11 +600,15 @@ def process_replication_rows( row.room_id, row.receipt_type, row.user_id ) self._receipts_stream_cache.entity_has_changed(row.room_id, token) - # Important that the ID gen advances after stream change caches - self._receipts_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 62f65f4cab54..6924add6c96e 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -71,6 +71,7 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken from synapse.util.caches.descriptors import cached +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable if TYPE_CHECKING: @@ -396,6 +397,23 @@ def __init__( # during startup which would cause one to die. self._need_to_reset_federation_stream_positions = self._send_federation + events_max = self.get_room_max_stream_ordering() + event_cache_prefill, min_event_val = self.db_pool.get_cache_dict( + db_conn, + "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", + min_event_val, + prefilled_cache=event_cache_prefill, + ) + self._membership_stream_cache = StreamChangeCache( + "MembershipStreamChangeCache", events_max + ) + self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index b0f5de67a30d..e23c927e02f2 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -300,13 +300,19 @@ def process_replication_rows( rows: Iterable[Any], ) -> None: if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + class TagsStore(TagsWorkerStore): pass