Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 4022d2d

Browse files
committed
media/{download,thumbnail}: add support for stall parameter
The stall will default to 20s if MSC2246 is not enabled. For remote media, the max_stall_ms parameter through to remote server. This way, if servers over federation attempt to request the media, the request will only 404 if the media is large (or the client doesn't upload as soon as it can). Signed-off-by: Sumner Evans <sumner@beeper.com>
1 parent ae43011 commit 4022d2d

File tree

5 files changed

+159
-60
lines changed

5 files changed

+159
-60
lines changed

synapse/rest/media/v1/_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
"text/xml",
5050
]
5151

52+
DEFAULT_MSC2246_DELAY = 20_000
53+
5254

5355
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
5456
"""Parses the server name, media ID and optional file name from the request URI

synapse/rest/media/v1/download_resource.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from typing import TYPE_CHECKING
1717

1818
from synapse.http.server import DirectServeJsonResource, set_cors_headers
19-
from synapse.http.servlet import parse_boolean
19+
from synapse.http.servlet import parse_boolean, parse_integer
2020
from synapse.http.site import SynapseRequest
2121

22-
from ._base import parse_media_id, respond_404
22+
from ._base import DEFAULT_MSC2246_DELAY, parse_media_id, respond_404
2323

2424
if TYPE_CHECKING:
2525
from synapse.rest.media.v1.media_repository import MediaRepository
@@ -35,6 +35,7 @@ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
3535
super().__init__()
3636
self.media_repo = media_repo
3737
self.server_name = hs.hostname
38+
self.enable_msc2246 = hs.config.experimental.msc2246_enabled
3839

3940
async def _async_render_GET(self, request: SynapseRequest) -> None:
4041
set_cors_headers(request)
@@ -50,13 +51,14 @@ async def _async_render_GET(self, request: SynapseRequest) -> None:
5051
)
5152
# Limited non-standard form of CSP for IE11
5253
request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
53-
request.setHeader(
54-
b"Referrer-Policy",
55-
b"no-referrer",
56-
)
54+
request.setHeader(b"Referrer-Policy", b"no-referrer")
5755
server_name, media_id, name = parse_media_id(request)
56+
max_stall_ms = parse_integer(
57+
request, "fi.mau.msc2246.max_stall_ms", default=DEFAULT_MSC2246_DELAY
58+
)
59+
5860
if server_name == self.server_name:
59-
await self.media_repo.get_local_media(request, media_id, name)
61+
await self.media_repo.get_local_media(request, media_id, name, max_stall_ms)
6062
else:
6163
allow_remote = parse_boolean(request, "allow_remote", default=True)
6264
if not allow_remote:
@@ -68,4 +70,6 @@ async def _async_render_GET(self, request: SynapseRequest) -> None:
6870
respond_404(request)
6971
return
7072

71-
await self.media_repo.get_remote_media(request, server_name, media_id, name)
73+
await self.media_repo.get_remote_media(
74+
request, server_name, media_id, name, max_stall_ms
75+
)

synapse/rest/media/v1/media_repository.py

Lines changed: 91 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import shutil
1919
from enum import Enum
2020
from io import BytesIO
21-
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
21+
from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
2222

2323
import twisted.internet.error
2424
import twisted.web.http
@@ -32,13 +32,14 @@
3232
NotFoundError,
3333
RequestSendFailed,
3434
SynapseError,
35+
cs_error,
3536
)
3637
from synapse.config._base import ConfigError
3738
from synapse.config.repository import ThumbnailRequirement
39+
from synapse.http.server import respond_with_json
3840
from synapse.http.site import SynapseRequest
3941
from synapse.logging.context import defer_to_thread
4042
from synapse.metrics.background_process_metrics import run_as_background_process
41-
from synapse.rest.media.v1.create_resource import CreateResource
4243
from synapse.types import UserID
4344
from synapse.util.async_helpers import Linearizer
4445
from synapse.util.retryutils import NotRetryingDestination
@@ -53,6 +54,7 @@
5354
respond_with_responder,
5455
)
5556
from .config_resource import MediaConfigResource
57+
from .create_resource import CreateResource
5658
from .download_resource import DownloadResource
5759
from .filepath import MediaFilePaths
5860
from .media_storage import MediaStorage
@@ -288,8 +290,64 @@ async def create_content(
288290

289291
return "mxc://%s/%s" % (self.server_name, media_id)
290292

293+
def respond_not_yet_uploaded(self, request: SynapseRequest) -> None:
294+
not_uploaded_error = cs_error(
295+
"Media has not been uploaded yet",
296+
code="FI.MAU.MSC2246_NOT_YET_UPLOADED",
297+
retry_after_ms=5_000,
298+
)
299+
respond_with_json(request, 404, not_uploaded_error, send_cors=True)
300+
301+
async def get_local_media_info(
302+
self, request: SynapseRequest, media_id: str, max_stall_ms: int
303+
) -> Optional[Dict[str, Any]]:
304+
"""Gets the info dictionary for given local media ID. If the media has
305+
not been uploaded yet, this function will wait up to ``max_stall_ms``
306+
milliseconds for the media to be uploaded.
307+
308+
Args:
309+
request: The incoming request.
310+
media_id: The media ID of the content. (This is the same as
311+
the file_id for local content.)
312+
max_stall_ms: the maximum number of milliseconds to wait for the
313+
media to be uploaded.
314+
315+
Returns:
316+
Either the info dictionary for the given local media ID or
317+
``None``. If ``None``, then no further processing is necessary as
318+
this function will send the necessary JSON response.
319+
"""
320+
wait_until = self.clock.time_msec() + max_stall_ms
321+
while True:
322+
# Get the info for the media
323+
media_info = await self.store.get_local_media(media_id)
324+
if not media_info:
325+
respond_404(request)
326+
return None
327+
328+
if media_info["quarantined_by"]:
329+
logger.info("Media is quarantined")
330+
respond_404(request)
331+
return None
332+
333+
# The file has been uploaded, so stop looping
334+
if media_info.get("media_length") is not None:
335+
return media_info
336+
337+
if self.clock.time_msec() >= wait_until:
338+
break
339+
340+
await self.clock.sleep(0.5)
341+
342+
self.respond_not_yet_uploaded(request)
343+
return None
344+
291345
async def get_local_media(
292-
self, request: SynapseRequest, media_id: str, name: Optional[str]
346+
self,
347+
request: SynapseRequest,
348+
media_id: str,
349+
name: Optional[str],
350+
max_stall_ms: int,
293351
) -> None:
294352
"""Responds to requests for local media, if exists, or returns 404.
295353
@@ -299,13 +357,14 @@ async def get_local_media(
299357
the file_id for local content.)
300358
name: Optional name that, if specified, will be used as
301359
the filename in the Content-Disposition header of the response.
360+
max_stall_ms: the maximum number of milliseconds to wait for the
361+
media to be uploaded.
302362
303363
Returns:
304364
Resolves once a response has successfully been written to request
305365
"""
306-
media_info = await self.store.get_local_media(media_id)
307-
if not media_info or media_info["quarantined_by"]:
308-
respond_404(request)
366+
media_info = await self.get_local_media_info(request, media_id, max_stall_ms)
367+
if not media_info:
309368
return
310369

311370
self.mark_recently_accessed(None, media_id)
@@ -330,6 +389,7 @@ async def get_remote_media(
330389
server_name: str,
331390
media_id: str,
332391
name: Optional[str],
392+
max_stall_ms: int,
333393
) -> None:
334394
"""Respond to requests for remote media.
335395
@@ -339,6 +399,8 @@ async def get_remote_media(
339399
media_id: The media ID of the content (as defined by the remote server).
340400
name: Optional name that, if specified, will be used as
341401
the filename in the Content-Disposition header of the response.
402+
max_stall_ms: the maximum number of milliseconds to wait for the
403+
media to be uploaded.
342404
343405
Returns:
344406
Resolves once a response has successfully been written to request
@@ -353,33 +415,37 @@ async def get_remote_media(
353415

354416
# We linearize here to ensure that we don't try and download remote
355417
# media multiple times concurrently
356-
key = (server_name, media_id)
357-
async with self.remote_media_linearizer.queue(key):
418+
async with self.remote_media_linearizer.queue((server_name, media_id)):
358419
responder, media_info = await self._get_remote_media_impl(
359-
server_name, media_id
420+
server_name, media_id, max_stall_ms
360421
)
361422

362-
# We deliberately stream the file outside the lock
363-
if responder:
423+
if responder and media_info:
364424
media_type = media_info["media_type"]
365425
media_length = media_info["media_length"]
366426
upload_name = name if name else media_info["upload_name"]
367427
await respond_with_responder(
368428
request, responder, media_type, media_length, upload_name
369429
)
370430
else:
371-
respond_404(request)
431+
self.respond_not_yet_uploaded(request)
432+
return
372433

373-
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
434+
async def get_remote_media_info(
435+
self, server_name: str, media_id: str, max_stall_ms: int
436+
) -> Optional[dict]:
374437
"""Gets the media info associated with the remote file, downloading
375438
if necessary.
376439
377440
Args:
378441
server_name: Remote server_name where the media originated.
379442
media_id: The media ID of the content (as defined by the remote server).
443+
max_stall_ms: the maximum number of milliseconds to wait for the
444+
media to be uploaded.
380445
381446
Returns:
382-
The media info of the file
447+
The media info of the file or ``None`` if the media wasn't uploaded
448+
in time.
383449
"""
384450
if (
385451
self.federation_domain_whitelist is not None
@@ -389,10 +455,9 @@ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
389455

390456
# We linearize here to ensure that we don't try and download remote
391457
# media multiple times concurrently
392-
key = (server_name, media_id)
393-
async with self.remote_media_linearizer.queue(key):
458+
async with self.remote_media_linearizer.queue((server_name, media_id)):
394459
responder, media_info = await self._get_remote_media_impl(
395-
server_name, media_id
460+
server_name, media_id, max_stall_ms
396461
)
397462

398463
# Ensure we actually use the responder so that it releases resources
@@ -403,7 +468,7 @@ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
403468
return media_info
404469

405470
async def _get_remote_media_impl(
406-
self, server_name: str, media_id: str
471+
self, server_name: str, media_id: str, max_stall_ms: int
407472
) -> Tuple[Optional[Responder], dict]:
408473
"""Looks for media in local cache, if not there then attempt to
409474
download from remote server.
@@ -412,6 +477,8 @@ async def _get_remote_media_impl(
412477
server_name (str): Remote server_name where the media originated.
413478
media_id (str): The media ID of the content (as defined by the
414479
remote server).
480+
max_stall_ms: the maximum number of milliseconds to wait for the
481+
media to be uploaded.
415482
416483
Returns:
417484
A tuple of responder and the media info of the file.
@@ -442,8 +509,7 @@ async def _get_remote_media_impl(
442509

443510
try:
444511
media_info = await self._download_remote_file(
445-
server_name,
446-
media_id,
512+
server_name, media_id, max_stall_ms
447513
)
448514
except SynapseError:
449515
raise
@@ -476,6 +542,7 @@ async def _download_remote_file(
476542
self,
477543
server_name: str,
478544
media_id: str,
545+
max_stall_ms: int,
479546
) -> dict:
480547
"""Attempt to download the remote file from the given server name,
481548
using the given file_id as the local id.
@@ -485,7 +552,8 @@ async def _download_remote_file(
485552
media_id: The media ID of the content (as defined by the
486553
remote server). This is different than the file_id, which is
487554
locally generated.
488-
file_id: Local file ID
555+
max_stall_ms: the maximum number of milliseconds to wait for the
556+
media to be uploaded.
489557
490558
Returns:
491559
The media info of the file.
@@ -509,7 +577,8 @@ async def _download_remote_file(
509577
# tell the remote server to 404 if it doesn't
510578
# recognise the server_name, to make sure we don't
511579
# end up with a routing loop.
512-
"allow_remote": "false"
580+
"allow_remote": "false",
581+
"fi.mau.msc2246.max_stall_ms": str(max_stall_ms),
513582
},
514583
)
515584
except RequestSendFailed as e:

0 commit comments

Comments
 (0)