Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 118 additions & 2 deletions synapse/rest/client/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,26 @@

import logging
import re
from typing import Optional
from typing import Optional, Union

from synapse.api.errors import (
Codes,
NotFoundError,
SynapseError,
)
from synapse.http.server import (
HttpServer,
respond_with_json,
respond_with_json_bytes,
set_corp_headers,
set_cors_headers,
)
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
Expand All @@ -44,6 +54,8 @@
from synapse.rest.media.create_resource import CreateResource
from synapse.rest.media.upload_resource import UploadRestrictedResource
from synapse.server import HomeServer
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import Requester
from synapse.util.stringutils import parse_and_validate_server_name

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -277,6 +289,109 @@ async def on_GET(
)


class CopyResource(RestServlet):
"""
MSC3911: This is an unstable endpoint that is introduced in msc3911 scope. This
"copy" api is to be used by clients when forwarding events with media attachments.
Rather than just allowing clients to attach media to multiple events, this ensures
that the list of events attached to a media does not grow over time, so that servers
can reliably cache media and impose the correct access restrictions.
"""

# Stable: /_matrix/client/v1/media/copy/{serverName}/{mediaId}
PATTERNS = [
re.compile(
"/_matrix/client/unstable/org.matrix.msc3911/media/copy/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)"
)
]

def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.store = hs.get_datastores().main
self.media_repo = media_repo
self.auth = hs.get_auth()
self._is_mine_server_name = hs.is_mine_server_name
self.limits_dict = {"m.upload.size": hs.config.media.max_upload_size}
self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository
self.clock = hs.get_clock()

async def _validate_user_media_limit(
self, requester: Requester, media_info: Union[LocalMedia, RemoteMedia, None]
) -> None:
"""Check if the request exceeds the user's media limits."""
media_config = await self.media_repository_callbacks.get_media_config_for_user(
requester.user.to_string(),
)
if not media_config:
media_config = self.limits_dict

max_upload_size = media_config.get("m.upload.size")
if max_upload_size and media_info and media_info.media_length:
# We are not counting the amount of media the user uploaded in a previous time period
if media_info.media_length > max_upload_size:
raise SynapseError(400, Codes.RESOURCE_LIMIT_EXCEEDED)

async def on_POST(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
) -> None:
"""
Handles copying a media item referenced by server_name and media_id.
Returns a new MXC URI for the copied media.
"""
requester = await self.auth.get_user_by_req(request)

# Optionally parse request body, must be a JSON object, but no required params.
_ = parse_json_object_from_request(request, allow_empty_body=True)

media_info: Union[LocalMedia, RemoteMedia, None] = None
if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id)
else:
media_info = await self.media_repo.get_remote_media_info(
server_name,
media_id,
DEFAULT_MAX_TIMEOUT_MS,
request.getClientAddress().host,
use_federation=True,
allow_authenticated=True,
)

if not media_info:
raise NotFoundError()
if media_info.quarantined_by:
raise NotFoundError()

await self._validate_user_media_limit(requester, media_info)

if media_info:
try:
mxc_uri, _ = await self.media_repo.create_media_id(
requester.user, restricted=True
)
if media_info.media_length and media_info.sha256:
await self.store.update_local_media(
media_id=mxc_uri.split("/")[-1],
media_type=media_info.media_type,
upload_name=media_info.upload_name,
media_length=media_info.media_length,
user_id=requester.user,
sha256=media_info.sha256,
quarantined_by=None,
)
respond_with_json(
request,
200,
{"content_uri": mxc_uri},
send_cors=True,
)
except Exception as e:
logger.error("Failed to copy media: %s", e)
respond_with_json(request, 500, {"error": "Failed to copy media"})


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
media_repo = hs.get_media_repository()
if hs.config.media.url_preview_enabled:
Expand All @@ -289,3 +404,4 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc3911_enabled:
CreateResource(hs, media_repo, restricted=True).register(http_server)
UploadRestrictedResource(hs, media_repo).register(http_server)
CopyResource(hs, media_repo).register(http_server)
145 changes: 144 additions & 1 deletion tests/rest/client/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -3029,7 +3029,7 @@ def test_unrestricted_resource_upload_disabled(self) -> None:
)


class RestrictedResourceTestCase(unittest.HomeserverTestCase):
class RestrictedResourceUploadTestCase(unittest.HomeserverTestCase):
"""
Tests restricted media creation and upload endpoints when `msc3911_enabled` is
configured to be True.
Expand Down Expand Up @@ -3176,3 +3176,146 @@ def test_async_upload_restricted_resource(self) -> None:
access_token=self.other_user_tok,
)
assert channel.code == 404


class CopyRestrictedResource(unittest.HomeserverTestCase):
"""
Tests copy API when `msc3911_enabled` is configured to be True.
"""

extra_config = {
"experimental_features": {"msc3911_enabled": True},
}

servlets = [
media.register_servlets,
login.register_servlets,
admin.register_servlets,
]

def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config.update(self.extra_config)
return self.setup_test_homeserver(config=config)

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository()
self.user = self.register_user("user", "testpass")
self.user_tok = self.login("user", "testpass")
self.other_user = self.register_user("other", "testpass")
self.other_user_tok = self.login("other", "testpass")

def create_resource_dict(self) -> dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources

def test_copy_local_restricted_resource(self) -> None:
"""
Tests that the new copy endpoint creates a new mxc uri for restricted resource.
"""
# The media is created with user_tok
content = io.BytesIO(SMALL_PNG)
content_uri = self.get_success(
self.media_repo.create_or_update_content(
"image/png",
"test_png_upload",
content,
67,
UserID.from_string(self.user),
restricted=True,
)
)
media_id = content_uri.media_id

# The other_user copies the media from local server
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{self.hs.hostname}/{media_id}",
access_token=self.other_user_tok,
)
self.assertEqual(channel.code, 200)
self.assertIn("content_uri", channel.json_body)
new_media_id = channel.json_body["content_uri"].split("/")[-1]
assert new_media_id != media_id

# Check if the original media there.
original_media = self.get_success(
self.hs.get_datastores().main.get_local_media(media_id)
)
assert original_media is not None
assert original_media.user_id == self.user

# Check the copied media.
copied_media = self.get_success(
self.hs.get_datastores().main.get_local_media(new_media_id)
)
assert copied_media is not None
assert copied_media.user_id == self.other_user

# Check if they are referencing the same image.
assert original_media.sha256 == copied_media.sha256

# Check if media is unattached to any event or user profile yet.
assert copied_media.attachments is None

def test_copy_remote_restricted_resource(self) -> None:
"""
Tests that the new copy endpoint creates a new mxc uri for restricted resource.
"""
# create remote media
remote_server = "remoteserver.com"
remote_file_id = "remote1"
file_info = FileInfo(server_name=remote_server, file_id=remote_file_id)

media_storage = self.hs.get_media_repository().media_storage
ctx = media_storage.store_into_file(file_info)
(f, _) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG)
self.get_success(ctx.__aexit__(None, None, None))
media_id = "remotemedia"
self.get_success(
self.hs.get_datastores().main.store_cached_remote_media(
origin=remote_server,
media_id=media_id,
media_type="image/png",
media_length=1,
time_now_ms=self.clock.time_msec(),
upload_name="test.png",
filesystem_id=remote_file_id,
sha256=remote_file_id,
)
)

# The other_user copies the media from remote server
channel = self.make_request(
"POST",
f"/_matrix/client/unstable/org.matrix.msc3911/media/copy/{remote_server}/{media_id}",
access_token=self.other_user_tok,
)
self.assertEqual(channel.code, 200)
self.assertIn("content_uri", channel.json_body)
new_media_id = channel.json_body["content_uri"].split("/")[-1]
assert new_media_id != media_id

# Check if the original media there.
original_media = self.get_success(
self.hs.get_datastores().main.get_cached_remote_media(
remote_server, media_id
)
)
assert original_media is not None
assert original_media.upload_name == "test.png"

# Check the copied media.
copied_media = self.get_success(
self.hs.get_datastores().main.get_local_media(new_media_id)
)
assert copied_media is not None
assert copied_media.user_id == self.other_user

# Check if they are referencing the same image.
assert original_media.sha256 == copied_media.sha256

# Check if copied media is unattached to any event or user profile yet.
assert copied_media.attachments is None
Loading