Skip to content

Commit

Permalink
Refactor recorder data migration (home-assistant#121009)
Browse files Browse the repository at this point in the history
* Refactor recorder data migration

* Fix stale docstrings

* Don't store a session object in BaseRunTimeMigration instances

* Simplify logic in EntityIDMigration.migration_done

* Fix tests
  • Loading branch information
emontnemery authored Jul 16, 2024
1 parent baa97ca commit 9970b7e
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 183 deletions.
58 changes: 8 additions & 50 deletions homeassistant/components/recorder/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
SupportedDialect,
)
from .db_schema import (
LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX,
LEGACY_STATES_EVENT_ID_INDEX,
SCHEMA_VERSION,
TABLE_STATES,
Expand All @@ -91,7 +90,6 @@
)
from .executor import DBInterruptibleThreadPoolExecutor
from .migration import (
BaseRunTimeMigration,
EntityIDMigration,
EventsContextIDMigration,
EventTypeIDMigration,
Expand All @@ -115,7 +113,6 @@
CommitTask,
CompileMissingStatisticsTask,
DatabaseLockTask,
EntityIDPostMigrationTask,
EventIdMigrationTask,
ImportStatisticsTask,
KeepAliveTask,
Expand Down Expand Up @@ -804,37 +801,14 @@ def _activate_and_set_db_ready(self) -> None:
for row in execute_stmt_lambda_element(session, get_migration_changes())
}

migrator: BaseRunTimeMigration
for migrator_cls in (StatesContextIDMigration, EventsContextIDMigration):
migrator = migrator_cls(session, schema_version, migration_changes)
if migrator.needs_migrate():
self.queue_task(migrator.task())

migrator = EventTypeIDMigration(session, schema_version, migration_changes)
if migrator.needs_migrate():
self.queue_task(migrator.task())
else:
_LOGGER.debug("Activating event_types manager as all data is migrated")
self.event_type_manager.active = True

migrator = EntityIDMigration(session, schema_version, migration_changes)
if migrator.needs_migrate():
self.queue_task(migrator.task())
else:
_LOGGER.debug("Activating states_meta manager as all data is migrated")
self.states_meta_manager.active = True
with contextlib.suppress(SQLAlchemyError):
# If ix_states_entity_id_last_updated_ts still exists
# on the states table it means the entity id migration
# finished by the EntityIDPostMigrationTask did not
# complete because they restarted in the middle of it. We need
# to pick back up where we left off.
if get_index_by_name(
session,
TABLE_STATES,
LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX,
):
self.queue_task(EntityIDPostMigrationTask())
for migrator_cls in (
StatesContextIDMigration,
EventsContextIDMigration,
EventTypeIDMigration,
EntityIDMigration,
):
migrator = migrator_cls(schema_version, migration_changes)
migrator.do_migrate(self, session)

if self.schema_version > LEGACY_STATES_EVENT_ID_INDEX_SCHEMA_VERSION:
with contextlib.suppress(SQLAlchemyError):
Expand Down Expand Up @@ -1319,22 +1293,6 @@ def _legacy_event_id_foreign_key_exists(self) -> bool:
)
)

def _migrate_states_context_ids(self) -> bool:
"""Migrate states context ids if needed."""
return migration.migrate_states_context_ids(self)

def _migrate_events_context_ids(self) -> bool:
"""Migrate events context ids if needed."""
return migration.migrate_events_context_ids(self)

def _migrate_event_type_ids(self) -> bool:
"""Migrate event type ids if needed."""
return migration.migrate_event_type_ids(self)

def _migrate_entity_ids(self) -> bool:
"""Migrate entity_ids if needed."""
return migration.migrate_entity_ids(self)

def _post_migrate_entity_ids(self) -> bool:
"""Post migrate entity_ids if needed."""
return migration.post_migrate_entity_ids(self)
Expand Down
120 changes: 101 additions & 19 deletions homeassistant/components/recorder/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,9 @@
from .statistics import get_start_time
from .tasks import (
CommitTask,
EntityIDMigrationTask,
EventsContextIDMigrationTask,
EventTypeIDMigrationTask,
EntityIDPostMigrationTask,
PostSchemaMigrationTask,
RecorderTask,
StatesContextIDMigrationTask,
StatisticsTimestampMigrationCleanupTask,
)
from .util import (
Expand Down Expand Up @@ -2001,9 +1998,6 @@ def migrate_event_type_ids(instance: Recorder) -> bool:
if is_done := not events:
_mark_migration_done(session, EventTypeIDMigration)

if is_done:
instance.event_type_manager.active = True

_LOGGER.debug("Migrating event_types done=%s", is_done)
return is_done

Expand Down Expand Up @@ -2182,27 +2176,62 @@ def initialize_database(session_maker: Callable[[], Session]) -> bool:
return False


@dataclass(slots=True)
class MigrationTask(RecorderTask):
"""Base class for migration tasks."""

migrator: BaseRunTimeMigration
commit_before = False

def run(self, instance: Recorder) -> None:
"""Run migration task."""
if not self.migrator.migrate_data(instance):
# Schedule a new migration task if this one didn't finish
instance.queue_task(MigrationTask(self.migrator))
else:
self.migrator.migration_done(instance)


@dataclass(slots=True)
class CommitBeforeMigrationTask(MigrationTask):
"""Base class for migration tasks which commit first."""

commit_before = True


class BaseRunTimeMigration(ABC):
"""Base class for run time migrations."""

required_schema_version = 0
migration_version = 1
migration_id: str
task: Callable[[], RecorderTask]
task = MigrationTask

def __init__(
self, session: Session, schema_version: int, migration_changes: dict[str, int]
) -> None:
def __init__(self, schema_version: int, migration_changes: dict[str, int]) -> None:
"""Initialize a new BaseRunTimeMigration."""
self.schema_version = schema_version
self.session = session
self.migration_changes = migration_changes

def do_migrate(self, instance: Recorder, session: Session) -> None:
"""Start migration if needed."""
if self.needs_migrate(session):
instance.queue_task(self.task(self))
else:
self.migration_done(instance)

@staticmethod
@abstractmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""

def migration_done(self, instance: Recorder) -> None:
"""Will be called after migrate returns True or if migration is not needed."""

@abstractmethod
def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run."""

def needs_migrate(self) -> bool:
def needs_migrate(self, session: Session) -> bool:
"""Return if the migration needs to run.
If the migration needs to run, it will return True.
Expand All @@ -2220,8 +2249,8 @@ def needs_migrate(self) -> bool:
# We do not know if the migration is done from the
# migration changes table so we must check the data
# This is the slow path
if not execute_stmt_lambda_element(self.session, self.needs_migrate_query()):
_mark_migration_done(self.session, self.__class__)
if not execute_stmt_lambda_element(session, self.needs_migrate_query()):
_mark_migration_done(session, self.__class__)
return False
return True

Expand All @@ -2231,7 +2260,11 @@ class StatesContextIDMigration(BaseRunTimeMigration):

required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
migration_id = "state_context_id_as_binary"
task = StatesContextIDMigrationTask

@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_states_context_ids(instance)

def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run."""
Expand All @@ -2243,7 +2276,11 @@ class EventsContextIDMigration(BaseRunTimeMigration):

required_schema_version = CONTEXT_ID_AS_BINARY_SCHEMA_VERSION
migration_id = "event_context_id_as_binary"
task = EventsContextIDMigrationTask

@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_events_context_ids(instance)

def needs_migrate_query(self) -> StatementLambdaElement:
"""Return the query to check if the migration needs to run."""
Expand All @@ -2255,7 +2292,20 @@ class EventTypeIDMigration(BaseRunTimeMigration):

required_schema_version = EVENT_TYPE_IDS_SCHEMA_VERSION
migration_id = "event_type_id_migration"
task = EventTypeIDMigrationTask
task = CommitBeforeMigrationTask
# We have to commit before to make sure there are
# no new pending event_types about to be added to
# the db since this happens live

@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_event_type_ids(instance)

def migration_done(self, instance: Recorder) -> None:
"""Will be called after migrate returns True."""
_LOGGER.debug("Activating event_types manager as all data is migrated")
instance.event_type_manager.active = True

def needs_migrate_query(self) -> StatementLambdaElement:
"""Check if the data is migrated."""
Expand All @@ -2267,7 +2317,39 @@ class EntityIDMigration(BaseRunTimeMigration):

required_schema_version = STATES_META_SCHEMA_VERSION
migration_id = "entity_id_migration"
task = EntityIDMigrationTask
task = CommitBeforeMigrationTask
# We have to commit before to make sure there are
# no new pending states_meta about to be added to
# the db since this happens live

@staticmethod
def migrate_data(instance: Recorder) -> bool:
"""Migrate some data, returns True if migration is completed."""
return migrate_entity_ids(instance)

def migration_done(self, instance: Recorder) -> None:
"""Will be called after migrate returns True."""
# The migration has finished, now we start the post migration
# to remove the old entity_id data from the states table
# at this point we can also start using the StatesMeta table
# so we set active to True
_LOGGER.debug("Activating states_meta manager as all data is migrated")
instance.states_meta_manager.active = True
with (
contextlib.suppress(SQLAlchemyError),
session_scope(session=instance.get_session()) as session,
):
# If ix_states_entity_id_last_updated_ts still exists
# on the states table it means the entity id migration
# finished by the EntityIDPostMigrationTask did not
# complete because they restarted in the middle of it. We need
# to pick back up where we left off.
if get_index_by_name(
session,
TABLE_STATES,
LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX,
):
instance.queue_task(EntityIDPostMigrationTask())

def needs_migrate_query(self) -> StatementLambdaElement:
"""Check if the data is migrated."""
Expand Down
69 changes: 0 additions & 69 deletions homeassistant/components/recorder/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,75 +358,6 @@ def run(self, instance: Recorder) -> None:
instance._adjust_lru_size() # noqa: SLF001


@dataclass(slots=True)
class StatesContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate states context ids."""

commit_before = False

def run(self, instance: Recorder) -> None:
"""Run context id migration task."""
if (
not instance._migrate_states_context_ids() # noqa: SLF001
):
# Schedule a new migration task if this one didn't finish
instance.queue_task(StatesContextIDMigrationTask())


@dataclass(slots=True)
class EventsContextIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate events context ids."""

commit_before = False

def run(self, instance: Recorder) -> None:
"""Run context id migration task."""
if (
not instance._migrate_events_context_ids() # noqa: SLF001
):
# Schedule a new migration task if this one didn't finish
instance.queue_task(EventsContextIDMigrationTask())


@dataclass(slots=True)
class EventTypeIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate event type ids."""

commit_before = True
# We have to commit before to make sure there are
# no new pending event_types about to be added to
# the db since this happens live

def run(self, instance: Recorder) -> None:
"""Run event type id migration task."""
if not instance._migrate_event_type_ids(): # noqa: SLF001
# Schedule a new migration task if this one didn't finish
instance.queue_task(EventTypeIDMigrationTask())


@dataclass(slots=True)
class EntityIDMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to migrate entity_ids to StatesMeta."""

commit_before = True
# We have to commit before to make sure there are
# no new pending states_meta about to be added to
# the db since this happens live

def run(self, instance: Recorder) -> None:
"""Run entity_id migration task."""
if not instance._migrate_entity_ids(): # noqa: SLF001
# Schedule a new migration task if this one didn't finish
instance.queue_task(EntityIDMigrationTask())
else:
# The migration has finished, now we start the post migration
# to remove the old entity_id data from the states table
# at this point we can also start using the StatesMeta table
# so we set active to True
instance.states_meta_manager.active = True
instance.queue_task(EntityIDPostMigrationTask())


@dataclass(slots=True)
class EntityIDPostMigrationTask(RecorderTask):
"""An object to insert into the recorder queue to cleanup after entity_ids migration."""
Expand Down
10 changes: 9 additions & 1 deletion tests/components/recorder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,14 @@ def get_schema_module_path(schema_version_postfix: str) -> str:
return f"tests.components.recorder.db_schema_{schema_version_postfix}"


@dataclass(slots=True)
class MockMigrationTask(migration.MigrationTask):
"""Mock migration task which does nothing."""

def run(self, instance: Recorder) -> None:
"""Run migration task."""


@contextmanager
def old_db_schema(schema_version_postfix: str) -> Iterator[None]:
"""Fixture to initialize the db with the old schema."""
Expand All @@ -434,7 +442,7 @@ def old_db_schema(schema_version_postfix: str) -> Iterator[None]:
patch.object(core, "States", old_db_schema.States),
patch.object(core, "Events", old_db_schema.Events),
patch.object(core, "StateAttributes", old_db_schema.StateAttributes),
patch.object(migration.EntityIDMigration, "task", core.RecorderTask),
patch.object(migration.EntityIDMigration, "task", MockMigrationTask),
patch(
CREATE_ENGINE_TARGET,
new=partial(
Expand Down
Loading

0 comments on commit 9970b7e

Please sign in to comment.