Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Misc typing fixes for tests, part 1 of N #11323

Merged
merged 5 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11323.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type annotations in Synapse's test suite.
4 changes: 3 additions & 1 deletion synapse/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable

from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
Expand Down Expand Up @@ -62,6 +62,8 @@
if TYPE_CHECKING:
from synapse.server import HomeServer

RegisterServletsFunc = Callable[["HomeServer", HttpServer], None]


class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.
Expand Down
3 changes: 2 additions & 1 deletion synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Mapping,
MutableMapping,
Expand Down Expand Up @@ -219,7 +220,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
'domain' : The domain part of the name
"""

SIGIL: str = abc.abstractproperty() # type: ignore
SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore

localpart = attr.ib(type=str)
domain = attr.ib(type=str)
Expand Down
5 changes: 1 addition & 4 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from twisted.internet.protocol import Protocol
from twisted.web.resource import Resource

from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
Expand Down Expand Up @@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""

servlets: List[Callable[[HomeServer, JsonResource], None]] = []

def setUp(self):
super().setUp()

Expand Down
22 changes: 14 additions & 8 deletions tests/rest/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@
import re
import time
import urllib.parse
from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
from typing import (
Any,
AnyStr,
Dict,
Iterable,
Mapping,
MutableMapping,
Optional,
Tuple,
Union,
)
from unittest.mock import patch

import attr
Expand Down Expand Up @@ -53,9 +63,7 @@ def create_room_as(
tok: Optional[str] = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> str:
"""
Create a room.
Expand Down Expand Up @@ -227,9 +235,7 @@ def send(
txn_id=None,
tok=None,
expect_code=200,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if body is None:
body = "body_text_here"
Expand Down Expand Up @@ -418,7 +424,7 @@ def upload_media(
path,
content=image_data,
access_token=tok,
custom_headers=[(b"Content-Length", str(image_length))],
custom_headers=[("Content-Length", str(image_length))],
)

assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
Expand Down
15 changes: 11 additions & 4 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
import logging
from collections import deque
from io import SEEK_END, BytesIO
from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
from typing import (
AnyStr,
Callable,
Dict,
Iterable,
MutableMapping,
Optional,
Tuple,
Union,
)

import attr
from typing_extensions import Deque
Expand Down Expand Up @@ -222,9 +231,7 @@ def make_request(
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Expand Down
32 changes: 21 additions & 11 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,20 @@
import logging
import secrets
import time
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
AnyStr,
Callable,
ClassVar,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from unittest.mock import Mock, patch

from canonicaljson import json
Expand All @@ -45,6 +58,7 @@
current_context,
set_current_context,
)
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
Expand Down Expand Up @@ -204,15 +218,15 @@ class HomeserverTestCase(TestCase):
config dict.

Attributes:
servlets (list[function]): List of servlet registration function.
servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id.
"""

servlets = []
hijack_auth = True
needs_threadpool = False
servlets: ClassVar[List[RegisterServletsFunc]] = []

def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
Expand Down Expand Up @@ -405,12 +419,10 @@ def make_request(
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Expand All @@ -425,7 +437,7 @@ def make_request(
a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
Expand Down Expand Up @@ -639,9 +651,7 @@ def login(
username,
password,
device_id=None,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
"""
Log in a user, and get an access token. Requires the Login API be
Expand Down