Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to state database module. (#10823)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Sep 15, 2021
1 parent b932590 commit 3eba047
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 72 deletions.
1 change: 1 addition & 0 deletions changelog.d/10823.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to the state database.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ files =
synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/databases/state,
synapse/storage/database.py,
synapse/storage/engines,
synapse/storage/keys.py,
Expand Down
60 changes: 41 additions & 19 deletions synapse/storage/databases/state/bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
# limitations under the License.

import logging
from typing import Optional
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

Expand All @@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
updates.
"""

def _count_state_group_hops_txn(self, txn, state_group):
def _count_state_group_hops_txn(
self, txn: LoggingTransaction, state_group: int
) -> int:
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
Expand All @@ -56,7 +66,7 @@ def _count_state_group_hops_txn(self, txn, state_group):
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
next_group: Optional[int] = state_group
count = 0

while next_group:
Expand All @@ -73,11 +83,14 @@ def _count_state_group_hops_txn(self, txn, state_group):
return count

def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter: Optional[StateFilter] = None
):
self,
txn: LoggingTransaction,
groups: List[int],
state_filter: Optional[StateFilter] = None,
) -> Mapping[int, StateMap[str]]:
state_filter = state_filter or StateFilter.all()

results = {group: {} for group in groups}
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}

where_clause, where_args = state_filter.make_sql_filter_clause()

Expand Down Expand Up @@ -117,7 +130,7 @@ def _get_state_groups_from_groups_txn(
"""

for group in groups:
args = [group]
args: List[Union[int, str]] = [group]
args.extend(where_args)

txn.execute(sql % (where_clause,), args)
Expand All @@ -131,7 +144,7 @@ def _get_state_groups_from_groups_txn(
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
next_group: Optional[int] = group

while next_group:
# We did this before by getting the list of group ids, and
Expand Down Expand Up @@ -173,6 +186,7 @@ def _get_state_groups_from_groups_txn(
allow_none=True,
)

# The results shouldn't be considered mutable.
return results


Expand All @@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
Expand All @@ -198,7 +217,9 @@ def __init__(self, database: DatabasePool, db_conn, hs):
columns=["room_id"],
)

async def _background_deduplicate_state(self, progress, batch_size):
async def _background_deduplicate_state(
self, progress: dict, batch_size: int
) -> int:
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
Expand All @@ -218,7 +239,7 @@ async def _background_deduplicate_state(self, progress, batch_size):
)
max_group = rows[0][0]

def reindex_txn(txn):
def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
new_last_state_group = last_state_group
for count in range(batch_size):
txn.execute(
Expand Down Expand Up @@ -251,7 +272,8 @@ def reindex_txn(txn):
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
(prev_group,) = txn.fetchone()
# There will be a result due to the coalesce.
(prev_group,) = txn.fetchone() # type: ignore
new_last_state_group = state_group

if prev_group:
Expand All @@ -261,15 +283,15 @@ def reindex_txn(txn):
# otherwise read performance degrades.
continue

prev_state = self._get_state_groups_from_groups_txn(
prev_state_by_group = self._get_state_groups_from_groups_txn(
txn, [prev_group]
)
prev_state = prev_state[prev_group]
prev_state = prev_state_by_group[prev_group]

curr_state = self._get_state_groups_from_groups_txn(
curr_state_by_group = self._get_state_groups_from_groups_txn(
txn, [state_group]
)
curr_state = curr_state[state_group]
curr_state = curr_state_by_group[state_group]

if not set(prev_state.keys()) - set(curr_state.keys()):
# We can only do a delta if the current has a strict super set
Expand Down Expand Up @@ -340,8 +362,8 @@ def reindex_txn(txn):

return result * BATCH_SIZE_SCALE_FACTOR

async def _background_index_state(self, progress, batch_size):
def reindex_txn(conn):
async def _background_index_state(self, progress: dict, batch_size: int) -> int:
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index
Expand Down
Loading

0 comments on commit 3eba047

Please sign in to comment.