Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add an API for listing threads in a room. #13394

Merged
merged 18 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13394.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API.
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,8 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)

# MSC3856: Threads list API
self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)

# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
89 changes: 89 additions & 0 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple

Expand All @@ -31,6 +32,13 @@
logger = logging.getLogger(__name__)


class ThreadsListInclude(str, enum.Enum):
"""Valid values for the 'include' flag of /threads."""

all = "all"
participated = "participated"


@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
Expand Down Expand Up @@ -482,3 +490,84 @@ async def get_bundled_aggregations(
results.setdefault(event_id, BundledAggregations()).replace = edit

return results

async def get_threads(
self,
requester: Requester,
room_id: str,
include: ThreadsListInclude,
limit: int = 5,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.

Args:
requester: The user requesting the relations.
room_id: The room the event belongs to.
include: One of "all" or "participated" to indicate which threads should
be returned.
limit: Only fetch the most recent `limit` events.
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.

Returns:
The pagination chunk.
"""

user_id = requester.user.to_string()

# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
room_id, requester, allow_departed_users=True
)

# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
thread_roots, next_token = await self._main_store.get_threads(
room_id=room_id, limit=limit, from_token=from_token, to_token=to_token
)

events = await self._main_store.get_events_as_list(thread_roots)

if include == ThreadsListInclude.participated:
# Pre-seed thread participation with whether the requester sent the event.
participated = {event.event_id: event.sender == user_id for event in events}
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
[eid for eid, p in participated.items() if not p],
user_id,
)
)

# Limit the returned threads to those the user has participated in.
events = [event for event in events if participated[event.event_id]]

events = await filter_events_for_client(
self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
)

now = self._clock.time_msec()

aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)

return_value: JsonDict = {"chunk": serialized_events}

if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)

if from_token:
return_value["prev_batch"] = await from_token.to_string(self._main_store)

return return_value
52 changes: 52 additions & 0 deletions synapse/rest/client/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import logging
import re
from typing import TYPE_CHECKING, Optional, Tuple

from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
Expand Down Expand Up @@ -91,5 +93,55 @@ async def on_GET(
return 200, result


class ThreadsServlet(RestServlet):
PATTERNS = (
re.compile(
"^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
),
)

def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()

async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)

limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")
include = parse_string(
request,
"include",
default=ThreadsListInclude.all.value,
allowed_values=[v.value for v in ThreadsListInclude],
)

# Return the relations
from_token = None
if from_token_str:
from_token = await StreamToken.from_string(self.store, from_token_str)
to_token = None
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
clokep marked this conversation as resolved.
Show resolved Hide resolved

result = await self._relations_handler.get_threads(
requester=requester,
room_id=room_id,
include=ThreadsListInclude(include),
limit=limit,
from_token=from_token,
to_token=to_token,
)

return 200, result


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
if hs.config.experimental.msc3856_enabled:
ThreadsServlet(hs).register(http_server)
7 changes: 5 additions & 2 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,7 +1596,7 @@ def _update_metadata_tables_txn(
)

# Remove from relations table.
self._handle_redact_relations(txn, event.redacts)
self._handle_redact_relations(txn, event.room_id, event.redacts)

# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
Expand Down Expand Up @@ -1911,6 +1911,7 @@ def _handle_event_relations(
self.store.get_thread_participated.invalidate,
(relation.parent_id, event.sender),
)
txn.call_after(self.store.get_threads.invalidate, (event.room_id,))

def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
Expand Down Expand Up @@ -2035,13 +2036,14 @@ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None
txn.execute(sql, (batch_id,))

def _handle_redact_relations(
self, txn: LoggingTransaction, redacted_event_id: str
self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None:
"""Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database.

Args:
txn
room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted.
"""

Expand Down Expand Up @@ -2070,6 +2072,7 @@ def _handle_redact_relations(
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
txn.call_after(self.store.get_threads.invalidate, (room_id,))
self.store._invalidate_cache_and_stream(
txn,
self.store.get_mutual_event_relations_for_rel_type,
Expand Down
87 changes: 87 additions & 0 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,93 @@ def _get_event_relations(
"get_event_relations", _get_event_relations
)

@cached(tree=True)
async def get_threads(
self,
room_id: str,
limit: int = 5,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[str], Optional[StreamToken]]:
"""Get a list of thread IDs, ordered by topological ordering of their
latest reply.

Args:
room_id: The room the event belongs to.
limit: Only fetch the most recent `limit` threads.
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.

Returns:
A tuple of:
A list of thread root event IDs.

The next stream token, if one exists.
"""
pagination_clause = generate_pagination_where_clause(
direction="b",
column_names=("topological_ordering", "stream_ordering"),
from_token=from_token.room_key.as_historical_tuple()
if from_token
else None,
to_token=to_token.room_key.as_historical_tuple() if to_token else None,
engine=self.database_engine,
)

if pagination_clause:
pagination_clause = "AND " + pagination_clause

sql = f"""
SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
room_id = ? AND
relation_type = '{RelationTypes.THREAD}'
{pagination_clause}
GROUP BY relates_to_id
ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC
LIMIT ?
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""

def _get_threads_txn(
txn: LoggingTransaction,
) -> Tuple[List[str], Optional[StreamToken]]:
txn.execute(sql, [room_id, limit + 1])

last_topo_id = None
last_stream_id = None
thread_ids = []
for thread_id, topo_id, stream_id in txn:
thread_ids.append(thread_id)
last_topo_id = topo_id
last_stream_id = stream_id

# If there are more events, generate the next pagination key.
next_token = None
if len(thread_ids) > limit and last_topo_id and last_stream_id:
next_key = RoomStreamToken(last_topo_id, last_stream_id)
if from_token:
next_token = from_token.copy_and_replace(
StreamKeyType.ROOM, next_key
)
else:
next_token = StreamToken(
room_key=next_key,
presence_key=0,
typing_key=0,
receipt_key=0,
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
groups_key=0,
)

return thread_ids[:limit], next_token

return await self.db_pool.runInteraction("get_threads", _get_threads_txn)


class RelationsStore(RelationsWorkerStore):
pass
Loading