15
15
# See the License for the specific language governing permissions and
16
16
# limitations under the License.
17
17
import json
18
- from typing import List
18
+ from typing import (
19
+ TYPE_CHECKING ,
20
+ Awaitable ,
21
+ Container ,
22
+ Iterable ,
23
+ List ,
24
+ Optional ,
25
+ Set ,
26
+ TypeVar ,
27
+ Union ,
28
+ )
19
29
20
30
import jsonschema
21
31
from jsonschema import FormatChecker
22
32
23
33
from synapse .api .constants import EventContentFields
24
34
from synapse .api .errors import SynapseError
25
35
from synapse .api .presence import UserPresenceState
26
- from synapse .types import RoomID , UserID
36
+ from synapse .events import EventBase
37
+ from synapse .types import JsonDict , RoomID , UserID
38
+
39
+ if TYPE_CHECKING :
40
+ from synapse .server import HomeServer
27
41
28
42
FILTER_SCHEMA = {
29
43
"additionalProperties" : False ,
120
134
121
135
122
136
@FormatChecker .cls_checks ("matrix_room_id" )
123
- def matrix_room_id_validator (room_id_str ) :
137
+ def matrix_room_id_validator (room_id_str : str ) -> RoomID :
124
138
return RoomID .from_string (room_id_str )
125
139
126
140
127
141
@FormatChecker .cls_checks ("matrix_user_id" )
128
- def matrix_user_id_validator (user_id_str ) :
142
+ def matrix_user_id_validator (user_id_str : str ) -> UserID :
129
143
return UserID .from_string (user_id_str )
130
144
131
145
132
146
class Filtering :
133
- def __init__ (self , hs ):
147
+ def __init__ (self , hs : "HomeServer" ):
134
148
super ().__init__ ()
135
149
self .store = hs .get_datastore ()
136
150
137
- async def get_user_filter (self , user_localpart , filter_id ):
151
+ async def get_user_filter (
152
+ self , user_localpart : str , filter_id : Union [int , str ]
153
+ ) -> "FilterCollection" :
138
154
result = await self .store .get_user_filter (user_localpart , filter_id )
139
155
return FilterCollection (result )
140
156
141
- def add_user_filter (self , user_localpart , user_filter ):
157
+ def add_user_filter (
158
+ self , user_localpart : str , user_filter : JsonDict
159
+ ) -> Awaitable [int ]:
142
160
self .check_valid_filter (user_filter )
143
161
return self .store .add_user_filter (user_localpart , user_filter )
144
162
145
163
# TODO(paul): surely we should probably add a delete_user_filter or
146
164
# replace_user_filter at some point? There's no REST API specified for
147
165
# them however
148
166
149
- def check_valid_filter (self , user_filter_json ) :
167
+ def check_valid_filter (self , user_filter_json : JsonDict ) -> None :
150
168
"""Check if the provided filter is valid.
151
169
152
170
This inspects all definitions contained within the filter.
153
171
154
172
Args:
155
- user_filter_json(dict) : The filter
173
+ user_filter_json: The filter
156
174
Raises:
157
175
SynapseError: If the filter is not valid.
158
176
"""
@@ -167,8 +185,12 @@ def check_valid_filter(self, user_filter_json):
167
185
raise SynapseError (400 , str (e ))
168
186
169
187
188
+ # Filters work across events, presence EDUs, and account data.
189
+ FilterEvent = TypeVar ("FilterEvent" , EventBase , UserPresenceState , JsonDict )
190
+
191
+
170
192
class FilterCollection :
171
- def __init__ (self , filter_json ):
193
+ def __init__ (self , filter_json : JsonDict ):
172
194
self ._filter_json = filter_json
173
195
174
196
room_filter_json = self ._filter_json .get ("room" , {})
@@ -188,25 +210,25 @@ def __init__(self, filter_json):
188
210
self .event_fields = filter_json .get ("event_fields" , [])
189
211
self .event_format = filter_json .get ("event_format" , "client" )
190
212
191
- def __repr__ (self ):
213
+ def __repr__ (self ) -> str :
192
214
return "<FilterCollection %s>" % (json .dumps (self ._filter_json ),)
193
215
194
- def get_filter_json (self ):
216
+ def get_filter_json (self ) -> JsonDict :
195
217
return self ._filter_json
196
218
197
- def timeline_limit (self ):
219
+ def timeline_limit (self ) -> int :
198
220
return self ._room_timeline_filter .limit ()
199
221
200
- def presence_limit (self ):
222
+ def presence_limit (self ) -> int :
201
223
return self ._presence_filter .limit ()
202
224
203
- def ephemeral_limit (self ):
225
+ def ephemeral_limit (self ) -> int :
204
226
return self ._room_ephemeral_filter .limit ()
205
227
206
- def lazy_load_members (self ):
228
+ def lazy_load_members (self ) -> bool :
207
229
return self ._room_state_filter .lazy_load_members ()
208
230
209
- def include_redundant_members (self ):
231
+ def include_redundant_members (self ) -> bool :
210
232
return self ._room_state_filter .include_redundant_members ()
211
233
212
234
def filter_presence (self , events ):
@@ -218,29 +240,31 @@ def filter_account_data(self, events):
218
240
def filter_room_state (self , events ):
219
241
return self ._room_state_filter .filter (self ._room_filter .filter (events ))
220
242
221
- def filter_room_timeline (self , events ) :
243
+ def filter_room_timeline (self , events : Iterable [ FilterEvent ]) -> List [ FilterEvent ] :
222
244
return self ._room_timeline_filter .filter (self ._room_filter .filter (events ))
223
245
224
- def filter_room_ephemeral (self , events ) :
246
+ def filter_room_ephemeral (self , events : Iterable [ FilterEvent ]) -> List [ FilterEvent ] :
225
247
return self ._room_ephemeral_filter .filter (self ._room_filter .filter (events ))
226
248
227
- def filter_room_account_data (self , events ):
249
+ def filter_room_account_data (
250
+ self , events : Iterable [FilterEvent ]
251
+ ) -> List [FilterEvent ]:
228
252
return self ._room_account_data .filter (self ._room_filter .filter (events ))
229
253
230
- def blocks_all_presence (self ):
254
+ def blocks_all_presence (self ) -> bool :
231
255
return (
232
256
self ._presence_filter .filters_all_types ()
233
257
or self ._presence_filter .filters_all_senders ()
234
258
)
235
259
236
- def blocks_all_room_ephemeral (self ):
260
+ def blocks_all_room_ephemeral (self ) -> bool :
237
261
return (
238
262
self ._room_ephemeral_filter .filters_all_types ()
239
263
or self ._room_ephemeral_filter .filters_all_senders ()
240
264
or self ._room_ephemeral_filter .filters_all_rooms ()
241
265
)
242
266
243
- def blocks_all_room_timeline (self ):
267
+ def blocks_all_room_timeline (self ) -> bool :
244
268
return (
245
269
self ._room_timeline_filter .filters_all_types ()
246
270
or self ._room_timeline_filter .filters_all_senders ()
@@ -249,7 +273,7 @@ def blocks_all_room_timeline(self):
249
273
250
274
251
275
class Filter :
252
- def __init__ (self , filter_json ):
276
+ def __init__ (self , filter_json : JsonDict ):
253
277
self .filter_json = filter_json
254
278
255
279
self .types = self .filter_json .get ("types" , None )
@@ -266,20 +290,20 @@ def __init__(self, filter_json):
266
290
self .labels = self .filter_json .get ("org.matrix.labels" , None )
267
291
self .not_labels = self .filter_json .get ("org.matrix.not_labels" , [])
268
292
269
- def filters_all_types (self ):
293
+ def filters_all_types (self ) -> bool :
270
294
return "*" in self .not_types
271
295
272
- def filters_all_senders (self ):
296
+ def filters_all_senders (self ) -> bool :
273
297
return "*" in self .not_senders
274
298
275
- def filters_all_rooms (self ):
299
+ def filters_all_rooms (self ) -> bool :
276
300
return "*" in self .not_rooms
277
301
278
- def check (self , event ) :
302
+ def check (self , event : FilterEvent ) -> bool :
279
303
"""Checks whether the filter matches the given event.
280
304
281
305
Returns:
282
- bool: True if the event matches
306
+ True if the event matches
283
307
"""
284
308
# We usually get the full "events" as dictionaries coming through,
285
309
# except for presence which actually gets passed around as its own
@@ -305,18 +329,25 @@ def check(self, event):
305
329
room_id = event .get ("room_id" , None )
306
330
ev_type = event .get ("type" , None )
307
331
308
- content = event .get ("content" , {})
332
+ content = event .get ("content" ) or {}
309
333
# check if there is a string url field in the content for filtering purposes
310
334
contains_url = isinstance (content .get ("url" ), str )
311
335
labels = content .get (EventContentFields .LABELS , [])
312
336
313
337
return self .check_fields (room_id , sender , ev_type , labels , contains_url )
314
338
315
- def check_fields (self , room_id , sender , event_type , labels , contains_url ):
339
+ def check_fields (
340
+ self ,
341
+ room_id : Optional [str ],
342
+ sender : Optional [str ],
343
+ event_type : Optional [str ],
344
+ labels : Container [str ],
345
+ contains_url : bool ,
346
+ ) -> bool :
316
347
"""Checks whether the filter matches the given event fields.
317
348
318
349
Returns:
319
- bool: True if the event fields match
350
+ True if the event fields match
320
351
"""
321
352
literal_keys = {
322
353
"rooms" : lambda v : room_id == v ,
@@ -343,14 +374,14 @@ def check_fields(self, room_id, sender, event_type, labels, contains_url):
343
374
344
375
return True
345
376
346
- def filter_rooms (self , room_ids ) :
377
+ def filter_rooms (self , room_ids : Iterable [ str ]) -> Set [ str ] :
347
378
"""Apply the 'rooms' filter to a given list of rooms.
348
379
349
380
Args:
350
- room_ids (list) : A list of room_ids.
381
+ room_ids: A list of room_ids.
351
382
352
383
Returns:
353
- list: A list of room_ids that match the filter
384
+ A list of room_ids that match the filter
354
385
"""
355
386
room_ids = set (room_ids )
356
387
@@ -363,23 +394,23 @@ def filter_rooms(self, room_ids):
363
394
364
395
return room_ids
365
396
366
- def filter (self , events ) :
397
+ def filter (self , events : Iterable [ FilterEvent ]) -> List [ FilterEvent ] :
367
398
return list (filter (self .check , events ))
368
399
369
- def limit (self ):
400
+ def limit (self ) -> int :
370
401
return self .filter_json .get ("limit" , 10 )
371
402
372
- def lazy_load_members (self ):
403
+ def lazy_load_members (self ) -> bool :
373
404
return self .filter_json .get ("lazy_load_members" , False )
374
405
375
- def include_redundant_members (self ):
406
+ def include_redundant_members (self ) -> bool :
376
407
return self .filter_json .get ("include_redundant_members" , False )
377
408
378
- def with_room_ids (self , room_ids ) :
409
+ def with_room_ids (self , room_ids : Iterable [ str ]) -> "Filter" :
379
410
"""Returns a new filter with the given room IDs appended.
380
411
381
412
Args:
382
- room_ids (iterable[unicode]) : The room_ids to add
413
+ room_ids: The room_ids to add
383
414
384
415
Returns:
385
416
filter: A new filter including the given rooms and the old
@@ -390,8 +421,8 @@ def with_room_ids(self, room_ids):
390
421
return newFilter
391
422
392
423
393
- def _matches_wildcard (actual_value , filter_value ) :
394
- if filter_value .endswith ("*" ):
424
+ def _matches_wildcard (actual_value : Optional [ str ] , filter_value : str ) -> bool :
425
+ if filter_value .endswith ("*" ) and isinstance ( actual_value , str ) :
395
426
type_prefix = filter_value [:- 1 ]
396
427
return actual_value .startswith (type_prefix )
397
428
else :
0 commit comments