Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Don't pull out the full state when storing state #13274

Merged
merged 14 commits into from
Jul 15, 2022
1 change: 1 addition & 0 deletions changelog.d/13274.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Don't pull out state in `compute_event_context` for unconflicted state.
36 changes: 21 additions & 15 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,18 @@ async def compute_event_context(

state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
state_ids_before_event = None

# We make sure that we have a state group assigned to the state.
if entry.state_group is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)
# store_state_group requires us to have either a previous state group
# (with deltas) or the complete state map. So, if we don't have a
# previous state group, load the complete state map now.
if state_group_before_event_prev_group is None:
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)

state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
Expand All @@ -316,7 +322,6 @@ async def compute_event_context(
entry.state_group = state_group_before_event
else:
state_group_before_event = entry.state_group
state_ids_before_event = None

#
# now if it's not a state event, we're done
Expand All @@ -336,19 +341,20 @@ async def compute_event_context(
#
# otherwise, we'll need to create a new state group for after the event
#
if state_ids_before_event is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)

key = (event.type, event.state_key)
if key in state_ids_before_event:
replaces = state_ids_before_event[key]
if replaces != event.event_id:
event.unsigned["replaces_state"] = replaces

state_ids_after_event = dict(state_ids_before_event)
state_ids_after_event[key] = event.event_id
if state_ids_before_event is not None:
replaces = state_ids_before_event.get(key)
else:
replaces_state_map = await entry.get_state(
self._state_storage_controller, StateFilter.from_types([key])
)
replaces = replaces_state_map.get(key)

if replaces and replaces != event.event_id:
event.unsigned["replaces_state"] = replaces

delta_ids = {key: event.event_id}

state_group_after_event = (
Expand All @@ -357,7 +363,7 @@ async def compute_event_context(
event.room_id,
prev_group=state_group_before_event,
delta_ids=delta_ids,
current_state_ids=state_ids_after_event,
current_state_ids=None,
)
)

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ async def store_state_group(
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
current_state_ids: StateMap[str],
current_state_ids: Optional[StateMap[str]],
) -> int:
"""Store a new set of state, returning a newly assigned state group.

Expand Down
156 changes: 103 additions & 53 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,17 @@ async def store_state_group(
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
current_state_ids: StateMap[str],
current_state_ids: Optional[StateMap[str]],
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
) -> int:
"""Store a new set of state, returning a newly assigned state group.

At least one of `current_state_ids` and `prev_group` must be provided. Whenever
`prev_group` is not None, `delta_ids` must also not be None.

Args:
event_id: The event ID for which the state was calculated
room_id
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
prev_group: A previous state group for the room, optional.
prev_group: A previous state group for the room.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
Expand All @@ -418,10 +421,41 @@ async def store_state_group(
The state group ID
"""

def _store_state_group_txn(txn: LoggingTransaction) -> int:
if current_state_ids is None:
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
if prev_group is None and current_state_ids is None:
raise Exception("current_state_ids and prev_group can't both be None")

if prev_group is not None and delta_ids is None:
raise Exception("delta_ids is None when prev_group is not None")

def insert_delta_group_txn(
txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str]
) -> Optional[int]:
"""Try and persist the new group as a delta.

Requires that we have the state as a delta from a previous state group.

Returns:
The state group if successfully created, or None if the state
needs to be persisted as a full state.
"""
is_in_db = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
)

# if the chain of state group deltas is going too long, we fall back to
# persisting a complete state group.
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if potential_hops >= MAX_STATE_DELTA_HOPS:
return None

state_group = self._state_group_seq_gen.get_next_id_txn(txn)

Expand All @@ -431,51 +465,45 @@ def _store_state_group_txn(txn: LoggingTransaction) -> int:
values={"id": state_group, "room_id": room_id, "event_id": event_id},
)

# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
is_in_db = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
)

potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
assert delta_ids is not None

self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
values={"state_group": state_group, "prev_state_group": prev_group},
)
self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
values={"state_group": state_group, "prev_state_group": prev_group},
)

self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
(state_group, room_id, key[0], key[1], state_id)
for key, state_id in delta_ids.items()
],
)
else:
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
(state_group, room_id, key[0], key[1], state_id)
for key, state_id in current_state_ids.items()
],
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
(state_group, room_id, key[0], key[1], state_id)
for key, state_id in delta_ids.items()
],
)

return state_group

def insert_full_state_txn(
txn: LoggingTransaction, current_state_ids: StateMap[str]
) -> int:
"""Persist the full state, returning the new state group."""
state_group = self._state_group_seq_gen.get_next_id_txn(txn)

self.db_pool.simple_insert_txn(
txn,
table="state_groups",
values={"id": state_group, "room_id": room_id, "event_id": event_id},
)

self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
keys=("state_group", "room_id", "type", "state_key", "event_id"),
values=[
(state_group, room_id, key[0], key[1], state_id)
for key, state_id in current_state_ids.items()
],
)

# Prefill the state group caches with this group.
# It's fine to use the sequence like this as the state group map
Expand All @@ -491,7 +519,7 @@ def _store_state_group_txn(txn: LoggingTransaction) -> int:
self._state_group_members_cache.update,
self._state_group_members_cache.sequence,
key=state_group,
value=dict(current_member_state_ids),
value=current_member_state_ids,
)

current_non_member_state_ids = {
Expand All @@ -503,13 +531,35 @@ def _store_state_group_txn(txn: LoggingTransaction) -> int:
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
value=dict(current_non_member_state_ids),
value=current_non_member_state_ids,
)

return state_group

if prev_group is not None:
state_group = await self.db_pool.runInteraction(
"store_state_group.insert_delta_group",
insert_delta_group_txn,
prev_group,
delta_ids,
)
if state_group is not None:
return state_group

# We're going to persist the state as a complete group rather than
# a delta, so first we need to ensure we have loaded the state map
# from the database.
if current_state_ids is None:
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
assert prev_group is not None
assert delta_ids is not None
groups = await self._get_state_for_groups([prev_group])
current_state_ids = dict(groups[prev_group])
current_state_ids.update(delta_ids)

return await self.db_pool.runInteraction(
"store_state_group", _store_state_group_txn
"store_state_group.insert_full_state",
insert_full_state_txn,
current_state_ids,
)

async def purge_unreferenced_state_groups(
Expand Down
4 changes: 2 additions & 2 deletions tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def test_post_room_no_keys(self) -> None:
self.assertEqual(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(32, channel.resource_usage.db_txn_count)
self.assertEqual(36, channel.resource_usage.db_txn_count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit surprised these transaction counts increase. AFAICT the only reason we would end up doing more transactions is if we start doing an insert_delta_group_txn but abort that because the delta chain is too long. Any ideas?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its where we add the replaces_state to the unsigned section:

https://github.com/matrix-org/synapse/pull/13274/files#diff-d05c474c9fe45057f52616f38e54f6cdb3fa80a5a596ecb9c8fec3026ff8d68eR343-R346

So in the test we pull out some partial state for the group there, and then when we come to create the next event we pull out some other partial state for the group


def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
Expand All @@ -722,7 +722,7 @@ def test_post_room_initial_state(self) -> None:
self.assertEqual(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(35, channel.resource_usage.db_txn_count)
self.assertEqual(40, channel.resource_usage.db_txn_count)

def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
Expand Down
4 changes: 4 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ async def store_state_group(
state_group = self._next_group
self._next_group += 1

if current_state_ids is None:
current_state_ids = dict(self._group_to_state[prev_group])
current_state_ids.update(delta_ids)

self._group_to_state[state_group] = dict(current_state_ids)

return state_group
Expand Down