Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit aa07c37

Browse files
authored
Move and rename get_devices_with_keys_by_user (#8204)
* Move `get_devices_with_keys_by_user` to `EndToEndKeyWorkerStore` this seems a better fit for it. This commit simply moves the existing code: no other changes at all. * Rename `get_devices_with_keys_by_user` to better reflect what it does. * get_device_stream_token abstract method To avoid referencing fields which are declared in the derived classes, make `get_device_stream_token` abstract, and define that in the classes which define `_device_list_id_gen`.
1 parent 45e8f77 commit aa07c37

File tree

6 files changed

+67
-49
lines changed

6 files changed

+67
-49
lines changed

changelog.d/8204.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Refactor queries for device keys and cross-signatures.

synapse/handlers/device.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ async def get_user_ids_changed(self, user_id, from_token):
234234
return result
235235

236236
async def on_federation_query_user_devices(self, user_id):
237-
stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
237+
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
238+
user_id
239+
)
238240
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
239241
self_signing_key = await self.store.get_e2e_cross_signing_key(
240242
user_id, "self_signing"

synapse/replication/slave/storage/devices.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def __init__(self, database: DatabasePool, db_conn, hs):
4848
"DeviceListFederationStreamChangeCache", device_list_max
4949
)
5050

51+
def get_device_stream_token(self) -> int:
52+
return self._device_list_id_gen.get_current_token()
53+
5154
def process_replication_rows(self, stream_name, instance_name, token, rows):
5255
if stream_name == DeviceListsStream.NAME:
5356
self._device_list_id_gen.advance(instance_name, token)

synapse/storage/databases/main/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,9 @@ def __init__(self, database: DatabasePool, db_conn, hs):
264264
# Used in _generate_user_daily_visits to keep track of progress
265265
self._last_user_visit_update = self._get_start_of_day()
266266

267+
def get_device_stream_token(self) -> int:
268+
return self._device_list_id_gen.get_current_token()
269+
267270
def take_presence_startup_info(self):
268271
active_on_startup = self._presence_on_startup
269272
self._presence_on_startup = None

synapse/storage/databases/main/devices.py

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
import abc
1718
import logging
1819
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
1920

@@ -101,7 +102,7 @@ async def get_device_updates_by_remote(
101102
update included in the response), and the list of updates, where
102103
each update is a pair of EDU type and EDU contents.
103104
"""
104-
now_stream_id = self._device_list_id_gen.get_current_token()
105+
now_stream_id = self.get_device_stream_token()
105106

106107
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
107108
destination, int(from_stream_id)
@@ -412,8 +413,10 @@ def _add_user_signature_change_txn(
412413
},
413414
)
414415

416+
@abc.abstractmethod
415417
def get_device_stream_token(self) -> int:
416-
return self._device_list_id_gen.get_current_token()
418+
"""Get the current stream id from the _device_list_id_gen"""
419+
...
417420

418421
@trace
419422
async def get_user_devices_from_cache(
@@ -481,51 +484,6 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
481484
device["device_id"]: db_to_json(device["content"]) for device in devices
482485
}
483486

484-
def get_devices_with_keys_by_user(self, user_id: str):
485-
"""Get all devices (with any device keys) for a user
486-
487-
Returns:
488-
Deferred which resolves to (stream_id, devices)
489-
"""
490-
return self.db_pool.runInteraction(
491-
"get_devices_with_keys_by_user",
492-
self._get_devices_with_keys_by_user_txn,
493-
user_id,
494-
)
495-
496-
def _get_devices_with_keys_by_user_txn(
497-
self, txn: LoggingTransaction, user_id: str
498-
) -> Tuple[int, List[JsonDict]]:
499-
now_stream_id = self._device_list_id_gen.get_current_token()
500-
501-
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
502-
503-
if devices:
504-
user_devices = devices[user_id]
505-
results = []
506-
for device_id, device in user_devices.items():
507-
result = {"device_id": device_id}
508-
509-
key_json = device.get("key_json", None)
510-
if key_json:
511-
result["keys"] = db_to_json(key_json)
512-
513-
if "signatures" in device:
514-
for sig_user_id, sigs in device["signatures"].items():
515-
result["keys"].setdefault("signatures", {}).setdefault(
516-
sig_user_id, {}
517-
).update(sigs)
518-
519-
device_display_name = device.get("device_display_name", None)
520-
if device_display_name:
521-
result["device_display_name"] = device_display_name
522-
523-
results.append(result)
524-
525-
return now_stream_id, results
526-
527-
return now_stream_id, []
528-
529487
async def get_users_whose_devices_changed(
530488
self, from_key: str, user_ids: Iterable[str]
531489
) -> Set[str]:

synapse/storage/databases/main/end_to_end_keys.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
import abc
1718
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
1819

1920
from canonicaljson import encode_canonical_json
@@ -22,7 +23,7 @@
2223

2324
from synapse.logging.opentracing import log_kv, set_tag, trace
2425
from synapse.storage._base import SQLBaseStore, db_to_json
25-
from synapse.storage.database import make_in_list_sql_clause
26+
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
2627
from synapse.types import JsonDict
2728
from synapse.util import json_encoder
2829
from synapse.util.caches.descriptors import cached, cachedList
@@ -33,6 +34,51 @@
3334

3435

3536
class EndToEndKeyWorkerStore(SQLBaseStore):
37+
def get_e2e_device_keys_for_federation_query(self, user_id: str):
38+
"""Get all devices (with any device keys) for a user
39+
40+
Returns:
41+
Deferred which resolves to (stream_id, devices)
42+
"""
43+
return self.db_pool.runInteraction(
44+
"get_e2e_device_keys_for_federation_query",
45+
self._get_e2e_device_keys_for_federation_query_txn,
46+
user_id,
47+
)
48+
49+
def _get_e2e_device_keys_for_federation_query_txn(
50+
self, txn: LoggingTransaction, user_id: str
51+
) -> Tuple[int, List[JsonDict]]:
52+
now_stream_id = self.get_device_stream_token()
53+
54+
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
55+
56+
if devices:
57+
user_devices = devices[user_id]
58+
results = []
59+
for device_id, device in user_devices.items():
60+
result = {"device_id": device_id}
61+
62+
key_json = device.get("key_json", None)
63+
if key_json:
64+
result["keys"] = db_to_json(key_json)
65+
66+
if "signatures" in device:
67+
for sig_user_id, sigs in device["signatures"].items():
68+
result["keys"].setdefault("signatures", {}).setdefault(
69+
sig_user_id, {}
70+
).update(sigs)
71+
72+
device_display_name = device.get("device_display_name", None)
73+
if device_display_name:
74+
result["device_display_name"] = device_display_name
75+
76+
results.append(result)
77+
78+
return now_stream_id, results
79+
80+
return now_stream_id, []
81+
3682
@trace
3783
async def get_e2e_device_keys_for_cs_api(
3884
self, query_list: List[Tuple[str, Optional[str]]]
@@ -533,6 +579,11 @@ def _get_all_user_signature_changes_for_remotes_txn(txn):
533579
_get_all_user_signature_changes_for_remotes_txn,
534580
)
535581

582+
@abc.abstractmethod
583+
def get_device_stream_token(self) -> int:
584+
"""Get the current stream id from the _device_list_id_gen"""
585+
...
586+
536587

537588
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
538589
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):

0 commit comments

Comments
 (0)