Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Move the E2E key handling into the e2e handler #1112

Merged
merged 1 commit into from
Sep 13, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 102 additions & 3 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import ujson as json
import logging

from canonicaljson import encode_canonical_json
from twisted.internet import defer

from synapse.api.errors import SynapseError, CodeMessageException
Expand All @@ -29,7 +30,9 @@ class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()

# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
Expand Down Expand Up @@ -103,9 +106,9 @@ def do_remote_query(destination):
for destination in remote_queries
]))

defer.returnValue((200, {
defer.returnValue({
"device_keys": results, "failures": failures,
}))
})

@defer.inlineCallbacks
def query_local_devices(self, query):
Expand Down Expand Up @@ -159,3 +162,99 @@ def on_federation_query_client_keys(self, query_body):
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})

@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}

for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys

results = yield self.store.claim_e2e_one_time_keys(local_query)

json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}

@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}

yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))

defer.returnValue({
"one_time_keys": json_result,
"failures": failures
})

@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()

# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)

one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))

yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)

# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)

result = yield self.store.count_e2e_one_time_keys(user_id, device_id)

defer.returnValue({"one_time_key_counts": result})
128 changes: 16 additions & 112 deletions synapse/rest/client/v2_alpha/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@

import logging

import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer

from synapse.api.errors import SynapseError, CodeMessageException
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, parse_integer
)
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from ._base import client_v2_patterns

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,17 +60,13 @@ def __init__(self, hs):
hs (synapse.server.HomeServer): server
"""
super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.e2e_keys_handler = hs.get_e2e_keys_handler()

@defer.inlineCallbacks
def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)

user_id = requester.user.to_string()

body = parse_json_object_from_request(request)

if device_id is not None:
Expand All @@ -94,47 +86,10 @@ def on_POST(self, request, device_id):
"To upload keys, you must pass device_id when authenticating"
)

time_now = self.clock.time_msec()

# TODO: Validate the JSON to make sure it has the right keys.
device_keys = body.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)

one_time_keys = body.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))

yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)

# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)

result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
result = yield self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body
)
defer.returnValue((200, result))


class KeyQueryServlet(RestServlet):
Expand Down Expand Up @@ -199,7 +154,7 @@ def on_POST(self, request, user_id, device_id):
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue(result)
defer.returnValue((200, result))

@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
Expand All @@ -212,7 +167,7 @@ def on_GET(self, request, user_id, device_id):
{"device_keys": {user_id: device_ids}},
timeout,
)
defer.returnValue(result)
defer.returnValue((200, result))


class OneTimeKeyServlet(RestServlet):
Expand Down Expand Up @@ -244,80 +199,29 @@ class OneTimeKeyServlet(RestServlet):

def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
self.e2e_keys_handler = hs.get_e2e_keys_handler()

@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
result = yield self.handle_request(
result = yield self.e2e_keys_handler.claim_one_time_keys(
{"one_time_keys": {user_id: {device_id: algorithm}}},
timeout,
)
defer.returnValue(result)
defer.returnValue((200, result))

@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.handle_request(body, timeout)
defer.returnValue(result)

@defer.inlineCallbacks
def handle_request(self, body, timeout):
local_query = []
remote_queries = {}

for user_id, device_keys in body.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys

results = yield self.store.claim_e2e_one_time_keys(local_query)

json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}

@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}

yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))

defer.returnValue((200, {
"one_time_keys": json_result,
"failures": failures
}))
result = yield self.e2e_keys_handler.claim_one_time_keys(
body,
timeout,
)
defer.returnValue((200, result))


def register_servlets(hs, http_server):
Expand Down