1818from typing import (
1919 TYPE_CHECKING ,
2020 Awaitable ,
21- Container ,
21+ Callable ,
22+ Dict ,
2223 Iterable ,
2324 List ,
2425 Optional ,
@@ -217,19 +218,19 @@ def get_filter_json(self) -> JsonDict:
217218 return self ._filter_json
218219
219220 def timeline_limit (self ) -> int :
220- return self ._room_timeline_filter .limit ()
221+ return self ._room_timeline_filter .limit
221222
222223 def presence_limit (self ) -> int :
223- return self ._presence_filter .limit ()
224+ return self ._presence_filter .limit
224225
225226 def ephemeral_limit (self ) -> int :
226- return self ._room_ephemeral_filter .limit ()
227+ return self ._room_ephemeral_filter .limit
227228
228229 def lazy_load_members (self ) -> bool :
229- return self ._room_state_filter .lazy_load_members ()
230+ return self ._room_state_filter .lazy_load_members
230231
231232 def include_redundant_members (self ) -> bool :
232- return self ._room_state_filter .include_redundant_members ()
233+ return self ._room_state_filter .include_redundant_members
233234
234235 def filter_presence (
235236 self , events : Iterable [UserPresenceState ]
@@ -276,19 +277,25 @@ class Filter:
276277 def __init__ (self , filter_json : JsonDict ):
277278 self .filter_json = filter_json
278279
279- self .types = self .filter_json .get ("types" , None )
280- self .not_types = self .filter_json .get ("not_types" , [])
280+ self .limit = filter_json .get ("limit" , 10 )
281+ self .lazy_load_members = filter_json .get ("lazy_load_members" , False )
282+ self .include_redundant_members = filter_json .get (
283+ "include_redundant_members" , False
284+ )
285+
286+ self .types = filter_json .get ("types" , None )
287+ self .not_types = filter_json .get ("not_types" , [])
281288
282- self .rooms = self . filter_json .get ("rooms" , None )
283- self .not_rooms = self . filter_json .get ("not_rooms" , [])
289+ self .rooms = filter_json .get ("rooms" , None )
290+ self .not_rooms = filter_json .get ("not_rooms" , [])
284291
285- self .senders = self . filter_json .get ("senders" , None )
286- self .not_senders = self . filter_json .get ("not_senders" , [])
292+ self .senders = filter_json .get ("senders" , None )
293+ self .not_senders = filter_json .get ("not_senders" , [])
287294
288- self .contains_url = self . filter_json .get ("contains_url" , None )
295+ self .contains_url = filter_json .get ("contains_url" , None )
289296
290- self .labels = self . filter_json .get ("org.matrix.labels" , None )
291- self .not_labels = self . filter_json .get ("org.matrix.not_labels" , [])
297+ self .labels = filter_json .get ("org.matrix.labels" , None )
298+ self .not_labels = filter_json .get ("org.matrix.not_labels" , [])
292299
293300 def filters_all_types (self ) -> bool :
294301 return "*" in self .not_types
@@ -302,76 +309,95 @@ def filters_all_rooms(self) -> bool:
302309 def check (self , event : FilterEvent ) -> bool :
303310 """Checks whether the filter matches the given event.
304311
312+ Args:
313+ event: The event, account data, or presence to check against this
314+ filter.
315+
305316 Returns:
306- True if the event matches
317+ True if the event matches the filter.
307318 """
308319 # We usually get the full "events" as dictionaries coming through,
309320 # except for presence which actually gets passed around as its own
310321 # namedtuple type.
311322 if isinstance (event , UserPresenceState ):
312- sender : Optional [str ] = event .user_id
313- room_id = None
314- ev_type = "m.presence"
315- contains_url = False
316- labels : List [str ] = []
323+ user_id = event .user_id
324+ field_matchers = {
325+ "senders" : lambda v : user_id == v ,
326+ "types" : lambda v : "m.presence" == v ,
327+ }
328+ return self ._check_fields (field_matchers )
317329 else :
330+ content = event .get ("content" )
331+ # Content is assumed to be a dict below, so ensure it is. This should
332+ # always be true for events, but account_data has been allowed to
333+ # have non-dict content.
334+ if not isinstance (content , dict ):
335+ content = {}
336+
318337 sender = event .get ("sender" , None )
319338 if not sender :
320339 # Presence events had their 'sender' in content.user_id, but are
321340 # now handled above. We don't know if anything else uses this
322341 # form. TODO: Check this and probably remove it.
323- content = event .get ("content" )
324- # account_data has been allowed to have non-dict content, so
325- # check type first
326- if isinstance (content , dict ):
327- sender = content .get ("user_id" )
342+ sender = content .get ("user_id" )
328343
329344 room_id = event .get ("room_id" , None )
330345 ev_type = event .get ("type" , None )
331346
332- content = event .get ("content" ) or {}
333347 # check if there is a string url field in the content for filtering purposes
334- contains_url = isinstance (content .get ("url" ), str )
335348 labels = content .get (EventContentFields .LABELS , [])
336349
337- return self .check_fields (room_id , sender , ev_type , labels , contains_url )
350+ field_matchers = {
351+ "rooms" : lambda v : room_id == v ,
352+ "senders" : lambda v : sender == v ,
353+ "types" : lambda v : _matches_wildcard (ev_type , v ),
354+ "labels" : lambda v : v in labels ,
355+ }
356+
357+ result = self ._check_fields (field_matchers )
358+ if not result :
359+ return result
360+
361+ contains_url_filter = self .contains_url
362+ if contains_url_filter is not None :
363+ contains_url = isinstance (content .get ("url" ), str )
364+ if contains_url_filter != contains_url :
365+ return False
366+
367+ return True
338368
339- def check_fields (
340- self ,
341- room_id : Optional [str ],
342- sender : Optional [str ],
343- event_type : Optional [str ],
344- labels : Container [str ],
345- contains_url : bool ,
346- ) -> bool :
369+ def _check_fields (self , field_matchers : Dict [str , Callable [[str ], bool ]]) -> bool :
347370 """Checks whether the filter matches the given event fields.
348371
372+ Args:
373+ field_matchers: A map of attribute name to callable to use for checking
374+ particular fields.
375+
376+ The attribute name and an inverse (not_<attribute name>) must
377+ exist on the Filter.
378+
379+ The callable should return true if the event's value matches the
380+ filter's value.
381+
349382 Returns:
350383 True if the event fields match
351384 """
352- literal_keys = {
353- "rooms" : lambda v : room_id == v ,
354- "senders" : lambda v : sender == v ,
355- "types" : lambda v : _matches_wildcard (event_type , v ),
356- "labels" : lambda v : v in labels ,
357- }
358-
359- for name , match_func in literal_keys .items ():
385+
386+ for name , match_func in field_matchers .items ():
387+ # If the event matches one of the disallowed values, reject it.
360388 not_name = "not_%s" % (name ,)
361389 disallowed_values = getattr (self , not_name )
362390 if any (map (match_func , disallowed_values )):
363391 return False
364392
393+ # Other the event does not match at least one of the allowed values,
394+ # reject it.
365395 allowed_values = getattr (self , name )
366396 if allowed_values is not None :
367397 if not any (map (match_func , allowed_values )):
368398 return False
369399
370- contains_url_filter = self .filter_json .get ("contains_url" )
371- if contains_url_filter is not None :
372- if contains_url_filter != contains_url :
373- return False
374-
400+ # Otherwise, accept it.
375401 return True
376402
377403 def filter_rooms (self , room_ids : Iterable [str ]) -> Set [str ]:
@@ -385,10 +411,10 @@ def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
385411 """
386412 room_ids = set (room_ids )
387413
388- disallowed_rooms = set (self .filter_json . get ( " not_rooms" , []) )
414+ disallowed_rooms = set (self .not_rooms )
389415 room_ids -= disallowed_rooms
390416
391- allowed_rooms = self .filter_json . get ( " rooms" , None )
417+ allowed_rooms = self .rooms
392418 if allowed_rooms is not None :
393419 room_ids &= set (allowed_rooms )
394420
@@ -397,15 +423,6 @@ def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
397423 def filter (self , events : Iterable [FilterEvent ]) -> List [FilterEvent ]:
398424 return list (filter (self .check , events ))
399425
400- def limit (self ) -> int :
401- return self .filter_json .get ("limit" , 10 )
402-
403- def lazy_load_members (self ) -> bool :
404- return self .filter_json .get ("lazy_load_members" , False )
405-
406- def include_redundant_members (self ) -> bool :
407- return self .filter_json .get ("include_redundant_members" , False )
408-
409426 def with_room_ids (self , room_ids : Iterable [str ]) -> "Filter" :
410427 """Returns a new filter with the given room IDs appended.
411428
0 commit comments