Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Claim fallback keys in bulk #16570

Merged
merged 9 commits into from
Oct 30, 2023
Merged

Claim fallback keys in bulk #16570

merged 9 commits into from
Oct 30, 2023

Conversation

DMRobertson
Copy link
Contributor

@DMRobertson DMRobertson commented Oct 29, 2023

This is the second performance change as suggested by from Rich in #16554. The first was #16565.

Testing that the query is legit
\echo nuke table\\
DROP TABLE IF EXISTS e2e_fallback_keys_json;

\echo Make table\\
CREATE TABLE e2e_fallback_keys_json (
    user_id   text    NOT NULL,
    device_id text    NOT NULL,
    algorithm text    NOT NULL,
    key_id    text    NOT NULL,
    key_json  text    NOT NULL,
    used      boolean NOT NULL,
    UNIQUE (user_id, device_id, algorithm)
);

\echo Dummy data. 10 users, with 10 devices, with 10 algorithms. One key per alg. Keys alternate used and unused. \\
INSERT INTO e2e_fallback_keys_json (user_id, device_id, algorithm, key_id, key_json, used)
SELECT concat('@user_', id / 100, ':test'),
       concat('user_', id / 100, '_dev_', (id / 10) % 10),
       concat('alg_', id % 10),
       concat('key', id.id),
       concat('json', id.id),
       id % 2 = 0
FROM generate_series(1, 1000) as id(id);

\echo Select the rows we want to update. \\

SELECT *
FROM e2e_fallback_keys_json
WHERE (user_id, device_id, algorithm) = ('@user_0:test', 'user_0_dev_0', 'alg_1')
   OR (user_id, device_id, algorithm) = ('@user_0:test', 'user_0_dev_1', 'alg_2')
   OR (user_id, device_id, algorithm) = ('@user_1:test', 'user_1_dev_2', 'alg_4')
   OR (user_id, device_id, algorithm) = ('@user_1:test', 'user_1_dev_2', 'alg_9')
;

\echo The query of doom. \\

WITH
    claims(user_id, device_id, algorithm, mark_as_used) AS (
        VALUES ('@user_0:test', 'user_0_dev_0', 'alg_1', FALSE),
               ('@user_0:test', 'user_0_dev_1', 'alg_2', FALSE),
               ('@user_1:test', 'user_1_dev_2', 'alg_4', TRUE),
               ('@user_1:test', 'user_1_dev_2', 'alg_9', TRUE)
    )
UPDATE e2e_fallback_keys_json k
SET used = used OR mark_as_used
FROM claims
WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json, k.used;

\echo Reselect the rows that should have been updated. \\
SELECT *
FROM e2e_fallback_keys_json
WHERE (user_id, device_id, algorithm) = ('@user_0:test', 'user_0_dev_0', 'alg_1')  -- used false, mark false
   OR (user_id, device_id, algorithm) = ('@user_0:test', 'user_0_dev_1', 'alg_2')  -- used true, mark true
   OR (user_id, device_id, algorithm) = ('@user_1:test', 'user_1_dev_2', 'alg_4')  -- used true, mark false
   OR (user_id, device_id, algorithm) = ('@user_1:test', 'user_1_dev_2', 'alg_9')  -- used false, mark true
;

Commitwise reviewable. Completely untested outside of the practice query and CI.

@DMRobertson DMRobertson marked this pull request as ready for review October 29, 2023 01:08
@DMRobertson DMRobertson requested a review from a team as a code owner October 29, 2023 01:08
Comment on lines +1318 to +1321
async def _claim_e2e_fallback_keys_simple(
self,
query_list: Iterable[Tuple[str, str, str, bool]],
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch -- this was doing 2 queries for each item in the input list.

Comment on lines +1290 to +1297
WITH claims(user_id, device_id, algorithm, mark_as_used) AS (
VALUES ?
)
UPDATE e2e_fallback_keys_json k
SET used = used OR mark_as_used
FROM claims
WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm)
RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The form WITH ... UPDATE ... is non-standard, according to https://www.postgresql.org/docs/11/sql-update.html#id-1.9.3.182.10:

This command conforms to the SQL standard, except that the FROM and RETURNING clauses are PostgreSQL extensions, as is the ability to use WITH with UPDATE.

Copy link
Member

@clokep clokep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable! Any idea if this is decently unit tests or not?

@DMRobertson
Copy link
Contributor Author

Seems reasonable! Any idea if this is decently unit tests or not?

def test_fallback_key(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
fallback_key2 = {"alg1:k2": "fallback_key2"}
fallback_key3 = {"alg1:k2": "fallback_key3"}
otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"fallback_keys": fallback_key},
)
)
# we should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback
# key
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# we shouldn't have any unused fallback keys again
unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(unused_res, [])
# claiming an OTK again should return the same fallback key
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# re-uploading the same fallback key should still result in no unused fallback
# keys
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"fallback_keys": fallback_key},
)
)
unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(unused_res, [])
# uploading a new fallback key should result in an unused fallback key
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"fallback_keys": fallback_key2},
)
)
unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(unused_res, ["alg1"])
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
self.get_success(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": otk}
)
)
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)
# using the unstable prefix should also set the fallback key
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"org.matrix.msc2732.fallback_keys": fallback_key3},
)
)
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=False,
)
)
self.assertEqual(
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
def test_fallback_key_always_returned(self) -> None:
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "fallback_key1"}
otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
# Upload a OTK & fallback key.
self.get_success(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"one_time_keys": otk, "fallback_keys": fallback_key},
)
)
# we should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, ["alg1"])
# Claiming an OTK and requesting to always return the fallback key should
# return both.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{
"failures": {},
"one_time_keys": {local_user: {device_id: {**fallback_key, **otk}}},
},
)
# This should not mark the key as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, ["alg1"])
# Claiming an OTK again should return only the fallback key.
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{local_user: {device_id: {"alg1": 1}}},
self.requester,
timeout=None,
always_include_fallback_keys=True,
)
)
self.assertEqual(
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# And mark it as used.
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(fallback_res, [])
looks decent. In particular this tests that

  • you can claim a key
  • after doing so, that key doesn't show up as unused

it doesn't seem to test requesting more than one key at a time though. It might be prudent to add something for that.

@DMRobertson
Copy link
Contributor Author

@clokep I've written a short test of the bulk-fetching behaviour to keep me honest. Mind taking one more (last?) look?

@clokep
Copy link
Member

clokep commented Oct 30, 2023

LGTM! Thank for adding.

@DMRobertson DMRobertson merged commit fdce83e into develop Oct 30, 2023
38 checks passed
@DMRobertson DMRobertson deleted the dmr/batch-get-fallback-key branch October 30, 2023 14:34
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants