|
35 | 35 | from synapse.handlers.device import DeviceHandler
|
36 | 36 | from synapse.logging.context import make_deferred_yieldable, run_in_background
|
37 | 37 | from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
|
| 38 | +from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet |
38 | 39 | from synapse.types import (
|
39 | 40 | JsonDict,
|
40 | 41 | JsonMapping,
|
@@ -89,6 +90,12 @@ def __init__(self, hs: "HomeServer"):
|
89 | 90 | edu_updater.incoming_signing_key_update,
|
90 | 91 | )
|
91 | 92 |
|
| 93 | + self.device_key_uploader = self.upload_device_keys_for_user |
| 94 | + else: |
| 95 | + self.device_key_uploader = ( |
| 96 | + ReplicationUploadKeysForUserRestServlet.make_client(hs) |
| 97 | + ) |
| 98 | + |
92 | 99 | # doesn't really work as part of the generic query API, because the
|
93 | 100 | # query request requires an object POST, but we abuse the
|
94 | 101 | # "query handler" interface.
|
@@ -796,36 +803,17 @@ async def upload_keys_for_user(
|
796 | 803 | "one_time_keys": A mapping from algorithm to number of keys for that
|
797 | 804 | algorithm, including those previously persisted.
|
798 | 805 | """
|
799 |
| - # This can only be called from the main process. |
800 |
| - assert isinstance(self.device_handler, DeviceHandler) |
801 |
| - |
802 | 806 | time_now = self.clock.time_msec()
|
803 | 807 |
|
804 | 808 | # TODO: Validate the JSON to make sure it has the right keys.
|
805 | 809 | device_keys = keys.get("device_keys", None)
|
806 | 810 | if device_keys:
|
807 |
| - logger.info( |
808 |
| - "Updating device_keys for device %r for user %s at %d", |
809 |
| - device_id, |
810 |
| - user_id, |
811 |
| - time_now, |
| 811 | + await self.device_key_uploader( |
| 812 | + user_id=user_id, |
| 813 | + device_id=device_id, |
| 814 | + keys={"device_keys": device_keys}, |
812 | 815 | )
|
813 |
| - log_kv( |
814 |
| - { |
815 |
| - "message": "Updating device_keys for user.", |
816 |
| - "user_id": user_id, |
817 |
| - "device_id": device_id, |
818 |
| - } |
819 |
| - ) |
820 |
| - # TODO: Sign the JSON with the server key |
821 |
| - changed = await self.store.set_e2e_device_keys( |
822 |
| - user_id, device_id, time_now, device_keys |
823 |
| - ) |
824 |
| - if changed: |
825 |
| - # Only notify about device updates *if* the keys actually changed |
826 |
| - await self.device_handler.notify_device_update(user_id, [device_id]) |
827 |
| - else: |
828 |
| - log_kv({"message": "Not updating device_keys for user", "user_id": user_id}) |
| 816 | + |
829 | 817 | one_time_keys = keys.get("one_time_keys", None)
|
830 | 818 | if one_time_keys:
|
831 | 819 | log_kv(
|
@@ -861,18 +849,56 @@ async def upload_keys_for_user(
|
861 | 849 | {"message": "Did not update fallback_keys", "reason": "no keys given"}
|
862 | 850 | )
|
863 | 851 |
|
| 852 | + result = await self.store.count_e2e_one_time_keys(user_id, device_id) |
| 853 | + |
| 854 | + set_tag("one_time_key_counts", str(result)) |
| 855 | + return {"one_time_key_counts": result} |
| 856 | + |
| 857 | + @tag_args |
| 858 | + async def upload_device_keys_for_user( |
| 859 | + self, user_id: str, device_id: str, keys: JsonDict |
| 860 | + ) -> None: |
| 861 | + """ |
| 862 | + Args: |
| 863 | + user_id: user whose keys are being uploaded. |
| 864 | + device_id: device whose keys are being uploaded. |
| 865 | + device_keys: the `device_keys` of an /keys/upload request. |
| 866 | +
|
| 867 | + """ |
| 868 | + # This can only be called from the main process. |
| 869 | + assert isinstance(self.device_handler, DeviceHandler) |
| 870 | + |
| 871 | + time_now = self.clock.time_msec() |
| 872 | + |
| 873 | + device_keys = keys["device_keys"] |
| 874 | + logger.info( |
| 875 | + "Updating device_keys for device %r for user %s at %d", |
| 876 | + device_id, |
| 877 | + user_id, |
| 878 | + time_now, |
| 879 | + ) |
| 880 | + log_kv( |
| 881 | + { |
| 882 | + "message": "Updating device_keys for user.", |
| 883 | + "user_id": user_id, |
| 884 | + "device_id": device_id, |
| 885 | + } |
| 886 | + ) |
| 887 | + # TODO: Sign the JSON with the server key |
| 888 | + changed = await self.store.set_e2e_device_keys( |
| 889 | + user_id, device_id, time_now, device_keys |
| 890 | + ) |
| 891 | + if changed: |
| 892 | + # Only notify about device updates *if* the keys actually changed |
| 893 | + await self.device_handler.notify_device_update(user_id, [device_id]) |
| 894 | + |
864 | 895 | # the device should have been registered already, but it may have been
|
865 | 896 | # deleted due to a race with a DELETE request. Or we may be using an
|
866 | 897 | # old access_token without an associated device_id. Either way, we
|
867 | 898 | # need to double-check the device is registered to avoid ending up with
|
868 | 899 | # keys without a corresponding device.
|
869 | 900 | await self.device_handler.check_device_registered(user_id, device_id)
|
870 | 901 |
|
871 |
| - result = await self.store.count_e2e_one_time_keys(user_id, device_id) |
872 |
| - |
873 |
| - set_tag("one_time_key_counts", str(result)) |
874 |
| - return {"one_time_key_counts": result} |
875 |
| - |
876 | 902 | async def _upload_one_time_keys_for_user(
|
877 | 903 | self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
|
878 | 904 | ) -> None:
|
|
0 commit comments