Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add support for claiming multiple OTKs at once. #15468

Merged
merged 9 commits into from
Apr 27, 2023
Prev Previous commit
Next Next commit
Use a flat list of algorithms instead of a map.
  • Loading branch information
clokep committed Apr 25, 2023
commit bb9081eeb72c8d0c9b9cf5e9aadd99795ebca5d1
7 changes: 4 additions & 3 deletions synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,10 @@ async def claim_client_keys(

# Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {}
for user_id, device, algorithm, _count in query:
# Note that only a single OTK can be claimed this way.
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
for user_id, device, algorithm, count in query:
body.setdefault(user_id, {}).setdefault(device, []).extend(
[algorithm] * count
)

uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
try:
Expand Down
29 changes: 19 additions & 10 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def query_user_devices(
async def claim_client_keys(
self,
destination: str,
content: Dict[str, Dict[str, Dict[str, int]]],
query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
Expand All @@ -251,24 +251,33 @@ async def claim_client_keys(
"""
sent_queries_counter.labels("client_one_time_keys").inc()

# Convert the query with counts into a legacy query and check if attempting
# to claim more than 1 OTK.
legacy_content: Dict[str, Dict[str, str]] = {}
# Convert the query with counts into a stable and unstable query and check
# if attempting to claim more than 1 OTK.
content: Dict[str, Dict[str, str]] = {}
unstable_content: Dict[str, Dict[str, List[str]]] = {}
use_unstable = False
for user_id, one_time_keys in content.items():
for user_id, one_time_keys in query.items():
for device_id, algorithms in one_time_keys.items():
if any(count > 1 for count in algorithms.values()):
use_unstable = True
if algorithms:
# Choose the first algorithm only.
legacy_content.setdefault(user_id, {})[device_id] = next(
iter(algorithms)
# Choose the first algorithm only for the stable query.
content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
# Flatten the map of algorithm -> count to a list repeating
# each algorithm count times for the unstable query.
unstable_content.setdefault(user_id, {})[device_id] = list(
itertools.chain(
*(
itertools.repeat(algorithm, count)
for algorithm, count in algorithms.items()
)
)
)

if use_unstable:
try:
return await self.transport_layer.claim_client_keys_unstable(
destination, content, timeout
destination, unstable_content, timeout
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
Expand All @@ -284,7 +293,7 @@ async def claim_client_keys(
logger.debug("Skipping unstable claim client keys API")

return await self.transport_layer.claim_client_keys(
destination, legacy_content, timeout
destination, content, timeout
)

@trace
Expand Down
8 changes: 5 additions & 3 deletions synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import Counter
from typing import (
TYPE_CHECKING,
Dict,
Expand Down Expand Up @@ -577,7 +578,7 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
# Flatten the request query.
# Generate a count for each algorithm, which is hard-coded to 1.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
Expand All @@ -603,11 +604,12 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
# Flatten the request query.
# Generate a count for each algorithm.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithms in device_keys.items():
for algorithm, count in algorithms.items():
counts = Counter(algorithms)
for algorithm, count in counts.items():
key_query.append((user_id, device_id, algorithm, count))

response = await self.handler.on_claim_client_keys(
Expand Down
16 changes: 11 additions & 5 deletions synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import re
from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple

from synapse.api.errors import InvalidAPICallError, SynapseError
Expand Down Expand Up @@ -290,7 +291,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)

# Map the legacy request to the new request format.
# Generate a count for each algorithm, which is hard-coded to 1.
query: Dict[str, Dict[str, Dict[str, int]]] = {}
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
for device_id, algorithm in one_time_keys.items():
Expand All @@ -312,9 +313,8 @@ class UnstableOneTimeKeyServlet(RestServlet):
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>": <count>
} } } }
"<device_id>": ["<algorithm>", ...]
} } }

HTTP/1.1 200 OK
{
Expand All @@ -338,7 +338,13 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
query = body.get("one_time_keys", {})

# Generate a count for each algorithm.
query: Dict[str, Dict[str, Dict[str, int]]] = {}
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
for device_id, algorithms in one_time_keys.items():
query.setdefault(user_id, {})[device_id] = Counter(algorithms)

result = await self.e2e_keys_handler.claim_one_time_keys(
query, timeout, always_include_fallback_keys=True
)
Expand Down