Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Addtional type hints for the REST servlets. (#10665)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Aug 23, 2021
1 parent 31dac7f commit 2af6d31
Show file tree
Hide file tree
Showing 14 changed files with 204 additions and 107 deletions.
1 change: 1 addition & 0 deletions changelog.d/10665.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to REST servlets.
39 changes: 17 additions & 22 deletions synapse/rest/client/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,27 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet
from twisted.web.server import Request

from synapse.http.server import HttpServer, respond_with_html
from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict

from ._base import client_patterns

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$")

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()

self.hs = hs
Expand All @@ -46,18 +49,14 @@ def __init__(self, hs):
hs.config.account_validity.account_validity_invalid_token_template
)

async def on_GET(self, request):
if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
async def on_GET(self, request: Request) -> None:
renewal_token = parse_string(request, "token", required=True)

(
token_valid,
token_stale,
expiration_ts,
) = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
)
) = await self.account_activity_handler.renew_account(renewal_token)

if token_valid:
status_code = 200
Expand All @@ -77,11 +76,7 @@ async def on_GET(self, request):
class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/send_mail$")

def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()

self.hs = hs
Expand All @@ -91,14 +86,14 @@ def __init__(self, hs):
hs.config.account_validity.account_validity_renew_by_email_enabled
)

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
await self.account_activity_handler.send_renewal_email_to_user(user_id)

return 200, {}


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountValidityRenewServlet(hs).register(http_server)
AccountValiditySendMailServlet(hs).register(http_server)
3 changes: 2 additions & 1 deletion synapse/rest/client/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Tuple

from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
Expand Down Expand Up @@ -75,5 +76,5 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return 200, response


def register_servlets(hs: "HomeServer", http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
CapabilitiesRestServlet(hs).register(http_server)
78 changes: 49 additions & 29 deletions synapse/rest/client/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
from typing import TYPE_CHECKING, Tuple

from twisted.web.server import Request

from synapse.api.errors import (
AuthError,
Expand All @@ -22,14 +24,19 @@
NotFoundError,
SynapseError,
)
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import RoomAlias
from synapse.types import JsonDict, RoomAlias

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ClientDirectoryServer(hs).register(http_server)
ClientDirectoryListServer(hs).register(http_server)
ClientAppserviceDirectoryListServer(hs).register(http_server)
Expand All @@ -38,21 +45,23 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()

async def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
room_alias_obj = RoomAlias.from_string(room_alias)

res = await self.directory_handler.get_association(room_alias)
res = await self.directory_handler.get_association(room_alias_obj)

return 200, res

async def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
async def on_PUT(
self, request: SynapseRequest, room_alias: str
) -> Tuple[int, JsonDict]:
room_alias_obj = RoomAlias.from_string(room_alias)

content = parse_json_object_from_request(request)
if "room_id" not in content:
Expand All @@ -61,7 +70,7 @@ async def on_PUT(self, request, room_alias):
)

logger.debug("Got content: %s", content)
logger.debug("Got room name: %s", room_alias.to_string())
logger.debug("Got room name: %s", room_alias_obj.to_string())

room_id = content["room_id"]
servers = content["servers"] if "servers" in content else None
Expand All @@ -78,22 +87,25 @@ async def on_PUT(self, request, room_alias):
requester = await self.auth.get_user_by_req(request)

await self.directory_handler.create_association(
requester, room_alias, room_id, servers
requester, room_alias_obj, room_id, servers
)

return 200, {}

async def on_DELETE(self, request, room_alias):
async def on_DELETE(
self, request: SynapseRequest, room_alias: str
) -> Tuple[int, JsonDict]:
room_alias_obj = RoomAlias.from_string(room_alias)

try:
service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
await self.directory_handler.delete_appservice_association(
service, room_alias
service, room_alias_obj
)
logger.info(
"Application service at %s deleted alias %s",
service.url,
room_alias.to_string(),
room_alias_obj.to_string(),
)
return 200, {}
except InvalidClientCredentialsError:
Expand All @@ -103,12 +115,10 @@ async def on_DELETE(self, request, room_alias):
requester = await self.auth.get_user_by_req(request)
user = requester.user

room_alias = RoomAlias.from_string(room_alias)

await self.directory_handler.delete_association(requester, room_alias)
await self.directory_handler.delete_association(requester, room_alias_obj)

logger.info(
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
"User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
)

return 200, {}
Expand All @@ -117,20 +127,22 @@ async def on_DELETE(self, request, room_alias):
class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()

async def on_GET(self, request, room_id):
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
room = await self.store.get_room(room_id)
if room is None:
raise NotFoundError("Unknown room")

return 200, {"visibility": "public" if room["is_public"] else "private"}

async def on_PUT(self, request, room_id):
async def on_PUT(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

content = parse_json_object_from_request(request)
Expand All @@ -142,7 +154,9 @@ async def on_PUT(self, request, room_id):

return 200, {}

async def on_DELETE(self, request, room_id):
async def on_DELETE(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

await self.directory_handler.edit_published_room_list(
Expand All @@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet):
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()

def on_PUT(self, request, network_id, room_id):
async def on_PUT(
self, request: SynapseRequest, network_id: str, room_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
return self._edit(request, network_id, room_id, visibility)
return await self._edit(request, network_id, room_id, visibility)

def on_DELETE(self, request, network_id, room_id):
return self._edit(request, network_id, room_id, "private")
async def on_DELETE(
self, request: SynapseRequest, network_id: str, room_id: str
) -> Tuple[int, JsonDict]:
return await self._edit(request, network_id, room_id, "private")

async def _edit(self, request, network_id, room_id, visibility):
async def _edit(
self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
Expand Down
22 changes: 14 additions & 8 deletions synapse/rest/client/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,36 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Tuple

from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import respond_with_html_bytes
from synapse.http.server import HttpServer, respond_with_html_bytes
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()

async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user

Expand All @@ -50,14 +56,14 @@ async def on_GET(self, request):
class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()

async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = requester.user

Expand Down Expand Up @@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()

async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user

Expand All @@ -165,7 +171,7 @@ async def on_GET(self, request):
return None


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)
Loading

0 comments on commit 2af6d31

Please sign in to comment.