|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
16 | 16 | import logging
|
17 |
| -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, cast |
| 17 | +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast |
18 | 18 |
|
19 | 19 | from synapse.logging import issue9533_logger
|
20 | 20 | from synapse.logging.opentracing import log_kv, set_tag, trace
|
@@ -137,134 +137,241 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
|
137 | 137 | def get_to_device_stream_token(self):
|
138 | 138 | return self._device_inbox_id_gen.get_current_token()
|
139 | 139 |
|
140 |
| - async def get_new_messages( |
| 140 | + async def get_messages_for_user_devices( |
141 | 141 | self,
|
142 | 142 | user_ids: Collection[str],
|
143 | 143 | from_stream_id: int,
|
144 | 144 | to_stream_id: int,
|
145 | 145 | ) -> Dict[Tuple[str, str], List[JsonDict]]:
|
146 | 146 | """
|
147 |
| - Retrieve to-device messages for a given set of user IDs. |
| 147 | + Retrieve to-device messages for a given set of users. |
148 | 148 |
|
149 | 149 | Only to-device messages with stream ids between the given boundaries
|
150 | 150 | (from < X <= to) are returned.
|
151 | 151 |
|
152 |
| - Note that a stream ID can be shared by multiple copies of the same message with |
153 |
| - different recipient devices. Each (device, message_content) tuple has their own |
154 |
| - row in the device_inbox table. |
155 |
| -
|
156 | 152 | Args:
|
157 | 153 | user_ids: The users to retrieve to-device messages for.
|
158 | 154 | from_stream_id: The lower boundary of stream id to filter with (exclusive).
|
159 | 155 | to_stream_id: The upper boundary of stream id to filter with (inclusive).
|
160 | 156 |
|
161 | 157 | Returns:
|
162 |
| - A list of to-device messages. |
| 158 | + A dictionary of (user id, device id) -> list of to-device messages. |
163 | 159 | """
|
164 |
| - # Bail out if none of these users have any messages |
165 |
| - for user_id in user_ids: |
166 |
| - if self._device_inbox_stream_cache.has_entity_changed( |
167 |
| - user_id, from_stream_id |
168 |
| - ): |
169 |
| - break |
170 |
| - else: |
171 |
| - return {} |
172 |
| - |
173 |
| - def get_new_messages_txn(txn: LoggingTransaction): |
174 |
| - # Build a query to select messages from any of the given users that are between |
175 |
| - # the given stream id bounds |
| 160 | + # We expect the stream ID returned by _get_new_device_messages to always |
| 161 | + # return to_stream_id. So, no need to return it from this function. |
| 162 | + user_id_device_id_to_messages, _ = await self._get_device_messages( |
| 163 | + user_ids=user_ids, |
| 164 | + from_stream_id=from_stream_id, |
| 165 | + to_stream_id=to_stream_id, |
| 166 | + ) |
176 | 167 |
|
177 |
| - # Scope to only the given users. We need to use this method as doing so is |
178 |
| - # different across database engines. |
179 |
| - many_clause_sql, many_clause_args = make_in_list_sql_clause( |
180 |
| - self.database_engine, "user_id", user_ids |
181 |
| - ) |
| 168 | + return user_id_device_id_to_messages |
182 | 169 |
|
183 |
| - sql = f""" |
184 |
| - SELECT user_id, device_id, message_json FROM device_inbox |
185 |
| - WHERE {many_clause_sql} |
186 |
| - AND ? < stream_id AND stream_id <= ? |
187 |
| - ORDER BY stream_id ASC |
188 |
| - """ |
| 170 | + async def get_messages_for_device( |
| 171 | + self, |
| 172 | + user_id: str, |
| 173 | + device_id: str, |
| 174 | + from_stream_id: int, |
| 175 | + to_stream_id: int, |
| 176 | + limit: int = 100, |
| 177 | + ) -> Tuple[List[JsonDict], int]: |
| 178 | + """ |
| 179 | + Retrieve to-device messages for a single user device. |
189 | 180 |
|
190 |
| - txn.execute(sql, (*many_clause_args, from_stream_id, to_stream_id)) |
| 181 | + Only to-device messages with stream ids between the given boundaries |
| 182 | + (from < X <= to) are returned. |
191 | 183 |
|
192 |
| - # Create a dictionary of (user ID, device ID) -> list of messages that |
193 |
| - # that device is meant to receive. |
194 |
| - recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} |
| 184 | + Args: |
| 185 | + user_id: The ID of the user to retrieve messages for. |
| 186 | + device_id: The ID of the device to retrieve to-device messages for. |
| 187 | + from_stream_id: The lower boundary of stream id to filter with (exclusive). |
| 188 | + to_stream_id: The upper boundary of stream id to filter with (inclusive). |
| 189 | + limit: A limit on the number of to-device messages returned. |
195 | 190 |
|
196 |
| - for row in txn: |
197 |
| - recipient_user_id = row[0] |
198 |
| - recipient_device_id = row[1] |
199 |
| - message_dict = db_to_json(row[2]) |
| 191 | + Returns: |
| 192 | + A tuple containing: |
| 193 | + * A dictionary of (user id, device id) -> list of to-device messages. |
| 194 | + * The last-processed stream ID. Subsequent calls of this function with the |
| 195 | + same device should pass this value as 'from_stream_id'. |
| 196 | + """ |
| 197 | + ( |
| 198 | + user_id_device_id_to_messages, |
| 199 | + last_processed_stream_id, |
| 200 | + ) = await self._get_device_messages( |
| 201 | + user_ids=[user_id], |
| 202 | + device_ids=[device_id], |
| 203 | + from_stream_id=from_stream_id, |
| 204 | + to_stream_id=to_stream_id, |
| 205 | + limit=limit, |
| 206 | + ) |
200 | 207 |
|
201 |
| - recipient_device_to_messages.setdefault( |
202 |
| - (recipient_user_id, recipient_device_id), [] |
203 |
| - ).append(message_dict) |
| 208 | + if not user_id_device_id_to_messages: |
| 209 | + # There were no messages! |
| 210 | + return [], to_stream_id |
204 | 211 |
|
205 |
| - return recipient_device_to_messages |
| 212 | + # Extract the messages, no need to return the user and device ID again |
| 213 | + to_device_messages = list(user_id_device_id_to_messages.values())[0] |
206 | 214 |
|
207 |
| - return await self.db_pool.runInteraction( |
208 |
| - "get_new_messages", get_new_messages_txn |
209 |
| - ) |
| 215 | + return to_device_messages, last_processed_stream_id |
210 | 216 |
|
211 |
| - async def get_new_messages_for_device( |
| 217 | + async def _get_device_messages( |
212 | 218 | self,
|
213 |
| - user_id: str, |
214 |
| - device_id: Optional[str], |
215 |
| - last_stream_id: int, |
216 |
| - current_stream_id: int, |
217 |
| - limit: int = 100, |
218 |
| - ) -> Tuple[List[dict], int]: |
| 219 | + user_ids: Collection[str], |
| 220 | + from_stream_id: int, |
| 221 | + to_stream_id: int, |
| 222 | + device_ids: Optional[Collection[str]] = None, |
| 223 | + limit: Optional[int] = None, |
| 224 | + ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: |
219 | 225 | """
|
| 226 | + Retrieve pending to-device messages for a collection of user devices. |
| 227 | +
|
| 228 | + Only to-device messages with stream ids between the given boundaries |
| 229 | + (from < X <= to) are returned. |
| 230 | +
|
| 231 | + Note that a stream ID can be shared by multiple copies of the same message with |
| 232 | + different recipient devices. Stream IDs are only unique in the context of a single |
| 233 | + user ID / device ID pair Thus, applying a limit (of messages to return) when working |
| 234 | + with a sliding window of stream IDs is only possible when querying messages of a |
| 235 | + single user device. |
| 236 | +
|
| 237 | + Finally, note that device IDs are not unique across users. |
| 238 | +
|
220 | 239 | Args:
|
221 |
| - user_id: The recipient user_id. |
222 |
| - device_id: The recipient device_id. |
223 |
| - last_stream_id: The last stream ID checked. |
224 |
| - current_stream_id: The current position of the to device |
225 |
| - message stream. |
226 |
| - limit: The maximum number of messages to retrieve. |
| 240 | + user_ids: The user IDs to filter device messages by. |
| 241 | + from_stream_id: The lower boundary of stream id to filter with (exclusive). |
| 242 | + to_stream_id: The upper boundary of stream id to filter with (inclusive). |
| 243 | + device_ids: If provided, only messages destined for these device IDs will be returned. |
| 244 | + If not provided, all device IDs for the given user IDs will be used. |
| 245 | + limit: The maximum number of to-device messages to return. Can only be used when |
| 246 | + passing a single user ID / device ID tuple. |
227 | 247 |
|
228 | 248 | Returns:
|
229 | 249 | A tuple containing:
|
230 |
| - * A list of messages for the device. |
231 |
| - * The max stream token of these messages. There may be more to retrieve |
232 |
| - if the given limit was reached. |
| 250 | + * A dict of (user_id, device_id) -> list of to-device messages |
| 251 | + * The last-processed stream ID. If this is less than `to_stream_id`, then |
| 252 | + there may be more messages to retrieve. If `limit` is not set, then this |
| 253 | + is always equal to 'to_stream_id'. |
233 | 254 | """
|
234 |
| - has_changed = self._device_inbox_stream_cache.has_entity_changed( |
235 |
| - user_id, last_stream_id |
236 |
| - ) |
237 |
| - if not has_changed: |
238 |
| - return [], current_stream_id |
| 255 | + # A limit can only be applied when querying for a single user ID / device ID tuple. |
| 256 | + if limit: |
| 257 | + if not device_ids: |
| 258 | + raise AssertionError( |
| 259 | + "Programming error: _get_new_device_messages was passed 'limit' " |
| 260 | + "but not device_ids. This could lead to querying multiple user ID " |
| 261 | + "/ device ID pairs, which is not compatible with 'limit'" |
| 262 | + ) |
239 | 263 |
|
240 |
| - def get_new_messages_for_device_txn(txn): |
241 |
| - sql = ( |
242 |
| - "SELECT stream_id, message_json FROM device_inbox" |
243 |
| - " WHERE user_id = ? AND device_id = ?" |
244 |
| - " AND ? < stream_id AND stream_id <= ?" |
245 |
| - " ORDER BY stream_id ASC" |
246 |
| - " LIMIT ?" |
| 264 | + if len(user_ids) > 1 or len(device_ids) > 1: |
| 265 | + raise AssertionError( |
| 266 | + "Programming error: _get_new_device_messages was passed 'limit' " |
| 267 | + "with >1 user id/device id pair" |
| 268 | + ) |
| 269 | + |
| 270 | + user_ids_to_query: Set[str] = set() |
| 271 | + device_ids_to_query: Set[str] = set() |
| 272 | + |
| 273 | + if device_ids is not None: |
| 274 | + # If a collection of device IDs were passed, use them to filter results. |
| 275 | + # Otherwise, device IDs will be derived from the given collection of user IDs. |
| 276 | + device_ids_to_query.update(device_ids) |
| 277 | + |
| 278 | + # Determine which users have devices with pending messages |
| 279 | + for user_id in user_ids: |
| 280 | + if self._device_inbox_stream_cache.has_entity_changed( |
| 281 | + user_id, from_stream_id |
| 282 | + ): |
| 283 | + # This user has new messages sent to them. Query messages for them |
| 284 | + user_ids_to_query.add(user_id) |
| 285 | + |
| 286 | + def get_new_device_messages_txn(txn: LoggingTransaction): |
| 287 | + # Build a query to select messages from any of the given devices that |
| 288 | + # are between the given stream id bounds. |
| 289 | + |
| 290 | + # If a list of device IDs was not provided, retrieve all devices IDs |
| 291 | + # for the given users. We explicitly do not query hidden devices, as |
| 292 | + # hidden devices should not receive to-device messages. |
| 293 | + if not device_ids: |
| 294 | + user_device_dicts = self.db_pool.simple_select_many_txn( |
| 295 | + txn, |
| 296 | + table="devices", |
| 297 | + column="user_id", |
| 298 | + iterable=user_ids_to_query, |
| 299 | + keyvalues={"user_id": user_id, "hidden": False}, |
| 300 | + retcols=("device_id",), |
| 301 | + ) |
| 302 | + |
| 303 | + device_ids_to_query.update( |
| 304 | + {row["device_id"] for row in user_device_dicts} |
| 305 | + ) |
| 306 | + |
| 307 | + if not user_ids_to_query or not device_ids_to_query: |
| 308 | + # We've ended up with no devices to query. |
| 309 | + return {}, to_stream_id |
| 310 | + |
| 311 | + # We include both user IDs and device IDs in this query, as we have an index |
| 312 | + # (device_inbox_user_stream_id) for them. |
| 313 | + user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause( |
| 314 | + self.database_engine, "user_id", user_ids_to_query |
247 | 315 | )
|
248 |
| - txn.execute( |
249 |
| - sql, (user_id, device_id, last_stream_id, current_stream_id, limit) |
| 316 | + ( |
| 317 | + device_id_many_clause_sql, |
| 318 | + device_id_many_clause_args, |
| 319 | + ) = make_in_list_sql_clause( |
| 320 | + self.database_engine, "device_id", device_ids_to_query |
250 | 321 | )
|
251 | 322 |
|
252 |
| - messages = [] |
253 |
| - stream_pos = current_stream_id |
| 323 | + sql = f""" |
| 324 | + SELECT stream_id, user_id, device_id, message_json FROM device_inbox |
| 325 | + WHERE {user_id_many_clause_sql} |
| 326 | + AND {device_id_many_clause_sql} |
| 327 | + AND ? < stream_id AND stream_id <= ? |
| 328 | + ORDER BY stream_id ASC |
| 329 | + """ |
| 330 | + sql_args = ( |
| 331 | + *user_id_many_clause_args, |
| 332 | + *device_id_many_clause_args, |
| 333 | + from_stream_id, |
| 334 | + to_stream_id, |
| 335 | + ) |
254 | 336 |
|
255 |
| - for row in txn: |
256 |
| - stream_pos = row[0] |
257 |
| - messages.append(db_to_json(row[1])) |
| 337 | + # If a limit was provided, limit the data retrieved from the database |
| 338 | + if limit: |
| 339 | + sql += "LIMIT ?" |
| 340 | + sql_args += (limit,) |
258 | 341 |
|
259 |
| - # If the limit was not reached we know that there's no more data for this |
260 |
| - # user/device pair up to current_stream_id. |
261 |
| - if len(messages) < limit: |
262 |
| - stream_pos = current_stream_id |
| 342 | + txn.execute(sql, sql_args) |
263 | 343 |
|
264 |
| - return messages, stream_pos |
| 344 | + # Create and fill a dictionary of (user ID, device ID) -> list of messages |
| 345 | + # intended for each device. |
| 346 | + recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} |
| 347 | + for message_count, row in enumerate(txn, start=1): |
| 348 | + last_processed_stream_pos = row[0] |
| 349 | + recipient_user_id = row[1] |
| 350 | + recipient_device_id = row[2] |
| 351 | + message_dict = db_to_json(row[3]) |
| 352 | + |
| 353 | + # Store the device details |
| 354 | + recipient_device_to_messages.setdefault( |
| 355 | + (recipient_user_id, recipient_device_id), [] |
| 356 | + ).append(message_dict) |
| 357 | + |
| 358 | + if limit and message_count == limit: |
| 359 | + # We ended up hitting the message limit. There may be more messages to retrieve. |
| 360 | + # Return what we have, as well as the last stream position that was processed. |
| 361 | + # |
| 362 | + # The caller is expected to set this as the lower (exclusive) bound |
| 363 | + # for the next query of this device. |
| 364 | + return recipient_device_to_messages, last_processed_stream_pos |
| 365 | + |
| 366 | + # The limit was not reached, thus we know that recipient_device_to_messages |
| 367 | + # contains all to-device messages for the given device and stream id range. |
| 368 | + # |
| 369 | + # We return to_stream_id, which the caller should then provide as the lower |
| 370 | + # (exclusive) bound on the next query of this device. |
| 371 | + return recipient_device_to_messages, to_stream_id |
265 | 372 |
|
266 | 373 | return await self.db_pool.runInteraction(
|
267 |
| - "get_new_messages_for_device", get_new_messages_for_device_txn |
| 374 | + "get_new_device_messages", get_new_device_messages_txn |
268 | 375 | )
|
269 | 376 |
|
270 | 377 | @trace
|
|
0 commit comments