Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Bulk-invalidate e2e cached queries after claiming keys #16613

Merged
merged 15 commits into from
Nov 9, 2023
1 change: 1 addition & 0 deletions changelog.d/16613.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve the performance of claiming encryption keys.
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ async def simple_insert_many(
def simple_insert_many_txn(
txn: LoggingTransaction,
table: str,
keys: Collection[str],
keys: Sequence[str],
values: Collection[Iterable[Any]],
) -> None:
"""Executes an INSERT query on the named table.
Expand Down
75 changes: 71 additions & 4 deletions synapse/storage/databases/main/cache.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update the non-bulk versions of these to actually require tuples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if we should actually be requiring List[str], since the keys column is []text. It'd save a conversion tuple->list here and there? 🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tuples are used though because they're immutable so is what is used in the cache keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, that makes sense!

I'd be in favour of specifying Tuple[str, ...], but in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do they have to be strings? I'd be kind of surprised if we don't have other types in there?

Anyway, yes sounds like a different PR.

Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,30 @@ def _invalidate_cache_and_stream(
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)

def _invalidate_cache_and_stream_bulk(
self,
txn: LoggingTransaction,
cache_func: CachedFunction,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""A bulk version of _invalidate_cache_and_stream.

Locally invalidate every key-tuple in `key_tuples`, then emit invalidations
for each key-tuple over replication.

This implementation is more efficient than a loop which repeatedly calls the
non-bulk version.
"""
if not key_tuples:
return

for keys in key_tuples:
txn.call_after(cache_func.invalidate, keys)
Comment on lines +503 to +504
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DMRobertson postulated if it would be better to do:

Suggested change
for keys in key_tuples:
txn.call_after(cache_func.invalidate, keys)
def invalidate():
for keys in key_tuples:
cache_func.invalidate(keys)
txn.call_after(invalidate)

But I don't think it matters much since txn_call_after just calls them in a loop anyway? Maybe one way would use less memory. 🤷


self._send_invalidation_to_replication_bulk(
txn, cache_func.__name__, key_tuples
)

def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
Expand Down Expand Up @@ -564,10 +588,6 @@ def _send_invalidation_to_replication(
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None

# get_next() returns a context manager which is designed to wrap
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)

Expand All @@ -586,6 +606,53 @@ def _send_invalidation_to_replication(
},
)

def _send_invalidation_to_replication_bulk(
self,
txn: LoggingTransaction,
cache_name: str,
key_tuples: Collection[Tuple[Any, ...]],
) -> None:
"""Announce the invalidation of multiple (but not all) cache entries.

This is more efficient than repeated calls to the non-bulk version. It should
NOT be used to invalidating the entire cache: use
`_send_invalidation_to_replication` with keys=None.

Note that this does *not* invalidate the cache locally.

Args:
txn
cache_name
key_tuples: Key-tuples to invalidate. Assumed to be non-empty.
"""
if isinstance(self.database_engine, PostgresEngine):
assert self._cache_id_gen is not None

stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples))
ts = self._clock.time_msec()
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self.db_pool.simple_insert_many_txn(
txn,
table="cache_invalidation_stream_by_instance",
keys=(
"stream_id",
"instance_name",
"cache_func",
"keys",
"invalidation_ts",
),
values=[
# We convert key_tuples to a list here because psycopg2 serialises
# lists as pq arrrays, but serialises tuples as "composite types".
# (We need an array because the `keys` column has type `[]text`.)
# See:
# https://www.psycopg.org/docs/usage.html#adapt-list
# https://www.psycopg.org/docs/usage.html#adapt-tuple
(stream_id, self._instance_name, cache_name, list(key_tuple), ts)
for stream_id, key_tuple in zip(stream_ids, key_tuples)
],
)

def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name)
Expand Down
26 changes: 12 additions & 14 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,13 +1237,11 @@ def _claim_e2e_fallback_keys_bulk_txn(
for user_id, device_id, algorithm, key_id, key_json in claimed_keys:
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)

if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)

self._invalidate_cache_and_stream_bulk(
txn, self.get_e2e_unused_fallback_key_types, seen_user_device
)

return results

Expand Down Expand Up @@ -1376,14 +1374,14 @@ def _claim_e2e_one_time_keys_bulk(
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)

seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, _, _, _ in otk_rows:
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
seen_user_device = {
(user_id, device_id) for user_id, device_id, _, _, _ in otk_rows
}
self._invalidate_cache_and_stream_bulk(
txn,
self.count_e2e_one_time_keys,
seen_user_device,
)

return otk_rows

Expand Down
56 changes: 47 additions & 9 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,8 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:

next_id = self._load_next_id_txn(txn)

txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
txn.call_after(self._mark_ids_as_finished, [next_id])
txn.call_on_exception(self._mark_ids_as_finished, [next_id])
txn.call_after(self._notifier.notify_replication)

# Update the `stream_positions` table with newly updated stream
Expand All @@ -671,14 +671,50 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:

return self._return_factor * next_id

def _mark_id_as_finished(self, next_id: int) -> None:
"""The ID has finished being processed so we should advance the
def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if get_next_txn should call this instead of duplicating some of the logic? (Same kind of goes for the cache & stream functions above TBH)

"""
Usage:

stream_id = stream_id_gen.get_next_txn(txn)
# ... persist event ...
"""

# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
raise Exception("Tried to allocate stream ID on non-writer")

next_ids = self._load_next_mult_id_txn(txn, n)

txn.call_after(self._mark_ids_as_finished, next_ids)
txn.call_on_exception(self._mark_ids_as_finished, next_ids)
txn.call_after(self._notifier.notify_replication)

# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
# bother, as nothing will read it).
#
# We only do this on the success path so that the persisted current
# position points to a persisted row with the correct instance name.
if self._writers:
txn.call_after(
run_as_background_process,
"MultiWriterIdGenerator._update_table",
self._db.runInteraction,
"MultiWriterIdGenerator._update_table",
self._update_stream_positions_table_txn,
)

return [self._return_factor * next_id for next_id in next_ids]

def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
"""These IDs have finished being processed so we should advance the
current position if possible.
"""

with self._lock:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
self._unfinished_ids.difference_update(next_ids)
clokep marked this conversation as resolved.
Show resolved Hide resolved
self._finished_ids.update(next_ids)

new_cur: Optional[int] = None

Expand Down Expand Up @@ -727,7 +763,10 @@ def _mark_id_as_finished(self, next_id: int) -> None:
curr, new_cur, self._max_position_of_local_instance
)

self._add_persisted_position(next_id)
# TODO Can we call this for just the last position or somehow batch
# _add_persisted_position.
Comment on lines +766 to +767
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a question for @erikjohnston -- _add_persisted_position seems fairly heavy and if we could optimize what we call it with (or take multiple IDs at once) I think that'd be a win.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we certainly could make make it take multiple IDs? It does expect to be called with every position (though I don't think anything would break, just be less efficient)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess if it has to be called w/ every position maybe it isn't a huge deal...

for next_id in next_ids:
self._add_persisted_position(next_id)

def get_current_token(self) -> int:
return self.get_persisted_upto_position()
Expand Down Expand Up @@ -933,8 +972,7 @@ async def __aexit__(
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
self.id_gen._mark_ids_as_finished(self.stream_ids)

self.notifier.notify_replication()

Expand Down
117 changes: 117 additions & 0 deletions tests/storage/databases/main/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, call

from synapse.storage.database import LoggingTransaction

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.unittest import HomeserverTestCase


class CacheInvalidationTestCase(HomeserverTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main

def test_bulk_invalidation(self) -> None:
master_invalidate = Mock()

self.store._get_cached_user_device.invalidate = master_invalidate

keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]

def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)

self.get_success(
self.store.db_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)

master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)


class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
def setUp(self) -> None:
super().setUp()
self.store = self.hs.get_datastores().main

def test_bulk_invalidation_replicates(self) -> None:
"""Like test_bulk_invalidation, but also checks the invalidations replicate."""
master_invalidate = Mock()
worker_invalidate = Mock()

self.store._get_cached_user_device.invalidate = master_invalidate
worker = self.make_worker_hs("synapse.app.generic_worker")
worker_ds = worker.get_datastores().main
worker_ds._get_cached_user_device.invalidate = worker_invalidate

keys_to_invalidate = [
("a", "b"),
("c", "d"),
("e", "f"),
("g", "h"),
]

def test_txn(txn: LoggingTransaction) -> None:
self.store._invalidate_cache_and_stream_bulk(
txn,
# This is an arbitrarily chosen cached store function. It was chosen
# because it takes more than one argument. We'll use this later to
# check that the invalidation was actioned over replication.
cache_func=self.store._get_cached_user_device,
key_tuples=keys_to_invalidate,
)

assert self.store._cache_id_gen is not None
initial_token = self.store._cache_id_gen.get_current_token()
self.get_success(
self.database_pool.runInteraction(
"test_invalidate_cache_and_stream_bulk", test_txn
)
)
second_token = self.store._cache_id_gen.get_current_token()

self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))

self.get_success(
worker.get_replication_data_handler().wait_for_stream_position(
"master", "caches", second_token
)
)

master_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)
worker_invalidate.assert_has_calls(
[call(key_list) for key_list in keys_to_invalidate],
any_order=True,
)
Loading