1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from typing import Collection , Optional
1516
1617from synapse .api .constants import ReceiptTypes
1718from synapse .types import UserID , create_requester
@@ -84,6 +85,33 @@ def prepare(self, reactor, clock, homeserver) -> None:
8485 )
8586 )
8687
88+ def get_last_unthreaded_receipt (
89+ self , receipt_types : Collection [str ], room_id : Optional [str ] = None
90+ ) -> Optional [str ]:
91+ """
92+ Fetch the event ID for the latest unthreaded receipt in the test room for the test user.
93+
94+ Args:
95+ receipt_types: The receipt types to fetch.
96+
97+ Returns:
98+ The latest receipt, if one exists.
99+ """
100+ result = self .get_success (
101+ self .store .db_pool .runInteraction (
102+ "get_last_receipt_event_id_for_user" ,
103+ self .store .get_last_unthreaded_receipt_for_user_txn ,
104+ OUR_USER_ID ,
105+ room_id or self .room_id1 ,
106+ receipt_types ,
107+ )
108+ )
109+ if not result :
110+ return None
111+
112+ event_id , _ = result
113+ return event_id
114+
87115 def test_return_empty_with_no_data (self ) -> None :
88116 res = self .get_success (
89117 self .store .get_receipts_for_user (
@@ -107,16 +135,10 @@ def test_return_empty_with_no_data(self) -> None:
107135 )
108136 self .assertEqual (res , {})
109137
110- res = self .get_success (
111- self .store .get_last_receipt_event_id_for_user (
112- OUR_USER_ID ,
113- self .room_id1 ,
114- [
115- ReceiptTypes .READ ,
116- ReceiptTypes .READ_PRIVATE ,
117- ],
118- )
138+ res = self .get_last_unthreaded_receipt (
139+ [ReceiptTypes .READ , ReceiptTypes .READ_PRIVATE ]
119140 )
141+
120142 self .assertEqual (res , None )
121143
122144 def test_get_receipts_for_user (self ) -> None :
@@ -228,29 +250,17 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
228250 )
229251
230252 # Test we get the latest event when we want both private and public receipts
231- res = self .get_success (
232- self .store .get_last_receipt_event_id_for_user (
233- OUR_USER_ID ,
234- self .room_id1 ,
235- [ReceiptTypes .READ , ReceiptTypes .READ_PRIVATE ],
236- )
253+ res = self .get_last_unthreaded_receipt (
254+ [ReceiptTypes .READ , ReceiptTypes .READ_PRIVATE ]
237255 )
238256 self .assertEqual (res , event1_2_id )
239257
240258 # Test we get the older event when we want only public receipt
241- res = self .get_success (
242- self .store .get_last_receipt_event_id_for_user (
243- OUR_USER_ID , self .room_id1 , [ReceiptTypes .READ ]
244- )
245- )
259+ res = self .get_last_unthreaded_receipt ([ReceiptTypes .READ ])
246260 self .assertEqual (res , event1_1_id )
247261
248262 # Test we get the latest event when we want only the private receipt
249- res = self .get_success (
250- self .store .get_last_receipt_event_id_for_user (
251- OUR_USER_ID , self .room_id1 , [ReceiptTypes .READ_PRIVATE ]
252- )
253- )
263+ res = self .get_last_unthreaded_receipt ([ReceiptTypes .READ_PRIVATE ])
254264 self .assertEqual (res , event1_2_id )
255265
256266 # Test receipt updating
@@ -259,11 +269,7 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
259269 self .room_id1 , ReceiptTypes .READ , OUR_USER_ID , [event1_2_id ], None , {}
260270 )
261271 )
262- res = self .get_success (
263- self .store .get_last_receipt_event_id_for_user (
264- OUR_USER_ID , self .room_id1 , [ReceiptTypes .READ ]
265- )
266- )
272+ res = self .get_last_unthreaded_receipt ([ReceiptTypes .READ ])
267273 self .assertEqual (res , event1_2_id )
268274
269275 # Send some events into the second room
@@ -282,11 +288,7 @@ def test_get_last_receipt_event_id_for_user(self) -> None:
282288 {},
283289 )
284290 )
285- res = self .get_success (
286- self .store .get_last_receipt_event_id_for_user (
287- OUR_USER_ID ,
288- self .room_id2 ,
289- [ReceiptTypes .READ , ReceiptTypes .READ_PRIVATE ],
290- )
291+ res = self .get_last_unthreaded_receipt (
292+ [ReceiptTypes .READ , ReceiptTypes .READ_PRIVATE ], room_id = self .room_id2
291293 )
292294 self .assertEqual (res , event2_1_id )
0 commit comments