Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 7e44052

Browse files
authored
Add type hints to filtering classes. (#10958)
1 parent 9e5a429 commit 7e44052

File tree

3 files changed

+81
-45
lines changed

3 files changed

+81
-45
lines changed

changelog.d/10958.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add type hints to filtering classes.

synapse/api/filtering.py

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,29 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
import json
18-
from typing import List
18+
from typing import (
19+
TYPE_CHECKING,
20+
Awaitable,
21+
Container,
22+
Iterable,
23+
List,
24+
Optional,
25+
Set,
26+
TypeVar,
27+
Union,
28+
)
1929

2030
import jsonschema
2131
from jsonschema import FormatChecker
2232

2333
from synapse.api.constants import EventContentFields
2434
from synapse.api.errors import SynapseError
2535
from synapse.api.presence import UserPresenceState
26-
from synapse.types import RoomID, UserID
36+
from synapse.events import EventBase
37+
from synapse.types import JsonDict, RoomID, UserID
38+
39+
if TYPE_CHECKING:
40+
from synapse.server import HomeServer
2741

2842
FILTER_SCHEMA = {
2943
"additionalProperties": False,
@@ -120,39 +134,43 @@
120134

121135

122136
@FormatChecker.cls_checks("matrix_room_id")
123-
def matrix_room_id_validator(room_id_str):
137+
def matrix_room_id_validator(room_id_str: str) -> RoomID:
124138
return RoomID.from_string(room_id_str)
125139

126140

127141
@FormatChecker.cls_checks("matrix_user_id")
128-
def matrix_user_id_validator(user_id_str):
142+
def matrix_user_id_validator(user_id_str: str) -> UserID:
129143
return UserID.from_string(user_id_str)
130144

131145

132146
class Filtering:
133-
def __init__(self, hs):
147+
def __init__(self, hs: "HomeServer"):
134148
super().__init__()
135149
self.store = hs.get_datastore()
136150

137-
async def get_user_filter(self, user_localpart, filter_id):
151+
async def get_user_filter(
152+
self, user_localpart: str, filter_id: Union[int, str]
153+
) -> "FilterCollection":
138154
result = await self.store.get_user_filter(user_localpart, filter_id)
139155
return FilterCollection(result)
140156

141-
def add_user_filter(self, user_localpart, user_filter):
157+
def add_user_filter(
158+
self, user_localpart: str, user_filter: JsonDict
159+
) -> Awaitable[int]:
142160
self.check_valid_filter(user_filter)
143161
return self.store.add_user_filter(user_localpart, user_filter)
144162

145163
# TODO(paul): surely we should probably add a delete_user_filter or
146164
# replace_user_filter at some point? There's no REST API specified for
147165
# them however
148166

149-
def check_valid_filter(self, user_filter_json):
167+
def check_valid_filter(self, user_filter_json: JsonDict) -> None:
150168
"""Check if the provided filter is valid.
151169
152170
This inspects all definitions contained within the filter.
153171
154172
Args:
155-
user_filter_json(dict): The filter
173+
user_filter_json: The filter
156174
Raises:
157175
SynapseError: If the filter is not valid.
158176
"""
@@ -167,8 +185,12 @@ def check_valid_filter(self, user_filter_json):
167185
raise SynapseError(400, str(e))
168186

169187

188+
# Filters work across events, presence EDUs, and account data.
189+
FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
190+
191+
170192
class FilterCollection:
171-
def __init__(self, filter_json):
193+
def __init__(self, filter_json: JsonDict):
172194
self._filter_json = filter_json
173195

174196
room_filter_json = self._filter_json.get("room", {})
@@ -188,25 +210,25 @@ def __init__(self, filter_json):
188210
self.event_fields = filter_json.get("event_fields", [])
189211
self.event_format = filter_json.get("event_format", "client")
190212

191-
def __repr__(self):
213+
def __repr__(self) -> str:
192214
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
193215

194-
def get_filter_json(self):
216+
def get_filter_json(self) -> JsonDict:
195217
return self._filter_json
196218

197-
def timeline_limit(self):
219+
def timeline_limit(self) -> int:
198220
return self._room_timeline_filter.limit()
199221

200-
def presence_limit(self):
222+
def presence_limit(self) -> int:
201223
return self._presence_filter.limit()
202224

203-
def ephemeral_limit(self):
225+
def ephemeral_limit(self) -> int:
204226
return self._room_ephemeral_filter.limit()
205227

206-
def lazy_load_members(self):
228+
def lazy_load_members(self) -> bool:
207229
return self._room_state_filter.lazy_load_members()
208230

209-
def include_redundant_members(self):
231+
def include_redundant_members(self) -> bool:
210232
return self._room_state_filter.include_redundant_members()
211233

212234
def filter_presence(self, events):
@@ -218,29 +240,31 @@ def filter_account_data(self, events):
218240
def filter_room_state(self, events):
219241
return self._room_state_filter.filter(self._room_filter.filter(events))
220242

221-
def filter_room_timeline(self, events):
243+
def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
222244
return self._room_timeline_filter.filter(self._room_filter.filter(events))
223245

224-
def filter_room_ephemeral(self, events):
246+
def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
225247
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
226248

227-
def filter_room_account_data(self, events):
249+
def filter_room_account_data(
250+
self, events: Iterable[FilterEvent]
251+
) -> List[FilterEvent]:
228252
return self._room_account_data.filter(self._room_filter.filter(events))
229253

230-
def blocks_all_presence(self):
254+
def blocks_all_presence(self) -> bool:
231255
return (
232256
self._presence_filter.filters_all_types()
233257
or self._presence_filter.filters_all_senders()
234258
)
235259

236-
def blocks_all_room_ephemeral(self):
260+
def blocks_all_room_ephemeral(self) -> bool:
237261
return (
238262
self._room_ephemeral_filter.filters_all_types()
239263
or self._room_ephemeral_filter.filters_all_senders()
240264
or self._room_ephemeral_filter.filters_all_rooms()
241265
)
242266

243-
def blocks_all_room_timeline(self):
267+
def blocks_all_room_timeline(self) -> bool:
244268
return (
245269
self._room_timeline_filter.filters_all_types()
246270
or self._room_timeline_filter.filters_all_senders()
@@ -249,7 +273,7 @@ def blocks_all_room_timeline(self):
249273

250274

251275
class Filter:
252-
def __init__(self, filter_json):
276+
def __init__(self, filter_json: JsonDict):
253277
self.filter_json = filter_json
254278

255279
self.types = self.filter_json.get("types", None)
@@ -266,20 +290,20 @@ def __init__(self, filter_json):
266290
self.labels = self.filter_json.get("org.matrix.labels", None)
267291
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
268292

269-
def filters_all_types(self):
293+
def filters_all_types(self) -> bool:
270294
return "*" in self.not_types
271295

272-
def filters_all_senders(self):
296+
def filters_all_senders(self) -> bool:
273297
return "*" in self.not_senders
274298

275-
def filters_all_rooms(self):
299+
def filters_all_rooms(self) -> bool:
276300
return "*" in self.not_rooms
277301

278-
def check(self, event):
302+
def check(self, event: FilterEvent) -> bool:
279303
"""Checks whether the filter matches the given event.
280304
281305
Returns:
282-
bool: True if the event matches
306+
True if the event matches
283307
"""
284308
# We usually get the full "events" as dictionaries coming through,
285309
# except for presence which actually gets passed around as its own
@@ -305,18 +329,25 @@ def check(self, event):
305329
room_id = event.get("room_id", None)
306330
ev_type = event.get("type", None)
307331

308-
content = event.get("content", {})
332+
content = event.get("content") or {}
309333
# check if there is a string url field in the content for filtering purposes
310334
contains_url = isinstance(content.get("url"), str)
311335
labels = content.get(EventContentFields.LABELS, [])
312336

313337
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
314338

315-
def check_fields(self, room_id, sender, event_type, labels, contains_url):
339+
def check_fields(
340+
self,
341+
room_id: Optional[str],
342+
sender: Optional[str],
343+
event_type: Optional[str],
344+
labels: Container[str],
345+
contains_url: bool,
346+
) -> bool:
316347
"""Checks whether the filter matches the given event fields.
317348
318349
Returns:
319-
bool: True if the event fields match
350+
True if the event fields match
320351
"""
321352
literal_keys = {
322353
"rooms": lambda v: room_id == v,
@@ -343,14 +374,14 @@ def check_fields(self, room_id, sender, event_type, labels, contains_url):
343374

344375
return True
345376

346-
def filter_rooms(self, room_ids):
377+
def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
347378
"""Apply the 'rooms' filter to a given list of rooms.
348379
349380
Args:
350-
room_ids (list): A list of room_ids.
381+
room_ids: A list of room_ids.
351382
352383
Returns:
353-
list: A list of room_ids that match the filter
384+
A list of room_ids that match the filter
354385
"""
355386
room_ids = set(room_ids)
356387

@@ -363,23 +394,23 @@ def filter_rooms(self, room_ids):
363394

364395
return room_ids
365396

366-
def filter(self, events):
397+
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
367398
return list(filter(self.check, events))
368399

369-
def limit(self):
400+
def limit(self) -> int:
370401
return self.filter_json.get("limit", 10)
371402

372-
def lazy_load_members(self):
403+
def lazy_load_members(self) -> bool:
373404
return self.filter_json.get("lazy_load_members", False)
374405

375-
def include_redundant_members(self):
406+
def include_redundant_members(self) -> bool:
376407
return self.filter_json.get("include_redundant_members", False)
377408

378-
def with_room_ids(self, room_ids):
409+
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
379410
"""Returns a new filter with the given room IDs appended.
380411
381412
Args:
382-
room_ids (iterable[unicode]): The room_ids to add
413+
room_ids: The room_ids to add
383414
384415
Returns:
385416
filter: A new filter including the given rooms and the old
@@ -390,8 +421,8 @@ def with_room_ids(self, room_ids):
390421
return newFilter
391422

392423

393-
def _matches_wildcard(actual_value, filter_value):
394-
if filter_value.endswith("*"):
424+
def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
425+
if filter_value.endswith("*") and isinstance(actual_value, str):
395426
type_prefix = filter_value[:-1]
396427
return actual_value.startswith(type_prefix)
397428
else:

synapse/storage/databases/main/filtering.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Union
16+
1517
from canonicaljson import encode_canonical_json
1618

1719
from synapse.api.errors import Codes, SynapseError
@@ -22,7 +24,9 @@
2224

2325
class FilteringStore(SQLBaseStore):
2426
@cached(num_args=2)
25-
async def get_user_filter(self, user_localpart, filter_id):
27+
async def get_user_filter(
28+
self, user_localpart: str, filter_id: Union[int, str]
29+
) -> JsonDict:
2630
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
2731
# with a coherent error message rather than 500 M_UNKNOWN.
2832
try:
@@ -40,7 +44,7 @@ async def get_user_filter(self, user_localpart, filter_id):
4044

4145
return db_to_json(def_json)
4246

43-
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
47+
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
4448
def_json = encode_canonical_json(user_filter)
4549

4650
# Need an atomic transaction to SELECT the maximal ID so far then

0 commit comments

Comments
 (0)