Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Factor out an is_mine_server_name method (#15542)
Browse files Browse the repository at this point in the history
Add an `is_mine_server_name` method, similar to `is_mine_id`.

Ideally we would use this consistently, instead of sometimes comparing
against `hs.hostname` and other times reaching into
`hs.config.server.server_name`.

Also fix a bug in the tests where `hs.hostname` would sometimes differ
from `hs.config.server.server_name`.

Signed-off-by: Sean Quah <seanq@matrix.org>
  • Loading branch information
squahtx authored May 5, 2023
1 parent 83e7fa5 commit e46d5f3
Show file tree
Hide file tree
Showing 23 changed files with 64 additions and 36 deletions.
1 change: 1 addition & 0 deletions changelog.d/15542.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Factor out an `is_mine_server_name` method.
4 changes: 2 additions & 2 deletions synapse/api/auth_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, hs: "HomeServer"):
self._mau_limits_reserved_threepids = (
hs.config.server.mau_limits_reserved_threepids
)
self._server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips

async def check_auth_blocking(
Expand Down Expand Up @@ -77,7 +77,7 @@ async def check_auth_blocking(
if requester:
if requester.authenticated_entity.startswith("@"):
user_id = requester.authenticated_entity
elif requester.authenticated_entity == self._server_name:
elif self._is_mine_server_name(requester.authenticated_entity):
# We never block the server from doing actions on behalf of
# users.
return
Expand Down
4 changes: 2 additions & 2 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(
process_batch_callback=self._inner_fetch_key_requests,
)

self._hostname = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name

# build a FetchKeyResult for each of our own keys, to shortcircuit the
# fetcher.
Expand Down Expand Up @@ -277,7 +277,7 @@ async def process_request(self, verify_request: VerifyJsonRequest) -> None:

# If we are the originating server, short-circuit the key-fetch for any keys
# we already have
if verify_request.server_name == self._hostname:
if self._is_mine_server_name(verify_request.server_name):
for key_id in verify_request.key_ids:
if key_id in self._local_verify_keys:
found_keys[key_id] = self._local_verify_keys[key_id]
Expand Down
2 changes: 1 addition & 1 deletion synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class FederationBase:
def __init__(self, hs: "HomeServer"):
self.hs = hs

self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self.keyring = hs.get_keyring()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self.store = hs.get_datastores().main
Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ async def _try_destination_list(

for destination in destinations:
# We don't want to ask our own server for information we don't have
if destination == self.server_name:
if self._is_mine_server_name(destination):
continue

try:
Expand Down Expand Up @@ -1536,7 +1536,7 @@ async def forward_third_party_invite(
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None:
for destination in destinations:
if destination == self.server_name:
if self._is_mine_server_name(destination):
continue

try:
Expand Down
3 changes: 2 additions & 1 deletion synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class FederationServer(FederationBase):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.server_name = hs.hostname
self.handler = hs.get_federation_handler()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self._federation_event_handler = hs.get_federation_event_handler()
Expand Down Expand Up @@ -942,7 +943,7 @@ async def _on_send_membership_event(
authorising_server = get_domain_from_id(
event.content[EventContentFields.AUTHORISING_USER]
)
if authorising_server != self.server_name:
if not self._is_mine_server_name(authorising_server):
raise SynapseError(
400,
f"Cannot authorise request from resident server: {authorising_server}",
Expand Down
3 changes: 2 additions & 1 deletion synapse/federation/send_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name

# We may have multiple federation sender instances, so we need to track
# their positions separately.
Expand Down Expand Up @@ -198,7 +199,7 @@ def build_and_send_edu(
key: Optional[Hashable] = None,
) -> None:
"""As per FederationSender"""
if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.info("Not sending EDU to ourselves")
return

Expand Down
11 changes: 6 additions & 5 deletions synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def __init__(self, hs: "HomeServer"):

self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name

self._presence_router: Optional["PresenceRouter"] = None
self._transaction_manager = TransactionManager(hs)
Expand Down Expand Up @@ -766,7 +767,7 @@ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
domains = [
d
for d in domains_set
if d != self.server_name
if not self.is_mine_server_name(d)
and self._federation_shard_config.should_handle(self._instance_name, d)
]
if not domains:
Expand Down Expand Up @@ -832,7 +833,7 @@ def send_presence_to_destinations(
assert self.is_mine_id(state.user_id)

for destination in destinations:
if destination == self.server_name:
if self.is_mine_server_name(destination):
continue
if not self._federation_shard_config.should_handle(
self._instance_name, destination
Expand Down Expand Up @@ -860,7 +861,7 @@ def build_and_send_edu(
content: content of EDU
key: clobbering key for this edu
"""
if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.info("Not sending EDU to ourselves")
return

Expand Down Expand Up @@ -897,7 +898,7 @@ def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None:
queue.send_edu(edu)

def send_device_messages(self, destination: str, immediate: bool = True) -> None:
if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.warning("Not sending device update to ourselves")
return

Expand All @@ -919,7 +920,7 @@ def wake_destination(self, destination: str) -> None:
might have come back.
"""

if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.warning("Not waking up ourselves")
return

Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""

def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
self._is_mine_server_name = hs.is_mine_server_name

async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
Expand Down Expand Up @@ -235,7 +235,7 @@ async def send_transaction(
transaction.transaction_id,
)

if transaction.destination == self.server_name:
if self._is_mine_server_name(transaction.destination):
raise RuntimeError("Transport layer cannot send to itself!")

# FIXME: This is only used by the tests. The actual json sent is
Expand Down
5 changes: 4 additions & 1 deletion synapse/federation/transport/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self.store = hs.get_datastores().main
self.federation_domain_whitelist = (
hs.config.federation.federation_domain_whitelist
Expand Down Expand Up @@ -100,7 +101,9 @@ async def authenticate_request(
json_request["signatures"].setdefault(origin, {})[key] = sig

# if the origin_server sent a destination along it needs to match our own server_name
if destination is not None and destination != self.server_name:
if destination is not None and not self._is_mine_server_name(
destination
):
raise AuthenticationError(
HTTPStatus.UNAUTHORIZED,
"Destination mismatch in auth header",
Expand Down
5 changes: 3 additions & 2 deletions synapse/handlers/event_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.types import StateMap, StrCollection, get_domain_from_id
from synapse.types import StateMap, StrCollection

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand All @@ -47,6 +47,7 @@ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
self._state_storage_controller = hs.get_storage_controllers().state
self._server_name = hs.hostname
self._is_mine_id = hs.is_mine_id

async def check_auth_rules_from_context(
self,
Expand Down Expand Up @@ -247,7 +248,7 @@ async def check_restricted_join_rules(
if not await self.is_user_in_rooms(allowed_rooms, user_id):
# If this is a remote request, the user might be in an allowed room
# that we do not know about.
if get_domain_from_id(user_id) != self._server_name:
if not self._is_mine_id(user_id):
for room_id in allowed_rooms:
if not await self._store.is_host_joined(room_id, self._server_name):
raise SynapseError(
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self.event_creation_handler = hs.get_event_creation_handler()
self.event_builder_factory = hs.get_event_builder_factory()
Expand Down Expand Up @@ -453,7 +454,7 @@ async def try_backfill(domains: StrCollection) -> bool:

for dom in domains:
# We don't want to ask our own server for information we don't have
if dom == self.server_name:
if self.is_mine_server_name(dom):
continue

try:
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(self, hs: "HomeServer"):
self._notifier = hs.get_notifier()

self._is_mine_id = hs.is_mine_id
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self._instance_name = hs.get_instance_name()

Expand Down Expand Up @@ -688,7 +689,7 @@ async def backfill(
server from invalid events (there is probably no point in trying to
re-fetch invalid events from every other HS in the room.)
"""
if dest == self._server_name:
if self._is_mine_server_name(dest):
raise SynapseError(400, "Can't backfill from self.")

events = await self._federation_client.backfill(
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, hs: "HomeServer"):
self.max_avatar_size = hs.config.server.max_avatar_size
self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes

self.server_name = hs.config.server.server_name
self._is_mine_server_name = hs.is_mine_server_name

self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules

Expand Down Expand Up @@ -309,7 +309,7 @@ async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:
else:
server_name = host

if server_name == self.server_name:
if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self._registration_handler = hs.get_registration_handler()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
Expand Down Expand Up @@ -802,7 +803,7 @@ def is_allowed_mime_type(content_type: str) -> bool:
if profile["avatar_url"] is not None:
server_name = profile["avatar_url"].split("/")[-2]
media_id = profile["avatar_url"].split("/")[-1]
if server_name == self._server_name:
if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]:
logger.info("skipping saving the user avatar")
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, hs: "HomeServer"):
self.server_name = hs.config.server.server_name
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name

self.federation = None
if hs.should_send_federation():
Expand Down Expand Up @@ -153,7 +154,7 @@ async def _push_remote(self, member: RoomMember, typing: bool) -> None:
member.room_id
)
for domain in hosts:
if domain != self.server_name:
if not self.is_mine_server_name(domain):
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
destination=domain,
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/admin/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,15 @@ class DeleteMediaByID(RestServlet):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.auth = hs.get_auth()
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self.media_repository = hs.get_media_repository()

async def on_DELETE(
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

if self.server_name != server_name:
if not self._is_mine_server_name(server_name):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")

if await self.store.get_local_media(media_id) is None:
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
limit = None

handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server.server_name:
if server and not self.hs.is_mine_server_name(server):
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
Expand Down Expand Up @@ -551,7 +551,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
limit = None

handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server.server_name:
if server and not self.hs.is_mine_server_name(server):
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/media/download_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DownloadResource(DirectServeJsonResource):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name

async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
Expand All @@ -59,7 +59,7 @@ async def _async_render_GET(self, request: SynapseRequest) -> None:
b"no-referrer",
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
if self._is_mine_server_name(server_name):
await self.media_repo.get_local_media(request, media_id, name)
else:
allow_remote = parse_boolean(request, "allow_remote", default=True)
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/media/thumbnail_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
self.media_repo = media_repo
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name

async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
Expand All @@ -71,7 +71,7 @@ async def _async_render_GET(self, request: SynapseRequest) -> None:
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png"

if server_name == self.server_name:
if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
await self._select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type
Expand Down
4 changes: 4 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,10 @@ def is_mine_id(self, string: str) -> bool:
return False
return localpart_hostname[1] == self.hostname

def is_mine_server_name(self, server_name: str) -> bool:
"""Determines whether a server name refers to this homeserver."""
return server_name == self.hostname

@cache_in_self
def get_clock(self) -> Clock:
return Clock(self._reactor)
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ async def quarantine_media_by_id(
If it is `None` media will be removed from quarantine
"""
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server.server_name
is_local = self.hs.is_mine_server_name(server_name)

def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
local_mxcs = [media_id] if is_local else []
Expand Down
Loading

0 comments on commit e46d5f3

Please sign in to comment.