Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 822e92a

Browse files
committed
Refactor storage methods to retrieve to-device messages
This commit refactors the previously rather duplicated 'get_new_messages_for_device' and 'get_new_messages' methods into one new private method with combined logic, and two small public methods. The public methods expose the correct interface for querying to-device messages for either a single device (where a limit can be used) and multiple devices (where using a limit is infeasible).
1 parent 0ac079b commit 822e92a

File tree

1 file changed

+194
-87
lines changed

1 file changed

+194
-87
lines changed

synapse/storage/databases/main/deviceinbox.py

+194-87
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, cast
17+
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
1818

1919
from synapse.logging import issue9533_logger
2020
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -137,134 +137,241 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
137137
def get_to_device_stream_token(self):
138138
return self._device_inbox_id_gen.get_current_token()
139139

140-
async def get_new_messages(
140+
async def get_messages_for_user_devices(
141141
self,
142142
user_ids: Collection[str],
143143
from_stream_id: int,
144144
to_stream_id: int,
145145
) -> Dict[Tuple[str, str], List[JsonDict]]:
146146
"""
147-
Retrieve to-device messages for a given set of user IDs.
147+
Retrieve to-device messages for a given set of users.
148148
149149
Only to-device messages with stream ids between the given boundaries
150150
(from < X <= to) are returned.
151151
152-
Note that a stream ID can be shared by multiple copies of the same message with
153-
different recipient devices. Each (device, message_content) tuple has their own
154-
row in the device_inbox table.
155-
156152
Args:
157153
user_ids: The users to retrieve to-device messages for.
158154
from_stream_id: The lower boundary of stream id to filter with (exclusive).
159155
to_stream_id: The upper boundary of stream id to filter with (inclusive).
160156
161157
Returns:
162-
A list of to-device messages.
158+
A dictionary of (user id, device id) -> list of to-device messages.
163159
"""
164-
# Bail out if none of these users have any messages
165-
for user_id in user_ids:
166-
if self._device_inbox_stream_cache.has_entity_changed(
167-
user_id, from_stream_id
168-
):
169-
break
170-
else:
171-
return {}
172-
173-
def get_new_messages_txn(txn: LoggingTransaction):
174-
# Build a query to select messages from any of the given users that are between
175-
# the given stream id bounds
160+
# We expect the stream ID returned by _get_new_device_messages to always
161+
# return to_stream_id. So, no need to return it from this function.
162+
user_id_device_id_to_messages, _ = await self._get_device_messages(
163+
user_ids=user_ids,
164+
from_stream_id=from_stream_id,
165+
to_stream_id=to_stream_id,
166+
)
176167

177-
# Scope to only the given users. We need to use this method as doing so is
178-
# different across database engines.
179-
many_clause_sql, many_clause_args = make_in_list_sql_clause(
180-
self.database_engine, "user_id", user_ids
181-
)
168+
return user_id_device_id_to_messages
182169

183-
sql = f"""
184-
SELECT user_id, device_id, message_json FROM device_inbox
185-
WHERE {many_clause_sql}
186-
AND ? < stream_id AND stream_id <= ?
187-
ORDER BY stream_id ASC
188-
"""
170+
async def get_messages_for_device(
171+
self,
172+
user_id: str,
173+
device_id: str,
174+
from_stream_id: int,
175+
to_stream_id: int,
176+
limit: int = 100,
177+
) -> Tuple[List[JsonDict], int]:
178+
"""
179+
Retrieve to-device messages for a single user device.
189180
190-
txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id))
181+
Only to-device messages with stream ids between the given boundaries
182+
(from < X <= to) are returned.
191183
192-
# Create a dictionary of (user ID, device ID) -> list of messages that
193-
# that device is meant to receive.
194-
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
184+
Args:
185+
user_id: The ID of the user to retrieve messages for.
186+
device_id: The ID of the device to retrieve to-device messages for.
187+
from_stream_id: The lower boundary of stream id to filter with (exclusive).
188+
to_stream_id: The upper boundary of stream id to filter with (inclusive).
189+
limit: A limit on the number of to-device messages returned.
195190
196-
for row in txn:
197-
recipient_user_id = row[0]
198-
recipient_device_id = row[1]
199-
message_dict = db_to_json(row[2])
191+
Returns:
192+
A tuple containing:
193+
* A dictionary of (user id, device id) -> list of to-device messages.
194+
* The last-processed stream ID. Subsequent calls of this function with the
195+
same device should pass this value as 'from_stream_id'.
196+
"""
197+
(
198+
user_id_device_id_to_messages,
199+
last_processed_stream_id,
200+
) = await self._get_device_messages(
201+
user_ids=[user_id],
202+
device_ids=[device_id],
203+
from_stream_id=from_stream_id,
204+
to_stream_id=to_stream_id,
205+
limit=limit,
206+
)
200207

201-
recipient_device_to_messages.setdefault(
202-
(recipient_user_id, recipient_device_id), []
203-
).append(message_dict)
208+
if not user_id_device_id_to_messages:
209+
# There were no messages!
210+
return [], to_stream_id
204211

205-
return recipient_device_to_messages
212+
# Extract the messages, no need to return the user and device ID again
213+
to_device_messages = list(user_id_device_id_to_messages.values())[0]
206214

207-
return await self.db_pool.runInteraction(
208-
"get_new_messages", get_new_messages_txn
209-
)
215+
return to_device_messages, last_processed_stream_id
210216

211-
async def get_new_messages_for_device(
217+
async def _get_device_messages(
212218
self,
213-
user_id: str,
214-
device_id: Optional[str],
215-
last_stream_id: int,
216-
current_stream_id: int,
217-
limit: int = 100,
218-
) -> Tuple[List[dict], int]:
219+
user_ids: Collection[str],
220+
from_stream_id: int,
221+
to_stream_id: int,
222+
device_ids: Optional[Collection[str]] = None,
223+
limit: Optional[int] = None,
224+
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
219225
"""
226+
Retrieve pending to-device messages for a collection of user devices.
227+
228+
Only to-device messages with stream ids between the given boundaries
229+
(from < X <= to) are returned.
230+
231+
Note that a stream ID can be shared by multiple copies of the same message with
232+
different recipient devices. Stream IDs are only unique in the context of a single
233+
user ID / device ID pair Thus, applying a limit (of messages to return) when working
234+
with a sliding window of stream IDs is only possible when querying messages of a
235+
single user device.
236+
237+
Finally, note that device IDs are not unique across users.
238+
220239
Args:
221-
user_id: The recipient user_id.
222-
device_id: The recipient device_id.
223-
last_stream_id: The last stream ID checked.
224-
current_stream_id: The current position of the to device
225-
message stream.
226-
limit: The maximum number of messages to retrieve.
240+
user_ids: The user IDs to filter device messages by.
241+
from_stream_id: The lower boundary of stream id to filter with (exclusive).
242+
to_stream_id: The upper boundary of stream id to filter with (inclusive).
243+
device_ids: If provided, only messages destined for these device IDs will be returned.
244+
If not provided, all device IDs for the given user IDs will be used.
245+
limit: The maximum number of to-device messages to return. Can only be used when
246+
passing a single user ID / device ID tuple.
227247
228248
Returns:
229249
A tuple containing:
230-
* A list of messages for the device.
231-
* The max stream token of these messages. There may be more to retrieve
232-
if the given limit was reached.
250+
* A dict of (user_id, device_id) -> list of to-device messages
251+
* The last-processed stream ID. If this is less than `to_stream_id`, then
252+
there may be more messages to retrieve. If `limit` is not set, then this
253+
is always equal to 'to_stream_id'.
233254
"""
234-
has_changed = self._device_inbox_stream_cache.has_entity_changed(
235-
user_id, last_stream_id
236-
)
237-
if not has_changed:
238-
return [], current_stream_id
255+
# A limit can only be applied when querying for a single user ID / device ID tuple.
256+
if limit:
257+
if not device_ids:
258+
raise AssertionError(
259+
"Programming error: _get_new_device_messages was passed 'limit' "
260+
"but not device_ids. This could lead to querying multiple user ID "
261+
"/ device ID pairs, which is not compatible with 'limit'"
262+
)
239263

240-
def get_new_messages_for_device_txn(txn):
241-
sql = (
242-
"SELECT stream_id, message_json FROM device_inbox"
243-
" WHERE user_id = ? AND device_id = ?"
244-
" AND ? < stream_id AND stream_id <= ?"
245-
" ORDER BY stream_id ASC"
246-
" LIMIT ?"
264+
if len(user_ids) > 1 or len(device_ids) > 1:
265+
raise AssertionError(
266+
"Programming error: _get_new_device_messages was passed 'limit' "
267+
"with >1 user id/device id pair"
268+
)
269+
270+
user_ids_to_query: Set[str] = set()
271+
device_ids_to_query: Set[str] = set()
272+
273+
if device_ids is not None:
274+
# If a collection of device IDs were passed, use them to filter results.
275+
# Otherwise, device IDs will be derived from the given collection of user IDs.
276+
device_ids_to_query.update(device_ids)
277+
278+
# Determine which users have devices with pending messages
279+
for user_id in user_ids:
280+
if self._device_inbox_stream_cache.has_entity_changed(
281+
user_id, from_stream_id
282+
):
283+
# This user has new messages sent to them. Query messages for them
284+
user_ids_to_query.add(user_id)
285+
286+
def get_new_device_messages_txn(txn: LoggingTransaction):
287+
# Build a query to select messages from any of the given devices that
288+
# are between the given stream id bounds.
289+
290+
# If a list of device IDs was not provided, retrieve all devices IDs
291+
# for the given users. We explicitly do not query hidden devices, as
292+
# hidden devices should not receive to-device messages.
293+
if not device_ids:
294+
user_device_dicts = self.db_pool.simple_select_many_txn(
295+
txn,
296+
table="devices",
297+
column="user_id",
298+
iterable=user_ids_to_query,
299+
keyvalues={"user_id": user_id, "hidden": False},
300+
retcols=("device_id",),
301+
)
302+
303+
device_ids_to_query.update(
304+
{row["device_id"] for row in user_device_dicts}
305+
)
306+
307+
if not user_ids_to_query or not device_ids_to_query:
308+
# We've ended up with no devices to query.
309+
return {}, to_stream_id
310+
311+
# We include both user IDs and device IDs in this query, as we have an index
312+
# (device_inbox_user_stream_id) for them.
313+
user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
314+
self.database_engine, "user_id", user_ids_to_query
247315
)
248-
txn.execute(
249-
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
316+
(
317+
device_id_many_clause_sql,
318+
device_id_many_clause_args,
319+
) = make_in_list_sql_clause(
320+
self.database_engine, "device_id", device_ids_to_query
250321
)
251322

252-
messages = []
253-
stream_pos = current_stream_id
323+
sql = f"""
324+
SELECT stream_id, user_id, device_id, message_json FROM device_inbox
325+
WHERE {user_id_many_clause_sql}
326+
AND {device_id_many_clause_sql}
327+
AND ? < stream_id AND stream_id <= ?
328+
ORDER BY stream_id ASC
329+
"""
330+
sql_args = (
331+
*user_id_many_clause_args,
332+
*device_id_many_clause_args,
333+
from_stream_id,
334+
to_stream_id,
335+
)
254336

255-
for row in txn:
256-
stream_pos = row[0]
257-
messages.append(db_to_json(row[1]))
337+
# If a limit was provided, limit the data retrieved from the database
338+
if limit:
339+
sql += "LIMIT ?"
340+
sql_args += (limit,)
258341

259-
# If the limit was not reached we know that there's no more data for this
260-
# user/device pair up to current_stream_id.
261-
if len(messages) < limit:
262-
stream_pos = current_stream_id
342+
txn.execute(sql, sql_args)
263343

264-
return messages, stream_pos
344+
# Create and fill a dictionary of (user ID, device ID) -> list of messages
345+
# intended for each device.
346+
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
347+
for message_count, row in enumerate(txn, start=1):
348+
last_processed_stream_pos = row[0]
349+
recipient_user_id = row[1]
350+
recipient_device_id = row[2]
351+
message_dict = db_to_json(row[3])
352+
353+
# Store the device details
354+
recipient_device_to_messages.setdefault(
355+
(recipient_user_id, recipient_device_id), []
356+
).append(message_dict)
357+
358+
if limit and message_count == limit:
359+
# We ended up hitting the message limit. There may be more messages to retrieve.
360+
# Return what we have, as well as the last stream position that was processed.
361+
#
362+
# The caller is expected to set this as the lower (exclusive) bound
363+
# for the next query of this device.
364+
return recipient_device_to_messages, last_processed_stream_pos
365+
366+
# The limit was not reached, thus we know that recipient_device_to_messages
367+
# contains all to-device messages for the given device and stream id range.
368+
#
369+
# We return to_stream_id, which the caller should then provide as the lower
370+
# (exclusive) bound on the next query of this device.
371+
return recipient_device_to_messages, to_stream_id
265372

266373
return await self.db_pool.runInteraction(
267-
"get_new_messages_for_device", get_new_messages_for_device_txn
374+
"get_new_device_messages", get_new_device_messages_txn
268375
)
269376

270377
@trace

0 commit comments

Comments
 (0)