Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 33784c9

Browse files
committed
media/upload: add support for async uploads
Signed-off-by: Sumner Evans <sumner@beeper.com>
1 parent 59c040d commit 33784c9

File tree

3 files changed

+149
-7
lines changed

3 files changed

+149
-7
lines changed

synapse/media/media_repository.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from twisted.internet.defer import Deferred
2727

2828
from synapse.api.errors import (
29+
Codes,
2930
FederationDeniedError,
3031
HttpResponseException,
3132
NotFoundError,
@@ -193,6 +194,70 @@ async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]:
193194
)
194195
return f"mxc://{self.server_name}/{media_id}", unused_expires_at
195196

197+
async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None:
198+
"""Verify that the media ID can be uploaded to by the given user. This
199+
function checks that:
200+
* the media ID exists
201+
* the media ID does not already have content
202+
* the user uploading is the same as the one who created the media ID
203+
* the media ID has not expired
204+
Args:
205+
media_id: The media ID to verify
206+
auth_user: The user_id of the uploader
207+
"""
208+
media = await self.store.get_local_media(media_id)
209+
if media is None:
210+
raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND)
211+
212+
if media["user_id"] != str(auth_user):
213+
raise SynapseError(
214+
403,
215+
"Only the creator of the media ID can upload to it",
216+
errcode=Codes.FORBIDDEN,
217+
)
218+
219+
if media.get("media_length") is not None:
220+
raise SynapseError(
221+
409,
222+
"Media ID already has content",
223+
errcode="M_CANNOT_OVERWRITE_MEDIA",
224+
)
225+
226+
if media.get("unused_expires_at", 0) < self.clock.time_msec():
227+
raise NotFoundError("Media ID has expired")
228+
229+
async def update_content(
230+
self,
231+
media_id: str,
232+
media_type: str,
233+
upload_name: Optional[str],
234+
content: IO,
235+
content_length: int,
236+
auth_user: UserID,
237+
) -> None:
238+
"""Update the content of the given media ID.
239+
Args:
240+
media_id: The media ID to replace.
241+
media_type: The content type of the file.
242+
upload_name: The name of the file, if provided.
243+
content: A file like object that is the content to store
244+
content_length: The length of the content
245+
auth_user: The user_id of the uploader
246+
"""
247+
file_info = FileInfo(server_name=None, file_id=media_id)
248+
fname = await self.media_storage.store_file(content, file_info)
249+
logger.info("Stored local media in file %r", fname)
250+
251+
await self.store.update_local_media(
252+
media_id=media_id,
253+
media_type=media_type,
254+
upload_name=upload_name,
255+
media_length=content_length,
256+
user_id=auth_user,
257+
)
258+
259+
await self._generate_thumbnails(None, media_id, media_id, media_type)
260+
196261
async def create_content(
197262
self,
198263
media_type: str,

synapse/rest/media/upload_resource.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,25 @@
1414
# limitations under the License.
1515

1616
import logging
17-
from typing import IO, TYPE_CHECKING, Dict, List, Optional
17+
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple
1818

1919
from synapse.api.errors import Codes, SynapseError
2020
from synapse.http.server import DirectServeJsonResource, respond_with_json
2121
from synapse.http.servlet import parse_bytes_from_args
2222
from synapse.http.site import SynapseRequest
2323
from synapse.media.media_storage import SpamMediaException
2424

25+
from synapse.media._base import parse_media_id, respond_404
26+
2527
if TYPE_CHECKING:
2628
from synapse.media.media_repository import MediaRepository
2729
from synapse.server import HomeServer
2830

2931
logger = logging.getLogger(__name__)
3032

33+
# The name of the lock to use when uploading media.
34+
_UPLOAD_MEDIA_LOCK_NAME = "upload_media"
35+
3136

3237
class UploadResource(DirectServeJsonResource):
3338
isLeaf = True
@@ -38,16 +43,13 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
3843
self.media_repo = media_repo
3944
self.filepaths = media_repo.filepaths
4045
self.store = hs.get_datastores().main
41-
self.clock = hs.get_clock()
4246
self.server_name = hs.hostname
4347
self.auth = hs.get_auth()
4448
self.max_upload_size = hs.config.media.max_upload_size
4549

46-
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
47-
respond_with_json(request, 200, {}, send_cors=True)
48-
49-
async def _async_render_POST(self, request: SynapseRequest) -> None:
50-
requester = await self.auth.get_user_by_req(request)
50+
def _get_file_metadata(
51+
self, request: SynapseRequest
52+
) -> Tuple[int, Optional[str], str]:
5153
raw_content_length = request.getHeader("Content-Length")
5254
if raw_content_length is None:
5355
raise SynapseError(msg="Request must specify a Content-Length", code=400)
@@ -90,6 +92,15 @@ async def _async_render_POST(self, request: SynapseRequest) -> None:
9092
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
9193
# TODO(markjh): parse content-dispostion
9294

95+
return content_length, upload_name, media_type
96+
97+
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
98+
respond_with_json(request, 200, {}, send_cors=True)
99+
100+
async def _async_render_POST(self, request: SynapseRequest) -> None:
101+
requester = await self.auth.get_user_by_req(request)
102+
content_length, upload_name, media_type = self._get_file_metadata(request)
103+
93104
try:
94105
content: IO = request.content # type: ignore
95106
content_uri = await self.media_repo.create_content(
@@ -105,3 +116,44 @@ async def _async_render_POST(self, request: SynapseRequest) -> None:
105116
respond_with_json(
106117
request, 200, {"content_uri": str(content_uri)}, send_cors=True
107118
)
119+
120+
async def _async_render_PUT(self, request: SynapseRequest) -> None:
121+
requester = await self.auth.get_user_by_req(request)
122+
server_name, media_id, _ = parse_media_id(request)
123+
124+
if server_name != self.server_name:
125+
raise SynapseError(
126+
404,
127+
"Non-local server name specified",
128+
errcode=Codes.NOT_FOUND,
129+
)
130+
131+
lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id)
132+
if not lock:
133+
raise SynapseError(
134+
409,
135+
"Media ID is is locked and cannot be uploaded to",
136+
errcode="M_CANNOT_OVERWRITE_MEDIA",
137+
)
138+
139+
async with lock:
140+
await self.media_repo.verify_can_upload(media_id, requester.user)
141+
content_length, upload_name, media_type = self._get_file_metadata(request)
142+
143+
try:
144+
content: IO = request.content # type: ignore
145+
await self.media_repo.update_content(
146+
media_id,
147+
media_type,
148+
upload_name,
149+
content,
150+
content_length,
151+
requester.user,
152+
)
153+
except SpamMediaException:
154+
# For uploading of media we want to respond with a 400, instead of
155+
# the default 404, as that would just be confusing.
156+
raise SynapseError(400, "Bad content")
157+
158+
logger.info("Uploaded content to URI %r", media_id)
159+
respond_with_json(request, 200, {}, send_cors=True)

synapse/storage/databases/main/media_repository.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
166166
"quarantined_by",
167167
"url_cache",
168168
"safe_from_quarantine",
169+
"user_id",
169170
),
170171
allow_none=True,
171172
desc="get_local_media",
@@ -370,6 +371,30 @@ async def store_local_media(
370371
desc="store_local_media",
371372
)
372373

374+
async def update_local_media(
375+
self,
376+
media_id: str,
377+
media_type: str,
378+
upload_name: Optional[str],
379+
media_length: int,
380+
user_id: UserID,
381+
url_cache: Optional[str] = None,
382+
) -> None:
383+
await self.db_pool.simple_update_one(
384+
"local_media_repository",
385+
keyvalues={
386+
"user_id": user_id.to_string(),
387+
"media_id": media_id,
388+
},
389+
updatevalues={
390+
"media_type": media_type,
391+
"upload_name": upload_name,
392+
"media_length": media_length,
393+
"url_cache": url_cache,
394+
},
395+
desc="update_local_media",
396+
)
397+
373398
async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
374399
"""Mark a local media as safe or unsafe from quarantining."""
375400
await self.db_pool.simple_update_one(

0 commit comments

Comments
 (0)