|
41 | 41 | from synapse.replication.tcp.streams.events import EventsStream
|
42 | 42 | from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
43 | 43 | from synapse.storage.database import Database
|
| 44 | +from synapse.storage.types import Cursor |
44 | 45 | from synapse.storage.util.id_generators import StreamIdGenerator
|
45 | 46 | from synapse.types import get_domain_from_id
|
46 |
| -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks |
| 47 | +from synapse.util.caches.descriptors import ( |
| 48 | + Cache, |
| 49 | + _CacheContext, |
| 50 | + cached, |
| 51 | + cachedInlineCallbacks, |
| 52 | +) |
47 | 53 | from synapse.util.iterutils import batch_iter
|
48 | 54 | from synapse.util.metrics import Measure
|
49 | 55 |
|
@@ -1358,6 +1364,84 @@ def get_next_event_to_expire_txn(txn):
|
1358 | 1364 | desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
1359 | 1365 | )
|
1360 | 1366 |
|
| 1367 | + @cached(tree=True, cache_context=True) |
| 1368 | + async def get_unread_message_count_for_user( |
| 1369 | + self, room_id: str, user_id: str, cache_context: _CacheContext, |
| 1370 | + ) -> int: |
| 1371 | + """Retrieve the count of unread messages for the given room and user. |
| 1372 | +
|
| 1373 | + Args: |
| 1374 | + room_id: The ID of the room to count unread messages in. |
| 1375 | + user_id: The ID of the user to count unread messages for. |
| 1376 | +
|
| 1377 | + Returns: |
| 1378 | + The number of unread messages for the given user in the given room. |
| 1379 | + """ |
| 1380 | + with Measure(self._clock, "get_unread_message_count_for_user"): |
| 1381 | + last_read_event_id = await self.get_last_receipt_event_id_for_user( |
| 1382 | + user_id=user_id, |
| 1383 | + room_id=room_id, |
| 1384 | + receipt_type="m.read", |
| 1385 | + on_invalidate=cache_context.invalidate, |
| 1386 | + ) |
| 1387 | + |
| 1388 | + return await self.db.runInteraction( |
| 1389 | + "get_unread_message_count_for_user", |
| 1390 | + self._get_unread_message_count_for_user_txn, |
| 1391 | + user_id, |
| 1392 | + room_id, |
| 1393 | + last_read_event_id, |
| 1394 | + ) |
| 1395 | + |
| 1396 | + def _get_unread_message_count_for_user_txn( |
| 1397 | + self, |
| 1398 | + txn: Cursor, |
| 1399 | + user_id: str, |
| 1400 | + room_id: str, |
| 1401 | + last_read_event_id: Optional[str], |
| 1402 | + ) -> int: |
| 1403 | + if last_read_event_id: |
| 1404 | + # Get the stream ordering for the last read event. |
| 1405 | + stream_ordering = self.db.simple_select_one_onecol_txn( |
| 1406 | + txn=txn, |
| 1407 | + table="events", |
| 1408 | + keyvalues={"room_id": room_id, "event_id": last_read_event_id}, |
| 1409 | + retcol="stream_ordering", |
| 1410 | + ) |
| 1411 | + else: |
| 1412 | + # If there's no read receipt for that room, it probably means the user hasn't |
| 1413 | + # opened it yet, in which case use the stream ID of their join event. |
| 1414 | + # We can't just set it to 0 otherwise messages from other local users from |
| 1415 | + # before this user joined will be counted as well. |
| 1416 | + txn.execute( |
| 1417 | + """ |
| 1418 | + SELECT stream_ordering FROM local_current_membership |
| 1419 | + LEFT JOIN events USING (event_id, room_id) |
| 1420 | + WHERE membership = 'join' |
| 1421 | + AND user_id = ? |
| 1422 | + AND room_id = ? |
| 1423 | + """, |
| 1424 | + (user_id, room_id), |
| 1425 | + ) |
| 1426 | + row = txn.fetchone() |
| 1427 | + |
| 1428 | + if row is None: |
| 1429 | + return 0 |
| 1430 | + |
| 1431 | + stream_ordering = row[0] |
| 1432 | + |
| 1433 | + # Count the messages that qualify as unread after the stream ordering we've just |
| 1434 | + # retrieved. |
| 1435 | + sql = """ |
| 1436 | + SELECT COUNT(*) FROM events |
| 1437 | + WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread |
| 1438 | + """ |
| 1439 | + |
| 1440 | + txn.execute(sql, (user_id, room_id, stream_ordering)) |
| 1441 | + row = txn.fetchone() |
| 1442 | + |
| 1443 | + return row[0] if row else 0 |
| 1444 | + |
1361 | 1445 |
|
1362 | 1446 | AllNewEventsResult = namedtuple(
|
1363 | 1447 | "AllNewEventsResult",
|
|
0 commit comments