Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to appservice api.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Nov 16, 2021
1 parent bf838f5 commit 2a1449f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ disallow_untyped_defs = True
[mypy-synapse.app.*]
disallow_untyped_defs = True

[mypy-synapse.appservice.*]
disallow_untyped_defs = True

[mypy-synapse.crypto.*]
disallow_untyped_defs = True

Expand Down
48 changes: 35 additions & 13 deletions synapse/appservice/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, List, Optional, Tuple
import urllib.parse
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple

from prometheus_client import Counter

Expand Down Expand Up @@ -53,15 +53,15 @@
APP_SERVICE_PREFIX = "/_matrix/app/unstable"


def _is_valid_3pe_metadata(info):
def _is_valid_3pe_metadata(info: JsonDict) -> bool:
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
return False
return True


def _is_valid_3pe_result(r, field):
def _is_valid_3pe_result(r: JsonDict, field: str) -> bool:
if not isinstance(r, dict):
return False

Expand Down Expand Up @@ -93,9 +93,13 @@ def __init__(self, hs: "HomeServer"):
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
)

async def query_user(self, service, user_id):
async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
if service.url is None:
return False

# This is required by the configuration.
assert service.hs_token is not None

uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
try:
response = await self.get_json(uri, {"access_token": service.hs_token})
Expand All @@ -109,9 +113,13 @@ async def query_user(self, service, user_id):
logger.warning("query_user to %s threw exception %s", uri, ex)
return False

async def query_alias(self, service, alias):
async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
if service.url is None:
return False

# This is required by the configuration.
assert service.hs_token is not None

uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
try:
response = await self.get_json(uri, {"access_token": service.hs_token})
Expand All @@ -125,7 +133,13 @@ async def query_alias(self, service, alias):
logger.warning("query_alias to %s threw exception %s", uri, ex)
return False

async def query_3pe(self, service, kind, protocol, fields):
async def query_3pe(
self,
service: "ApplicationService",
kind: str,
protocol: str,
fields: Dict[bytes, List[bytes]],
) -> List[JsonDict]:
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
Expand Down Expand Up @@ -205,11 +219,14 @@ async def push_bulk(
events: List[EventBase],
ephemeral: List[JsonDict],
txn_id: Optional[int] = None,
):
) -> bool:
if service.url is None:
return True

events = self._serialize(service, events)
# This is required by the configuration.
assert service.hs_token is not None

serialized_events = self._serialize(service, events)

if txn_id is None:
logger.warning(
Expand All @@ -221,9 +238,12 @@ async def push_bulk(

# Never send ephemeral events to appservices that do not support it
if service.supports_ephemeral:
body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
body = {
"events": serialized_events,
"de.sorunome.msc2409.ephemeral": ephemeral,
}
else:
body = {"events": events}
body = {"events": serialized_events}

try:
await self.put_json(
Expand All @@ -232,7 +252,7 @@ async def push_bulk(
args={"access_token": service.hs_token},
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
sent_events_counter.labels(service.id).inc(len(serialized_events))
return True
except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code)
Expand All @@ -241,7 +261,9 @@ async def push_bulk(
failed_transactions_counter.labels(service.id).inc()
return False

def _serialize(self, service, events):
def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:
time_now = self.clock.time_msec()
return [
serialize_event(
Expand Down

0 comments on commit 2a1449f

Please sign in to comment.