@@ -437,6 +437,7 @@ def _get_unread_counts_by_pos_txn(
437437 """
438438
439439 counts = NotifCounts ()
440+ thread_counts = {}
440441
441442 # First we pull the counts from the summary table.
442443 #
@@ -453,7 +454,7 @@ def _get_unread_counts_by_pos_txn(
453454 # receipt).
454455 txn .execute (
455456 """
456- SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
457+ SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id
457458 FROM event_push_summary
458459 WHERE room_id = ? AND user_id = ?
459460 AND (
@@ -463,42 +464,70 @@ def _get_unread_counts_by_pos_txn(
463464 """ ,
464465 (room_id , user_id , receipt_stream_ordering , receipt_stream_ordering ),
465466 )
466- row = txn .fetchone ()
467+ max_summary_stream_ordering = 0
468+ for summary_stream_ordering , notif_count , unread_count , thread_id in txn :
469+ if thread_id == "main" :
470+ counts = NotifCounts (
471+ notify_count = notif_count , unread_count = unread_count
472+ )
473+ # TODO Delete zeroed out threads completely from the database.
474+ elif notif_count or unread_count :
475+ thread_counts [thread_id ] = NotifCounts (
476+ notify_count = notif_count , unread_count = unread_count
477+ )
467478
468- summary_stream_ordering = 0
469- if row :
470- summary_stream_ordering = row [0 ]
471- counts .notify_count += row [1 ]
472- counts .unread_count += row [2 ]
479+ # XXX All threads should have the same stream ordering?
480+ max_summary_stream_ordering = max (
481+ summary_stream_ordering , max_summary_stream_ordering
482+ )
473483
474484 # Next we need to count highlights, which aren't summarised
475485 sql = """
476- SELECT COUNT(*) FROM event_push_actions
486+ SELECT COUNT(*), thread_id FROM event_push_actions
477487 WHERE user_id = ?
478488 AND room_id = ?
479489 AND stream_ordering > ?
480490 AND highlight = 1
491+ GROUP BY thread_id
481492 """
482493 txn .execute (sql , (user_id , room_id , receipt_stream_ordering ))
483- row = txn .fetchone ()
484- if row :
485- counts .highlight_count += row [0 ]
494+ for highlight_count , thread_id in txn :
495+ if thread_id == "main" :
496+ counts .highlight_count += highlight_count
497+ elif highlight_count :
498+ if thread_id in thread_counts :
499+ thread_counts [thread_id ].highlight_count += highlight_count
500+ else :
501+ thread_counts [thread_id ] = NotifCounts (
502+ notify_count = 0 , unread_count = 0 , highlight_count = highlight_count
503+ )
486504
487505 # Finally we need to count push actions that aren't included in the
488506 # summary returned above. This might be due to recent events that haven't
489507 # been summarised yet or the summary is out of date due to a recent read
490508 # receipt.
491509 start_unread_stream_ordering = max (
492- receipt_stream_ordering , summary_stream_ordering
510+ receipt_stream_ordering , max_summary_stream_ordering
493511 )
494- notify_count , unread_count = self ._get_notif_unread_count_for_user_room (
512+ unread_counts = self ._get_notif_unread_count_for_user_room (
495513 txn , room_id , user_id , start_unread_stream_ordering
496514 )
497515
498- counts .notify_count += notify_count
499- counts .unread_count += unread_count
516+ for notif_count , unread_count , thread_id in unread_counts :
517+ if thread_id == "main" :
518+ counts .notify_count += notif_count
519+ counts .unread_count += unread_count
520+ elif thread_id in thread_counts :
521+ thread_counts [thread_id ].notify_count += notif_count
522+ thread_counts [thread_id ].unread_count += unread_count
523+ else :
524+ thread_counts [thread_id ] = NotifCounts (
525+ notify_count = notif_count ,
526+ unread_count = unread_count ,
527+ highlight_count = 0 ,
528+ )
500529
501- return RoomNotifCounts (counts , {} )
530+ return RoomNotifCounts (counts , thread_counts )
502531
503532 def _get_notif_unread_count_for_user_room (
504533 self ,
@@ -507,7 +536,7 @@ def _get_notif_unread_count_for_user_room(
507536 user_id : str ,
508537 stream_ordering : int ,
509538 max_stream_ordering : Optional [int ] = None ,
510- ) -> Tuple [int , int ]:
539+ ) -> List [ Tuple [int , int , str ] ]:
511540 """Returns the notify and unread counts from `event_push_actions` for
512541 the given user/room in the given range.
513542
@@ -523,13 +552,14 @@ def _get_notif_unread_count_for_user_room(
523552 If this is not given, then no maximum is applied.
524553
525554 Return:
526- A tuple of the notif count and unread count in the given range.
555+ A tuple of the notif count and unread count in the given range for
556+ each thread.
527557 """
528558
529559 # If there have been no events in the room since the stream ordering,
530560 # there can't be any push actions either.
531561 if not self ._events_stream_cache .has_entity_changed (room_id , stream_ordering ):
532- return 0 , 0
562+ return []
533563
534564 clause = ""
535565 args = [user_id , room_id , stream_ordering ]
@@ -540,26 +570,23 @@ def _get_notif_unread_count_for_user_room(
540570 # If the max stream ordering is less than the min stream ordering,
541571 # then obviously there are zero push actions in that range.
542572 if max_stream_ordering <= stream_ordering :
543- return 0 , 0
573+ return []
544574
545575 sql = f"""
546576 SELECT
547577 COUNT(CASE WHEN notif = 1 THEN 1 END),
548- COUNT(CASE WHEN unread = 1 THEN 1 END)
549- FROM event_push_actions ea
550- WHERE user_id = ?
578+ COUNT(CASE WHEN unread = 1 THEN 1 END),
579+ thread_id
580+ FROM event_push_actions ea
581+ WHERE user_id = ?
551582 AND room_id = ?
552583 AND ea.stream_ordering > ?
553584 { clause }
585+ GROUP BY thread_id
554586 """
555587
556588 txn .execute (sql , args )
557- row = txn .fetchone ()
558-
559- if row :
560- return cast (Tuple [int , int ], row )
561-
562- return 0 , 0
589+ return cast (List [Tuple [int , int , str ]], txn .fetchall ())
563590
564591 async def get_push_action_users_in_range (
565592 self , min_stream_ordering : int , max_stream_ordering : int
@@ -1103,26 +1130,34 @@ def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool:
11031130
11041131 # Fetch the notification counts between the stream ordering of the
11051132 # latest receipt and what was previously summarised.
1106- notif_count , unread_count = self ._get_notif_unread_count_for_user_room (
1133+ unread_counts = self ._get_notif_unread_count_for_user_room (
11071134 txn , room_id , user_id , stream_ordering , old_rotate_stream_ordering
11081135 )
11091136
1110- # Replace the previous summary with the new counts.
1111- #
1112- # TODO(threads): Upsert per-thread instead of setting them all to main.
1113- self .db_pool .simple_upsert_txn (
1137+ # First mark the summary for all threads in the room as cleared.
1138+ self .db_pool .simple_update_txn (
11141139 txn ,
11151140 table = "event_push_summary" ,
1116- keyvalues = {"room_id " : room_id , "user_id " : user_id },
1117- values = {
1118- "notif_count" : notif_count ,
1119- "unread_count" : unread_count ,
1141+ keyvalues = {"user_id " : user_id , "room_id " : room_id },
1142+ updatevalues = {
1143+ "notif_count" : 0 ,
1144+ "unread_count" : 0 ,
11201145 "stream_ordering" : old_rotate_stream_ordering ,
11211146 "last_receipt_stream_ordering" : stream_ordering ,
1122- "thread_id" : "main" ,
11231147 },
11241148 )
11251149
1150+ # Then any updated threads get their notification count and unread
1151+ # count updated.
1152+ self .db_pool .simple_upsert_many_txn (
1153+ txn ,
1154+ table = "event_push_summary" ,
1155+ key_names = ("room_id" , "user_id" , "thread_id" ),
1156+ key_values = [(room_id , user_id , row [2 ]) for row in unread_counts ],
1157+ value_names = ("notif_count" , "unread_count" ),
1158+ value_values = [(row [0 ], row [1 ]) for row in unread_counts ],
1159+ )
1160+
11261161 # We always update `event_push_summary_last_receipt_stream_id` to
11271162 # ensure that we don't rescan the same receipts for remote users.
11281163
@@ -1208,23 +1243,23 @@ def _rotate_notifs_before_txn(
12081243
12091244 # Calculate the new counts that should be upserted into event_push_summary
12101245 sql = """
1211- SELECT user_id, room_id,
1246+ SELECT user_id, room_id, thread_id,
12121247 coalesce(old.%s, 0) + upd.cnt,
12131248 upd.stream_ordering
12141249 FROM (
1215- SELECT user_id, room_id, count(*) as cnt,
1250+ SELECT user_id, room_id, thread_id, count(*) as cnt,
12161251 max(ea.stream_ordering) as stream_ordering
12171252 FROM event_push_actions AS ea
1218- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
1253+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id )
12191254 WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
12201255 AND (
12211256 old.last_receipt_stream_ordering IS NULL
12221257 OR old.last_receipt_stream_ordering < ea.stream_ordering
12231258 )
12241259 AND %s = 1
1225- GROUP BY user_id, room_id
1260+ GROUP BY user_id, room_id, thread_id
12261261 ) AS upd
1227- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
1262+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id )
12281263 """
12291264
12301265 # First get the count of unread messages.
@@ -1238,11 +1273,11 @@ def _rotate_notifs_before_txn(
12381273 # object because we might not have the same amount of rows in each of them. To do
12391274 # this, we use a dict indexed on the user ID and room ID to make it easier to
12401275 # populate.
1241- summaries : Dict [Tuple [str , str ], _EventPushSummary ] = {}
1276+ summaries : Dict [Tuple [str , str , str ], _EventPushSummary ] = {}
12421277 for row in txn :
1243- summaries [(row [0 ], row [1 ])] = _EventPushSummary (
1244- unread_count = row [2 ],
1245- stream_ordering = row [3 ],
1278+ summaries [(row [0 ], row [1 ], row [ 2 ] )] = _EventPushSummary (
1279+ unread_count = row [3 ],
1280+ stream_ordering = row [4 ],
12461281 notif_count = 0 ,
12471282 )
12481283
@@ -1253,34 +1288,35 @@ def _rotate_notifs_before_txn(
12531288 )
12541289
12551290 for row in txn :
1256- if (row [0 ], row [1 ]) in summaries :
1257- summaries [(row [0 ], row [1 ])].notif_count = row [2 ]
1291+ if (row [0 ], row [1 ], row [ 2 ] ) in summaries :
1292+ summaries [(row [0 ], row [1 ], row [ 2 ] )].notif_count = row [3 ]
12581293 else :
12591294 # Because the rules on notifying are different than the rules on marking
12601295 # a message unread, we might end up with messages that notify but aren't
12611296 # marked unread, so we might not have a summary for this (user, room)
12621297 # tuple to complete.
1263- summaries [(row [0 ], row [1 ])] = _EventPushSummary (
1298+ summaries [(row [0 ], row [1 ], row [ 2 ] )] = _EventPushSummary (
12641299 unread_count = 0 ,
1265- stream_ordering = row [3 ],
1266- notif_count = row [2 ],
1300+ stream_ordering = row [4 ],
1301+ notif_count = row [3 ],
12671302 )
12681303
12691304 logger .info ("Rotating notifications, handling %d rows" , len (summaries ))
12701305
1271- # TODO(threads): Update on a per-thread basis.
12721306 self .db_pool .simple_upsert_many_txn (
12731307 txn ,
12741308 table = "event_push_summary" ,
1275- key_names = ("user_id" , "room_id" ),
1276- key_values = [(user_id , room_id ) for user_id , room_id in summaries ],
1277- value_names = ("notif_count" , "unread_count" , "stream_ordering" , "thread_id" ),
1309+ key_names = ("user_id" , "room_id" , "thread_id" ),
1310+ key_values = [
1311+ (user_id , room_id , thread_id )
1312+ for user_id , room_id , thread_id in summaries
1313+ ],
1314+ value_names = ("notif_count" , "unread_count" , "stream_ordering" ),
12781315 value_values = [
12791316 (
12801317 summary .notif_count ,
12811318 summary .unread_count ,
12821319 summary .stream_ordering ,
1283- "main" ,
12841320 )
12851321 for summary in summaries .values ()
12861322 ],
0 commit comments