Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit b359061

Browse files
authored
Require type hints in the handlers module. (#10831)
Adds missing type hints to methods in the synapse.handlers module and requires all methods to have type hints there. This also removes the unused construct_auth_difference method from the FederationHandler.
1 parent 4379617 commit b359061

35 files changed

+194
-295
lines changed

changelog.d/10831.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add missing type hints to handlers.

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ files =
9191
tests/util/test_itertools.py,
9292
tests/util/test_stream_change_cache.py
9393

94+
[mypy-synapse.handlers.*]
95+
disallow_untyped_defs = True
96+
9497
[mypy-synapse.rest.*]
9598
disallow_untyped_defs = True
9699

synapse/config/password_auth_providers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, List
15+
from typing import Any, List, Tuple, Type
1616

1717
from synapse.util.module_loader import load_module
1818

@@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
2525
section = "authproviders"
2626

2727
def read_config(self, config, **kwargs):
28-
self.password_providers: List[Any] = []
28+
self.password_providers: List[Tuple[Type, Any]] = []
2929
providers = []
3030

3131
# We want to be backwards compatible with the old `ldap_config`

synapse/handlers/_base.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import TYPE_CHECKING, Optional
1717

1818
from synapse.api.ratelimiting import Ratelimiter
19+
from synapse.types import Requester
1920

2021
if TYPE_CHECKING:
2122
from synapse.server import HomeServer
@@ -63,16 +64,21 @@ def __init__(self, hs: "HomeServer"):
6364

6465
self.event_builder_factory = hs.get_event_builder_factory()
6566

66-
async def ratelimit(self, requester, update=True, is_admin_redaction=False):
67+
async def ratelimit(
68+
self,
69+
requester: Requester,
70+
update: bool = True,
71+
is_admin_redaction: bool = False,
72+
) -> None:
6773
"""Ratelimits requests.
6874
6975
Args:
70-
requester (Requester)
71-
update (bool): Whether to record that a request is being processed.
76+
requester
77+
update: Whether to record that a request is being processed.
7278
Set to False when doing multiple checks for one request (e.g.
7379
to check up front if we would reject the request), and set to
7480
True for the last call for a given request.
75-
is_admin_redaction (bool): Whether this is a room admin/moderator
81+
is_admin_redaction: Whether this is a room admin/moderator
7682
redacting an event. If so then we may apply different
7783
ratelimits depending on config.
7884

synapse/handlers/account_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import random
16-
from typing import TYPE_CHECKING, List, Tuple
16+
from typing import TYPE_CHECKING, Any, List, Tuple
1717

1818
from synapse.replication.http.account_data import (
1919
ReplicationAddTagRestServlet,
@@ -171,7 +171,7 @@ def get_current_key(self, direction: str = "f") -> int:
171171
return self.store.get_max_account_data_stream_id()
172172

173173
async def get_new_events(
174-
self, user: UserID, from_key: int, **kwargs
174+
self, user: UserID, from_key: int, **kwargs: Any
175175
) -> Tuple[List[JsonDict], int]:
176176
user_id = user.to_string()
177177
last_stream_id = from_key

synapse/handlers/account_validity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def register_account_validity_callbacks(
9999
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
100100
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
101101
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
102-
):
102+
) -> None:
103103
"""Register callbacks from module for each hook."""
104104
if is_user_expired is not None:
105105
self._is_user_expired_callbacks.append(is_user_expired)
@@ -165,7 +165,7 @@ async def is_user_expired(self, user_id: str) -> bool:
165165

166166
return False
167167

168-
async def on_user_registration(self, user_id: str):
168+
async def on_user_registration(self, user_id: str) -> None:
169169
"""Tell third-party modules about a user's registration.
170170
171171
Args:

synapse/handlers/appservice.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union
15+
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
1616

1717
from prometheus_client import Counter
1818

@@ -58,7 +58,7 @@ def __init__(self, hs: "HomeServer"):
5858
self.current_max = 0
5959
self.is_processing = False
6060

61-
def notify_interested_services(self, max_token: RoomStreamToken):
61+
def notify_interested_services(self, max_token: RoomStreamToken) -> None:
6262
"""Notifies (pushes) all application services interested in this event.
6363
6464
Pushing is done asynchronously, so this method won't block for any
@@ -82,7 +82,7 @@ def notify_interested_services(self, max_token: RoomStreamToken):
8282
self._notify_interested_services(max_token)
8383

8484
@wrap_as_background_process("notify_interested_services")
85-
async def _notify_interested_services(self, max_token: RoomStreamToken):
85+
async def _notify_interested_services(self, max_token: RoomStreamToken) -> None:
8686
with Measure(self.clock, "notify_interested_services"):
8787
self.is_processing = True
8888
try:
@@ -100,7 +100,7 @@ async def _notify_interested_services(self, max_token: RoomStreamToken):
100100
for event in events:
101101
events_by_room.setdefault(event.room_id, []).append(event)
102102

103-
async def handle_event(event):
103+
async def handle_event(event: EventBase) -> None:
104104
# Gather interested services
105105
services = await self._get_services_for_event(event)
106106
if len(services) == 0:
@@ -116,9 +116,9 @@ async def handle_event(event):
116116

117117
if not self.started_scheduler:
118118

119-
async def start_scheduler():
119+
async def start_scheduler() -> None:
120120
try:
121-
return await self.scheduler.start()
121+
await self.scheduler.start()
122122
except Exception:
123123
logger.error("Application Services Failure")
124124

@@ -137,7 +137,7 @@ async def start_scheduler():
137137
"appservice_sender"
138138
).observe((now - ts) / 1000)
139139

140-
async def handle_room_events(events):
140+
async def handle_room_events(events: Iterable[EventBase]) -> None:
141141
for event in events:
142142
await handle_event(event)
143143

@@ -184,7 +184,7 @@ def notify_interested_services_ephemeral(
184184
stream_key: str,
185185
new_token: Optional[int],
186186
users: Optional[Collection[Union[str, UserID]]] = None,
187-
):
187+
) -> None:
188188
"""This is called by the notifier in the background
189189
when a ephemeral event handled by the homeserver.
190190
@@ -226,7 +226,7 @@ async def _notify_interested_services_ephemeral(
226226
stream_key: str,
227227
new_token: Optional[int],
228228
users: Collection[Union[str, UserID]],
229-
):
229+
) -> None:
230230
logger.debug("Checking interested services for %s" % (stream_key))
231231
with Measure(self.clock, "notify_interested_services_ephemeral"):
232232
for service in services:

synapse/handlers/auth.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Mapping,
3030
Optional,
3131
Tuple,
32+
Type,
3233
Union,
3334
cast,
3435
)
@@ -439,7 +440,7 @@ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
439440

440441
return ui_auth_types
441442

442-
def get_enabled_auth_types(self):
443+
def get_enabled_auth_types(self) -> Iterable[str]:
443444
"""Return the enabled user-interactive authentication types
444445
445446
Returns the UI-Auth types which are supported by the homeserver's current
@@ -702,7 +703,7 @@ async def get_session_data(
702703
except StoreError:
703704
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
704705

705-
async def _expire_old_sessions(self):
706+
async def _expire_old_sessions(self) -> None:
706707
"""
707708
Invalidate any user interactive authentication sessions that have expired.
708709
"""
@@ -1352,7 +1353,7 @@ async def validate_short_term_login_token(
13521353
await self.auth.check_auth_blocking(res.user_id)
13531354
return res
13541355

1355-
async def delete_access_token(self, access_token: str):
1356+
async def delete_access_token(self, access_token: str) -> None:
13561357
"""Invalidate a single access token
13571358
13581359
Args:
@@ -1381,7 +1382,7 @@ async def delete_access_tokens_for_user(
13811382
user_id: str,
13821383
except_token_id: Optional[int] = None,
13831384
device_id: Optional[str] = None,
1384-
):
1385+
) -> None:
13851386
"""Invalidate access tokens belonging to a user
13861387
13871388
Args:
@@ -1409,7 +1410,7 @@ async def delete_access_tokens_for_user(
14091410

14101411
async def add_threepid(
14111412
self, user_id: str, medium: str, address: str, validated_at: int
1412-
):
1413+
) -> None:
14131414
# check if medium has a valid value
14141415
if medium not in ["email", "msisdn"]:
14151416
raise SynapseError(
@@ -1480,7 +1481,7 @@ async def hash(self, password: str) -> str:
14801481
Hashed password.
14811482
"""
14821483

1483-
def _do_hash():
1484+
def _do_hash() -> str:
14841485
# Normalise the Unicode in the password
14851486
pw = unicodedata.normalize("NFKC", password)
14861487

@@ -1504,7 +1505,7 @@ async def validate_hash(
15041505
Whether self.hash(password) == stored_hash.
15051506
"""
15061507

1507-
def _do_validate_hash(checked_hash: bytes):
1508+
def _do_validate_hash(checked_hash: bytes) -> bool:
15081509
# Normalise the Unicode in the password
15091510
pw = unicodedata.normalize("NFKC", password)
15101511

@@ -1581,7 +1582,7 @@ async def complete_sso_login(
15811582
client_redirect_url: str,
15821583
extra_attributes: Optional[JsonDict] = None,
15831584
new_user: bool = False,
1584-
):
1585+
) -> None:
15851586
"""Having figured out a mxid for this user, complete the HTTP request
15861587
15871588
Args:
@@ -1627,7 +1628,7 @@ def _complete_sso_login(
16271628
extra_attributes: Optional[JsonDict] = None,
16281629
new_user: bool = False,
16291630
user_profile_data: Optional[ProfileInfo] = None,
1630-
):
1631+
) -> None:
16311632
"""
16321633
The synchronous portion of complete_sso_login.
16331634
@@ -1726,17 +1727,17 @@ def _expire_sso_extra_attributes(self) -> None:
17261727
del self._extra_attributes[user_id]
17271728

17281729
@staticmethod
1729-
def add_query_param_to_url(url: str, param_name: str, param: Any):
1730+
def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
17301731
url_parts = list(urllib.parse.urlparse(url))
17311732
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
17321733
query.append((param_name, param))
17331734
url_parts[4] = urllib.parse.urlencode(query)
17341735
return urllib.parse.urlunparse(url_parts)
17351736

17361737

1737-
@attr.s(slots=True)
1738+
@attr.s(slots=True, auto_attribs=True)
17381739
class MacaroonGenerator:
1739-
hs = attr.ib()
1740+
hs: "HomeServer"
17401741

17411742
def generate_guest_access_token(self, user_id: str) -> str:
17421743
macaroon = self._generate_base_macaroon(user_id)
@@ -1816,15 +1817,17 @@ class PasswordProvider:
18161817
"""
18171818

18181819
@classmethod
1819-
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
1820+
def load(
1821+
cls, module: Type, config: JsonDict, module_api: ModuleApi
1822+
) -> "PasswordProvider":
18201823
try:
18211824
pp = module(config=config, account_handler=module_api)
18221825
except Exception as e:
18231826
logger.error("Error while initializing %r: %s", module, e)
18241827
raise
18251828
return cls(pp, module_api)
18261829

1827-
def __init__(self, pp, module_api: ModuleApi):
1830+
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
18281831
self._pp = pp
18291832
self._module_api = module_api
18301833

@@ -1838,7 +1841,7 @@ def __init__(self, pp, module_api: ModuleApi):
18381841
if g:
18391842
self._supported_login_types.update(g())
18401843

1841-
def __str__(self):
1844+
def __str__(self) -> str:
18421845
return str(self._pp)
18431846

18441847
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
@@ -1876,19 +1879,19 @@ async def check_auth(
18761879
"""
18771880
# first grandfather in a call to check_password
18781881
if login_type == LoginType.PASSWORD:
1879-
g = getattr(self._pp, "check_password", None)
1880-
if g:
1882+
check_password = getattr(self._pp, "check_password", None)
1883+
if check_password:
18811884
qualified_user_id = self._module_api.get_qualified_user_id(username)
1882-
is_valid = await self._pp.check_password(
1885+
is_valid = await check_password(
18831886
qualified_user_id, login_dict["password"]
18841887
)
18851888
if is_valid:
18861889
return qualified_user_id, None
18871890

1888-
g = getattr(self._pp, "check_auth", None)
1889-
if not g:
1891+
check_auth = getattr(self._pp, "check_auth", None)
1892+
if not check_auth:
18901893
return None
1891-
result = await g(username, login_type, login_dict)
1894+
result = await check_auth(username, login_type, login_dict)
18921895

18931896
# Check if the return value is a str or a tuple
18941897
if isinstance(result, str):

synapse/handlers/cas.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,20 @@
3434
class CasError(Exception):
3535
"""Used to catch errors when validating the CAS ticket."""
3636

37-
def __init__(self, error, error_description=None):
37+
def __init__(self, error: str, error_description: Optional[str] = None):
3838
self.error = error
3939
self.error_description = error_description
4040

41-
def __str__(self):
41+
def __str__(self) -> str:
4242
if self.error_description:
4343
return f"{self.error}: {self.error_description}"
4444
return self.error
4545

4646

47-
@attr.s(slots=True, frozen=True)
47+
@attr.s(slots=True, frozen=True, auto_attribs=True)
4848
class CasResponse:
49-
username = attr.ib(type=str)
50-
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
49+
username: str
50+
attributes: Dict[str, List[Optional[str]]]
5151

5252

5353
class CasHandler:
@@ -133,11 +133,9 @@ async def _validate_ticket(
133133
body = pde.response
134134
except HttpResponseException as e:
135135
description = (
136-
(
137-
'Authorization server responded with a "{status}" error '
138-
"while exchanging the authorization code."
139-
).format(status=e.code),
140-
)
136+
'Authorization server responded with a "{status}" error '
137+
"while exchanging the authorization code."
138+
).format(status=e.code)
141139
raise CasError("server_error", description) from e
142140

143141
return self._parse_cas_response(body)

synapse/handlers/device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def __init__(self, hs: "HomeServer"):
267267

268268
hs.get_distributor().observe("user_left_room", self.user_left_room)
269269

270-
def _check_device_name_length(self, name: Optional[str]):
270+
def _check_device_name_length(self, name: Optional[str]) -> None:
271271
"""
272272
Checks whether a device name is longer than the maximum allowed length.
273273

0 commit comments

Comments
 (0)