Skip to content

Commit

Permalink
Split pre/post backup actions into dedicated methods (home-assistant#…
Browse files Browse the repository at this point in the history
…110632)

* Split pre/post backup actions into dedicated methods

* Update homeassistant/components/backup/manager.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
  • Loading branch information
ludeeus and MartinHjelmare authored Feb 15, 2024
1 parent b9a8b99 commit 57d3f3f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 24 deletions.
58 changes: 34 additions & 24 deletions homeassistant/components/backup/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,38 @@ async def _add_platform(
return
self.platforms[integration_domain] = platform

async def pre_backup_actions(self) -> None:
"""Perform pre backup actions."""
if not self.loaded_platforms:
await self.load_platforms()

pre_backup_results = await asyncio.gather(
*(
platform.async_pre_backup(self.hass)
for platform in self.platforms.values()
),
return_exceptions=True,
)
for result in pre_backup_results:
if isinstance(result, Exception):
raise result

async def post_backup_actions(self) -> None:
"""Perform post backup actions."""
if not self.loaded_platforms:
await self.load_platforms()

post_backup_results = await asyncio.gather(
*(
platform.async_post_backup(self.hass)
for platform in self.platforms.values()
),
return_exceptions=True,
)
for result in post_backup_results:
if isinstance(result, Exception):
raise result

async def load_backups(self) -> None:
"""Load data of stored backup files."""
backups = await self.hass.async_add_executor_job(self._read_backups)
Expand Down Expand Up @@ -160,22 +192,9 @@ async def generate_backup(self) -> Backup:
if self.backing_up:
raise HomeAssistantError("Backup already in progress")

if not self.loaded_platforms:
await self.load_platforms()

try:
self.backing_up = True
pre_backup_results = await asyncio.gather(
*(
platform.async_pre_backup(self.hass)
for platform in self.platforms.values()
),
return_exceptions=True,
)
for result in pre_backup_results:
if isinstance(result, Exception):
raise result

await self.pre_backup_actions()
backup_name = f"Core {HAVERSION}"
date_str = dt_util.now().isoformat()
slug = _generate_slug(date_str, backup_name)
Expand Down Expand Up @@ -208,16 +227,7 @@ async def generate_backup(self) -> Backup:
return backup
finally:
self.backing_up = False
post_backup_results = await asyncio.gather(
*(
platform.async_post_backup(self.hass)
for platform in self.platforms.values()
),
return_exceptions=True,
)
for result in post_backup_results:
if isinstance(result, Exception):
raise result
await self.post_backup_actions()

def _mkdir_and_generate_backup_contents(
self,
Expand Down
50 changes: 50 additions & 0 deletions tests/components/backup/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,53 @@ async def _mock_step(hass: HomeAssistant) -> None:

with pytest.raises(HomeAssistantError):
await _mock_backup_generation(manager)


async def test_loading_platforms_when_running_pre_backup_actions(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test loading backup platforms when running post backup actions."""
manager = BackupManager(hass)

assert not manager.loaded_platforms
assert not manager.platforms

await _setup_mock_domain(
hass,
Mock(
async_pre_backup=AsyncMock(),
async_post_backup=AsyncMock(),
),
)
await manager.pre_backup_actions()

assert manager.loaded_platforms
assert len(manager.platforms) == 1

assert "Loaded 1 platforms" in caplog.text


async def test_loading_platforms_when_running_post_backup_actions(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test loading backup platforms when running post backup actions."""
manager = BackupManager(hass)

assert not manager.loaded_platforms
assert not manager.platforms

await _setup_mock_domain(
hass,
Mock(
async_pre_backup=AsyncMock(),
async_post_backup=AsyncMock(),
),
)
await manager.post_backup_actions()

assert manager.loaded_platforms
assert len(manager.platforms) == 1

assert "Loaded 1 platforms" in caplog.text

0 comments on commit 57d3f3f

Please sign in to comment.