Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 220f901

Browse files
authored
Remove not needed database updates in modify user admin API (#10627)
1 parent 0c3565d commit 220f901

File tree

5 files changed

+118
-33
lines changed

5 files changed

+118
-33
lines changed

changelog.d/10627.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Remove not needed database updates in modify user admin API.

docs/admin_api/user_admin_api.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@ It returns a JSON body like the following:
2121
"threepids": [
2222
{
2323
"medium": "email",
24-
"address": "<user_mail_1>"
24+
"address": "<user_mail_1>",
25+
"added_at": 1586458409743,
26+
"validated_at": 1586458409743
2527
},
2628
{
2729
"medium": "email",
28-
"address": "<user_mail_2>"
30+
"address": "<user_mail_2>",
31+
"added_at": 1586458409743,
32+
"validated_at": 1586458409743
2933
}
3034
],
3135
"avatar_url": "<avatar_url>",

synapse/rest/admin/users.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,18 @@ async def on_PUT(
228228
if not isinstance(deactivate, bool):
229229
raise SynapseError(400, "'deactivated' parameter is not of type boolean")
230230

231-
# convert into List[Tuple[str, str]]
231+
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
232232
if external_ids is not None:
233-
new_external_ids = []
234-
for external_id in external_ids:
235-
new_external_ids.append(
236-
(external_id["auth_provider"], external_id["external_id"])
237-
)
233+
new_external_ids = {
234+
(external_id["auth_provider"], external_id["external_id"])
235+
for external_id in external_ids
236+
}
237+
238+
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
239+
if threepids is not None:
240+
new_threepids = {
241+
(threepid["medium"], threepid["address"]) for threepid in threepids
242+
}
238243

239244
if user: # modify user
240245
if "displayname" in body:
@@ -243,29 +248,39 @@ async def on_PUT(
243248
)
244249

245250
if threepids is not None:
246-
# remove old threepids from user
247-
old_threepids = await self.store.user_get_threepids(user_id)
248-
for threepid in old_threepids:
251+
# get changed threepids (added and removed)
252+
# convert List[Dict[str, Any]] into Set[Tuple[str, str]]
253+
cur_threepids = {
254+
(threepid["medium"], threepid["address"])
255+
for threepid in await self.store.user_get_threepids(user_id)
256+
}
257+
add_threepids = new_threepids - cur_threepids
258+
del_threepids = cur_threepids - new_threepids
259+
260+
# remove old threepids
261+
for medium, address in del_threepids:
249262
try:
250263
await self.auth_handler.delete_threepid(
251-
user_id, threepid["medium"], threepid["address"], None
264+
user_id, medium, address, None
252265
)
253266
except Exception:
254267
logger.exception("Failed to remove threepids")
255268
raise SynapseError(500, "Failed to remove threepids")
256269

257-
# add new threepids to user
270+
# add new threepids
258271
current_time = self.hs.get_clock().time_msec()
259-
for threepid in threepids:
272+
for medium, address in add_threepids:
260273
await self.auth_handler.add_threepid(
261-
user_id, threepid["medium"], threepid["address"], current_time
274+
user_id, medium, address, current_time
262275
)
263276

264277
if external_ids is not None:
265278
# get changed external_ids (added and removed)
266-
cur_external_ids = await self.store.get_external_ids_by_user(user_id)
267-
add_external_ids = set(new_external_ids) - set(cur_external_ids)
268-
del_external_ids = set(cur_external_ids) - set(new_external_ids)
279+
cur_external_ids = set(
280+
await self.store.get_external_ids_by_user(user_id)
281+
)
282+
add_external_ids = new_external_ids - cur_external_ids
283+
del_external_ids = cur_external_ids - new_external_ids
269284

270285
# remove old external_ids
271286
for auth_provider, external_id in del_external_ids:
@@ -348,9 +363,9 @@ async def on_PUT(
348363

349364
if threepids is not None:
350365
current_time = self.hs.get_clock().time_msec()
351-
for threepid in threepids:
366+
for medium, address in new_threepids:
352367
await self.auth_handler.add_threepid(
353-
user_id, threepid["medium"], threepid["address"], current_time
368+
user_id, medium, address, current_time
354369
)
355370
if (
356371
self.hs.config.email_enable_notifs
@@ -362,8 +377,8 @@ async def on_PUT(
362377
kind="email",
363378
app_id="m.email",
364379
app_display_name="Email Notifications",
365-
device_display_name=threepid["address"],
366-
pushkey=threepid["address"],
380+
device_display_name=address,
381+
pushkey=address,
367382
lang=None, # We don't know a user's language here
368383
data={},
369384
)

synapse/storage/databases/main/registration.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -754,16 +754,18 @@ async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[s
754754
)
755755
return user_id
756756

757-
def get_user_id_by_threepid_txn(self, txn, medium, address):
757+
def get_user_id_by_threepid_txn(
758+
self, txn, medium: str, address: str
759+
) -> Optional[str]:
758760
"""Returns user id from threepid
759761
760762
Args:
761763
txn (cursor):
762-
medium (str): threepid medium e.g. email
763-
address (str): threepid address e.g. me@example.com
764+
medium: threepid medium e.g. email
765+
address: threepid address e.g. me@example.com
764766
765767
Returns:
766-
str|None: user id or None if no user id/threepid mapping exists
768+
user id, or None if no user id/threepid mapping exists
767769
"""
768770
ret = self.db_pool.simple_select_one_txn(
769771
txn,
@@ -776,22 +778,31 @@ def get_user_id_by_threepid_txn(self, txn, medium, address):
776778
return ret["user_id"]
777779
return None
778780

779-
async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
781+
async def user_add_threepid(
782+
self,
783+
user_id: str,
784+
medium: str,
785+
address: str,
786+
validated_at: int,
787+
added_at: int,
788+
) -> None:
780789
await self.db_pool.simple_upsert(
781790
"user_threepids",
782791
{"medium": medium, "address": address},
783792
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
784793
)
785794

786-
async def user_get_threepids(self, user_id):
795+
async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
787796
return await self.db_pool.simple_select_list(
788797
"user_threepids",
789798
{"user_id": user_id},
790799
["medium", "address", "validated_at", "added_at"],
791800
"user_get_threepids",
792801
)
793802

794-
async def user_delete_threepid(self, user_id, medium, address) -> None:
803+
async def user_delete_threepid(
804+
self, user_id: str, medium: str, address: str
805+
) -> None:
795806
await self.db_pool.simple_delete(
796807
"user_threepids",
797808
keyvalues={"user_id": user_id, "medium": medium, "address": address},

tests/rest/admin/test_user.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,12 +1431,14 @@ def test_create_user(self):
14311431
self.assertEqual("Bob's name", channel.json_body["displayname"])
14321432
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
14331433
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
1434+
self.assertEqual(1, len(channel.json_body["threepids"]))
14341435
self.assertEqual(
14351436
"external_id1", channel.json_body["external_ids"][0]["external_id"]
14361437
)
14371438
self.assertEqual(
14381439
"auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
14391440
)
1441+
self.assertEqual(1, len(channel.json_body["external_ids"]))
14401442
self.assertFalse(channel.json_body["admin"])
14411443
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
14421444
self._check_fields(channel.json_body)
@@ -1676,18 +1678,53 @@ def test_set_threepid(self):
16761678
Test setting threepid for an other user.
16771679
"""
16781680

1679-
# Delete old and add new threepid to user
1681+
# Add two threepids to user
16801682
channel = self.make_request(
16811683
"PUT",
16821684
self.url_other_user,
16831685
access_token=self.admin_user_tok,
1684-
content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
1686+
content={
1687+
"threepids": [
1688+
{"medium": "email", "address": "bob1@bob.bob"},
1689+
{"medium": "email", "address": "bob2@bob.bob"},
1690+
],
1691+
},
16851692
)
16861693

16871694
self.assertEqual(200, channel.code, msg=channel.json_body)
16881695
self.assertEqual("@user:test", channel.json_body["name"])
1696+
self.assertEqual(2, len(channel.json_body["threepids"]))
1697+
# result does not always have the same sort order, therefore it becomes sorted
1698+
sorted_result = sorted(
1699+
channel.json_body["threepids"], key=lambda k: k["address"]
1700+
)
1701+
self.assertEqual("email", sorted_result[0]["medium"])
1702+
self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
1703+
self.assertEqual("email", sorted_result[1]["medium"])
1704+
self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
1705+
self._check_fields(channel.json_body)
1706+
1707+
# Set a new and remove a threepid
1708+
channel = self.make_request(
1709+
"PUT",
1710+
self.url_other_user,
1711+
access_token=self.admin_user_tok,
1712+
content={
1713+
"threepids": [
1714+
{"medium": "email", "address": "bob2@bob.bob"},
1715+
{"medium": "email", "address": "bob3@bob.bob"},
1716+
],
1717+
},
1718+
)
1719+
1720+
self.assertEqual(200, channel.code, msg=channel.json_body)
1721+
self.assertEqual("@user:test", channel.json_body["name"])
1722+
self.assertEqual(2, len(channel.json_body["threepids"]))
16891723
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
1690-
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
1724+
self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
1725+
self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
1726+
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
1727+
self._check_fields(channel.json_body)
16911728

16921729
# Get user
16931730
channel = self.make_request(
@@ -1698,8 +1735,24 @@ def test_set_threepid(self):
16981735

16991736
self.assertEqual(200, channel.code, msg=channel.json_body)
17001737
self.assertEqual("@user:test", channel.json_body["name"])
1738+
self.assertEqual(2, len(channel.json_body["threepids"]))
17011739
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
1702-
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
1740+
self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
1741+
self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
1742+
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
1743+
self._check_fields(channel.json_body)
1744+
1745+
# Remove threepids
1746+
channel = self.make_request(
1747+
"PUT",
1748+
self.url_other_user,
1749+
access_token=self.admin_user_tok,
1750+
content={"threepids": []},
1751+
)
1752+
self.assertEqual(200, channel.code, msg=channel.json_body)
1753+
self.assertEqual("@user:test", channel.json_body["name"])
1754+
self.assertEqual(0, len(channel.json_body["threepids"]))
1755+
self._check_fields(channel.json_body)
17031756

17041757
def test_set_external_id(self):
17051758
"""
@@ -1778,6 +1831,7 @@ def test_set_external_id(self):
17781831

17791832
self.assertEqual(200, channel.code, msg=channel.json_body)
17801833
self.assertEqual("@user:test", channel.json_body["name"])
1834+
self.assertEqual(2, len(channel.json_body["external_ids"]))
17811835
self.assertEqual(
17821836
channel.json_body["external_ids"],
17831837
[

0 commit comments

Comments
 (0)