Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions rust/src/events/internal_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,17 @@ enum EventInternalMetadataData {
TxnId(Box<str>),
TokenId(i64),
DeviceId(Box<str>),
MediaReferences(Vec<String>),
}

impl EventInternalMetadataData {
/// Convert the field to its name and python object.
fn to_python_pair<'a>(&self, py: Python<'a>) -> (&'a Bound<'a, PyString>, Bound<'a, PyAny>) {
match self {
EventInternalMetadataData::MediaReferences(o) => (
pyo3::intern!(py, "media_references"),
o.into_pyobject(py).unwrap().into_any(),
),
EventInternalMetadataData::OutOfBandMembership(o) => (
pyo3::intern!(py, "out_of_band_membership"),
o.into_pyobject(py)
Expand Down Expand Up @@ -128,6 +133,11 @@ impl EventInternalMetadataData {
let key_str: PyBackedStr = key.extract()?;

let e = match &*key_str {
"media_references" => EventInternalMetadataData::MediaReferences(
value
.extract()
.with_context(|| format!("'{key_str}' has invalid type"))?,
),
"out_of_band_membership" => EventInternalMetadataData::OutOfBandMembership(
value
.extract()
Expand Down Expand Up @@ -469,4 +479,14 @@ impl EventInternalMetadata {
fn set_device_id(&mut self, obj: String) {
set_property!(self, DeviceId, obj.into_boxed_str());
}

/// The media references for the restrictions being set for this event, if any.
#[getter]
fn get_media_references(&self) -> Option<&Vec<String>> {
get_property_opt!(self, MediaReferences)
}
#[setter]
fn set_media_references(&mut self, obj: Vec<String>) {
set_property!(self, MediaReferences, obj);
}
}
89 changes: 89 additions & 0 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Tuple

from canonicaljson import encode_canonical_json
from matrix_common.types.mxc_uri import MXCUri

from twisted.internet.interfaces import IDelayedCall

Expand Down Expand Up @@ -70,6 +71,10 @@
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.databases.main.media_repository import (
LocalMedia,
MediaRestrictions,
)
from synapse.types import (
JsonDict,
PersistedEventPosition,
Expand Down Expand Up @@ -583,6 +588,7 @@ async def create_event(
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
mxc_restriction_list_for_event: Optional[List[MXCUri]] = None,
) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""
Given a dict from a client, create a new event. If bool for_batch is true, will
Expand Down Expand Up @@ -637,6 +643,9 @@ async def create_event(
current_state_group: the current state group, used only for creating events for
batch persisting

mxc_restriction_list_for_event: An optional List of MXCUri objects, to be
used for setting media restrictions

Raises:
ResourceLimitError if server is blocked to some resource being
exceeded
Expand Down Expand Up @@ -716,6 +725,11 @@ async def create_event(
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id

if mxc_restriction_list_for_event is not None:
builder.internal_metadata.media_references = [
str(mxc) for mxc in mxc_restriction_list_for_event
]

builder.internal_metadata.outlier = outlier

event, unpersisted_context = await self.create_new_client_event(
Expand Down Expand Up @@ -956,6 +970,7 @@ async def create_and_send_nonmember_event(
ignore_shadow_ban: bool = False,
outlier: bool = False,
depth: Optional[int] = None,
media_info_for_attachment: Optional[set[LocalMedia]] = None,
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
Expand Down Expand Up @@ -984,6 +999,8 @@ async def create_and_send_nonmember_event(
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
media_info_for_attachment: An optional set of LocalMedia objects, for use in
restricting media.

Returns:
The event, and its stream ordering (if deduplication happened,
Expand Down Expand Up @@ -1057,6 +1074,7 @@ async def create_and_send_nonmember_event(
ignore_shadow_ban=ignore_shadow_ban,
outlier=outlier,
depth=depth,
media_info_for_attachment=media_info_for_attachment,
)

async def _create_and_send_nonmember_event_locked(
Expand All @@ -1070,6 +1088,7 @@ async def _create_and_send_nonmember_event_locked(
ignore_shadow_ban: bool = False,
outlier: bool = False,
depth: Optional[int] = None,
media_info_for_attachment: Optional[set[LocalMedia]] = None,
) -> Tuple[EventBase, int]:
room_id = event_dict["room_id"]

Expand Down Expand Up @@ -1098,6 +1117,13 @@ async def _create_and_send_nonmember_event_locked(
state_event_ids=state_event_ids,
outlier=outlier,
depth=depth,
mxc_restriction_list_for_event=[
MXCUri(self.server_name, local_media.media_id)
for local_media in media_info_for_attachment
if local_media
]
if media_info_for_attachment is not None
else None,
)
context = await unpersisted_context.persist(event)

Expand Down Expand Up @@ -1158,6 +1184,7 @@ async def _create_and_send_nonmember_event_locked(
events_and_context=[(event, context)],
ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban,
media_info_for_attachment=media_info_for_attachment,
)

break
Expand Down Expand Up @@ -1431,6 +1458,7 @@ async def handle_new_client_event(
ratelimit: bool = True,
extra_users: Optional[List[UserID]] = None,
ignore_shadow_ban: bool = False,
media_info_for_attachment: Optional[set[LocalMedia]] = None,
) -> EventBase:
"""Processes new events. Please note that if batch persisting events, an error in
handling any one of these events will result in all of the events being dropped.
Expand All @@ -1450,6 +1478,9 @@ async def handle_new_client_event(
ignore_shadow_ban: True if shadow-banned users should be allowed to
send this event.

media_info_for_attachment: An optional set of LocalMedia objects, for use in
restricting media.

Return:
If the event was deduplicated, the previous, duplicate, event. Otherwise,
`event`.
Expand All @@ -1460,6 +1491,16 @@ async def handle_new_client_event(
a room that has been un-partial stated.
"""
extra_users = extra_users or []
media_info_for_attachment = media_info_for_attachment or set()

# filter for the existing media attachments that were passed in based on the
# mxc. The 'attachments' key can be None, representing that an attachment has
# not been formed yet. If they are all None, will be an empty set
media_restrictions: set[MediaRestrictions] = {
local_media.attachments
for local_media in media_info_for_attachment
if local_media.attachments
}

for event, context in events_and_context:
# we don't apply shadow-banning to membership events here. Invites are blocked
Expand All @@ -1482,8 +1523,40 @@ async def handle_new_client_event(
event.event_id,
prev_event.event_id,
)
if media_restrictions:
# Sort out what event_id's were part of the restrictions.
existing_event_ids_from_media_restrictions = {
res.event_id for res in media_restrictions
}

# If the de-duplicated event_id matches one of the existing
# restrictions, then all is well. If it does not, then this
# needs to be denied as invalid
if (
prev_event.event_id
not in existing_event_ids_from_media_restrictions
):
logger.warning(
"De-duplicated state event '%s' was not already attached to this media",
prev_event.event_id,
)
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"De-duplicated state event was not already attached to this media",
Codes.INVALID_PARAM,
)

return prev_event

# Some media was trying to be attached to an event, but that media was
# already attached. Deny
if media_restrictions:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
f"These media ids, '{media_info_for_attachment}' has already been attached to a reference: {media_restrictions}",
Codes.INVALID_PARAM,
)

if not event.is_state() and event.type in [
EventTypes.Message,
EventTypes.Encrypted,
Expand Down Expand Up @@ -1552,6 +1625,7 @@ async def create_and_send_new_client_events(
event_dicts: Sequence[JsonDict],
ratelimit: bool = True,
ignore_shadow_ban: bool = False,
media_info_for_attachment: Optional[set[LocalMedia]] = None,
) -> None:
"""Helper to create and send a batch of new client events.

Expand All @@ -1573,6 +1647,8 @@ async def create_and_send_new_client_events(
ratelimit: Whether to rate limit this send.
ignore_shadow_ban: True if shadow-banned users should be allowed to
send these events.
media_info_for_attachment: An optional set of LocalMedia objects, for use in
restricting media.
"""

if not event_dicts:
Expand Down Expand Up @@ -1634,6 +1710,7 @@ async def create_and_send_new_client_events(
events_and_context,
ignore_shadow_ban=ignore_shadow_ban,
ratelimit=ratelimit,
media_info_for_attachment=media_info_for_attachment,
)

async def _persist_events(
Expand Down Expand Up @@ -2056,6 +2133,18 @@ async def persist_and_notify_client_events(

events_and_pos = []
for event in persisted_events:
# Access the 'media_references' object from the event internal metadata.
# This will be None if it was not attached during creation of the event.
maybe_media_restrictions_to_set = event.internal_metadata.media_references

if maybe_media_restrictions_to_set:
for mxc_str in maybe_media_restrictions_to_set:
mxc = MXCUri.from_str(mxc_str)
await self.store.set_media_restrictions(
mxc.server_name,
mxc.media_id,
{"restrictions": {"event_id": event.event_id}},
)
if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
Expand Down
23 changes: 23 additions & 0 deletions synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple

from matrix_common.types.mxc_uri import MXCUri

from synapse import types
from synapse.api.constants import (
AccountDataTypes,
Expand Down Expand Up @@ -52,6 +54,7 @@
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.push import ReplicationCopyPusherRestServlet
from synapse.storage.databases.main.media_repository import LocalMedia
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.storage.invite_rule import InviteRule
from synapse.types import (
Expand Down Expand Up @@ -402,6 +405,7 @@ async def _local_membership_update(
require_consent: bool = True,
outlier: bool = False,
origin_server_ts: Optional[int] = None,
media_info_for_attachment: Optional[set[LocalMedia]] = None,
) -> Tuple[str, int]:
"""
Internal membership update function to get an existing event or create
Expand Down Expand Up @@ -434,6 +438,8 @@ async def _local_membership_update(
opposed to being inline with the current DAG.
origin_server_ts: The origin_server_ts to use if a new event is created. Uses
the current timestamp if set to None.
media_info_for_attachment: An optional set of LocalMedia objects, for use in
restricting media.

Returns:
Tuple of event ID and stream ordering position
Expand Down Expand Up @@ -486,6 +492,13 @@ async def _local_membership_update(
depth=depth,
require_consent=require_consent,
outlier=outlier,
mxc_restriction_list_for_event=[
MXCUri(self._server_name, local_media.media_id)
for local_media in media_info_for_attachment
if local_media
]
if media_info_for_attachment is not None
else None,
)
context = await unpersisted_context.persist(event)
prev_state_ids = await context.get_prev_state_ids(
Expand All @@ -503,6 +516,7 @@ async def _local_membership_update(
events_and_context=[(event, context)],
extra_users=[target],
ratelimit=ratelimit,
media_info_for_attachment=media_info_for_attachment,
)
)

Expand Down Expand Up @@ -581,6 +595,7 @@ async def update_membership(
state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
origin_server_ts: Optional[int] = None,
media_info_for_attachment: Optional[set[LocalMedia]] = None,
) -> Tuple[str, int]:
"""Update a user's membership in a room.

Expand Down Expand Up @@ -611,6 +626,8 @@ async def update_membership(
based on the prev_events.
origin_server_ts: The origin_server_ts to use if a new event is created. Uses
the current timestamp if set to None.
media_info_for_attachment: An optional set of LocalMedia objects, for use in
restricting media.

Returns:
A tuple of the new event ID and stream ID.
Expand Down Expand Up @@ -673,6 +690,7 @@ async def update_membership(
state_event_ids=state_event_ids,
depth=depth,
origin_server_ts=origin_server_ts,
media_info_for_attachment=media_info_for_attachment,
)

return result
Expand All @@ -695,6 +713,7 @@ async def update_membership_locked(
state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
origin_server_ts: Optional[int] = None,
media_info_for_attachment: Optional[set[LocalMedia]] = None,
) -> Tuple[str, int]:
"""Helper for update_membership.

Expand Down Expand Up @@ -727,6 +746,8 @@ async def update_membership_locked(
based on the prev_events.
origin_server_ts: The origin_server_ts to use if a new event is created. Uses
the current timestamp if set to None.
media_info_for_attachment: An optional set of LocalMedia objects, for use in
restricting media.

Returns:
A tuple of the new event ID and stream ID.
Expand Down Expand Up @@ -931,6 +952,7 @@ async def update_membership_locked(
require_consent=require_consent,
outlier=outlier,
origin_server_ts=origin_server_ts,
media_info_for_attachment=media_info_for_attachment,
)

latest_event_ids = await self.store.get_prev_events_for_room(room_id)
Expand Down Expand Up @@ -1189,6 +1211,7 @@ async def update_membership_locked(
require_consent=require_consent,
outlier=outlier,
origin_server_ts=origin_server_ts,
media_info_for_attachment=media_info_for_attachment,
)

async def check_for_any_membership_in_room(
Expand Down
Loading
Loading