Skip to content

Commit ce93858

Browse files
authored
Handle OTK uploads off master (#17271)
And fallback keys uploads. Only device keys need handling on master
1 parent a963f57 commit ce93858

File tree

3 files changed

+60
-38
lines changed

3 files changed

+60
-38
lines changed

changelog.d/17271.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Handle OTK uploads off master.

synapse/handlers/e2e_keys.py

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from synapse.handlers.device import DeviceHandler
3636
from synapse.logging.context import make_deferred_yieldable, run_in_background
3737
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
38+
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
3839
from synapse.types import (
3940
JsonDict,
4041
JsonMapping,
@@ -89,6 +90,12 @@ def __init__(self, hs: "HomeServer"):
8990
edu_updater.incoming_signing_key_update,
9091
)
9192

93+
self.device_key_uploader = self.upload_device_keys_for_user
94+
else:
95+
self.device_key_uploader = (
96+
ReplicationUploadKeysForUserRestServlet.make_client(hs)
97+
)
98+
9299
# doesn't really work as part of the generic query API, because the
93100
# query request requires an object POST, but we abuse the
94101
# "query handler" interface.
@@ -796,36 +803,17 @@ async def upload_keys_for_user(
796803
"one_time_keys": A mapping from algorithm to number of keys for that
797804
algorithm, including those previously persisted.
798805
"""
799-
# This can only be called from the main process.
800-
assert isinstance(self.device_handler, DeviceHandler)
801-
802806
time_now = self.clock.time_msec()
803807

804808
# TODO: Validate the JSON to make sure it has the right keys.
805809
device_keys = keys.get("device_keys", None)
806810
if device_keys:
807-
logger.info(
808-
"Updating device_keys for device %r for user %s at %d",
809-
device_id,
810-
user_id,
811-
time_now,
811+
await self.device_key_uploader(
812+
user_id=user_id,
813+
device_id=device_id,
814+
keys={"device_keys": device_keys},
812815
)
813-
log_kv(
814-
{
815-
"message": "Updating device_keys for user.",
816-
"user_id": user_id,
817-
"device_id": device_id,
818-
}
819-
)
820-
# TODO: Sign the JSON with the server key
821-
changed = await self.store.set_e2e_device_keys(
822-
user_id, device_id, time_now, device_keys
823-
)
824-
if changed:
825-
# Only notify about device updates *if* the keys actually changed
826-
await self.device_handler.notify_device_update(user_id, [device_id])
827-
else:
828-
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
816+
829817
one_time_keys = keys.get("one_time_keys", None)
830818
if one_time_keys:
831819
log_kv(
@@ -861,18 +849,56 @@ async def upload_keys_for_user(
861849
{"message": "Did not update fallback_keys", "reason": "no keys given"}
862850
)
863851

852+
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
853+
854+
set_tag("one_time_key_counts", str(result))
855+
return {"one_time_key_counts": result}
856+
857+
@tag_args
858+
async def upload_device_keys_for_user(
859+
self, user_id: str, device_id: str, keys: JsonDict
860+
) -> None:
861+
"""
862+
Args:
863+
user_id: user whose keys are being uploaded.
864+
device_id: device whose keys are being uploaded.
865+
device_keys: the `device_keys` of an /keys/upload request.
866+
867+
"""
868+
# This can only be called from the main process.
869+
assert isinstance(self.device_handler, DeviceHandler)
870+
871+
time_now = self.clock.time_msec()
872+
873+
device_keys = keys["device_keys"]
874+
logger.info(
875+
"Updating device_keys for device %r for user %s at %d",
876+
device_id,
877+
user_id,
878+
time_now,
879+
)
880+
log_kv(
881+
{
882+
"message": "Updating device_keys for user.",
883+
"user_id": user_id,
884+
"device_id": device_id,
885+
}
886+
)
887+
# TODO: Sign the JSON with the server key
888+
changed = await self.store.set_e2e_device_keys(
889+
user_id, device_id, time_now, device_keys
890+
)
891+
if changed:
892+
# Only notify about device updates *if* the keys actually changed
893+
await self.device_handler.notify_device_update(user_id, [device_id])
894+
864895
# the device should have been registered already, but it may have been
865896
# deleted due to a race with a DELETE request. Or we may be using an
866897
# old access_token without an associated device_id. Either way, we
867898
# need to double-check the device is registered to avoid ending up with
868899
# keys without a corresponding device.
869900
await self.device_handler.check_device_registered(user_id, device_id)
870901

871-
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
872-
873-
set_tag("one_time_key_counts", str(result))
874-
return {"one_time_key_counts": result}
875-
876902
async def _upload_one_time_keys_for_user(
877903
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
878904
) -> None:

synapse/rest/client/keys.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
)
3737
from synapse.http.site import SynapseRequest
3838
from synapse.logging.opentracing import log_kv, set_tag
39-
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
4039
from synapse.rest.client._base import client_patterns, interactive_auth_handler
4140
from synapse.types import JsonDict, StreamToken
4241
from synapse.util.cancellation import cancellable
@@ -105,13 +104,8 @@ def __init__(self, hs: "HomeServer"):
105104
self.auth = hs.get_auth()
106105
self.e2e_keys_handler = hs.get_e2e_keys_handler()
107106
self.device_handler = hs.get_device_handler()
108-
109-
if hs.config.worker.worker_app is None:
110-
# if main process
111-
self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
112-
else:
113-
# then a worker
114-
self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
107+
self._clock = hs.get_clock()
108+
self._store = hs.get_datastores().main
115109

116110
async def on_POST(
117111
self, request: SynapseRequest, device_id: Optional[str]
@@ -151,9 +145,10 @@ async def on_POST(
151145
400, "To upload keys, you must pass device_id when authenticating"
152146
)
153147

154-
result = await self.key_uploader(
148+
result = await self.e2e_keys_handler.upload_keys_for_user(
155149
user_id=user_id, device_id=device_id, keys=body
156150
)
151+
157152
return 200, result
158153

159154

0 commit comments

Comments
 (0)