Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add storage helper and migrate config entries #15045

Merged
merged 8 commits into from
Jun 25, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add more instance variables
  • Loading branch information
balloob committed Jun 25, 2018
commit 37b18adceb3b49e041cf4107309033aca84fb742
18 changes: 9 additions & 9 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ async def async_step_discovery(info):
import logging
import uuid

from . import data_entry_flow
from .core import callback
from .exceptions import HomeAssistantError
from .setup import async_setup_component, async_process_deps_reqs
from .util.decorator import Registry
from homeassistant import data_entry_flow
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component, async_process_deps_reqs
from homeassistant.util.decorator import Registry


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(self, hass, hass_config):
hass, self._async_create_flow, self._async_finish_flow)
self._hass_config = hass_config
self._entries = None
self._store = hass.helpers.storage.Store(STORAGE_KEY)
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)

@callback
def async_domains(self):
Expand Down Expand Up @@ -319,12 +319,12 @@ async def async_load(self):
"""Handle loading the config."""
# Migrating for config entries stored before 0.73
config = await self.hass.helpers.storage.async_migrator(
self.hass.config.path(PATH_CONFIG), self._store, STORAGE_VERSION,
self.hass.config.path(PATH_CONFIG), self._store,
old_conf_migrate_func=_old_conf_migrator
)

if config is None:
config = await self._store.async_load(STORAGE_VERSION)
config = await self._store.async_load()

self._entries = [ConfigEntry(**entry) for entry in config['entries']]

Expand Down Expand Up @@ -426,7 +426,7 @@ async def _async_schedule_save(self):
data = {
'entries': [entry.as_dict() for entry in self._entries]
}
await self._store.async_save(STORAGE_VERSION, data, delay=SAVE_DELAY)
await self._store.async_save(data, delay=SAVE_DELAY)


async def _old_conf_migrator(old_config):
Expand Down
48 changes: 29 additions & 19 deletions homeassistant/helpers/storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Helper to help store data."""
import asyncio
import logging
import os
from typing import Dict, Optional
Expand All @@ -14,8 +15,7 @@


@bind_hass
async def async_migrator(hass, old_path, store, version, *,
old_conf_migrate_func=None):
async def async_migrator(hass, old_path, store, *, old_conf_migrate_func=None):
"""Helper function to migrate old data to a store and then load data.

async def old_conf_migrate_func(old_data)
Expand All @@ -35,7 +35,7 @@ def load_old_config():
if old_conf_migrate_func is not None:
config = await old_conf_migrate_func(config)

await store.async_save(version, config)
await store.async_save(config)
await hass.async_add_executor_job(os.remove, old_path)
return config

Expand All @@ -44,16 +44,22 @@ def load_old_config():
class Store:
"""Class to help storing data."""

def __init__(self, hass, key: str):
def __init__(self, hass, version: int, key: str):
"""Initialize storage class."""
self.hass = hass
self.version = version
self.key = key
self.path = hass.config.path(STORAGE_DIR, key)
self.hass = hass
self._data = None
self._unsub_delay_listener = None
self._unsub_stop_listener = None
self._write_lock = asyncio.Lock()

async def async_load(self, expected_version, *, migrate_func=None):
@property
def path(self):
"""Return the config path."""
return self.hass.config.path(STORAGE_DIR, self.key)

async def async_load(self):
"""Load data.

If the expected version does not match the given version, the migrate
Expand All @@ -68,16 +74,15 @@ async def async_load(self, expected_version, *, migrate_func=None):
if data is None:
return {}

if data['version'] == expected_version:
if data['version'] == self.version:
return data['data']

return await migrate_func(data['version'], data['data'])
return await self._async_migrate_func(data['version'], data['data'])

async def async_save(self, data_version, data: Dict, *,
delay: Optional[int] = None):
async def async_save(self, data: Dict, *, delay: Optional[int] = None):
"""Save data with an optional delay."""
self._data = {
'version': data_version,
'version': self.version,
'key': self.key,
'data': data,
}
Expand Down Expand Up @@ -132,13 +137,14 @@ async def _handle_write_data(self, *_args):
data = self._data
self._data = None

try:
await self.hass.async_add_executor_job(
self._write_data, self.path, data)
except json.SerializationError as err:
_LOGGER.error('Error writing config for %s: %s', self.key, err)
except json.WriteError as err:
_LOGGER.error('Error writing config for %s: %s', self.key, err)
async with self._write_lock:
try:
await self.hass.async_add_executor_job(
self._write_data, self.path, data)
except json.SerializationError as err:
_LOGGER.error('Error writing config for %s: %s', self.key, err)
except json.WriteError as err:
_LOGGER.error('Error writing config for %s: %s', self.key, err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should take that toggether like:
except (json.xy, json.ser) as err:


def _write_data(self, path: str, data: Dict):
"""Write the data."""
Expand All @@ -147,3 +153,7 @@ def _write_data(self, path: str, data: Dict):

_LOGGER.debug('Writing data for %s', self.key)
json.save_json(path, data)

async def _async_migrate_func(self, old_version, old_data):
"""Migrate to the new version."""
raise NotImplementedError
37 changes: 19 additions & 18 deletions tests/helpers/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


MOCK_VERSION = 1
MOCK_KEY = 'storage-test'
MOCK_DATA = {'hello': 'world'}


Expand All @@ -35,28 +36,28 @@ def mock_load(mock_save):
@pytest.fixture
def store(hass):
"""Fixture of a store that prevents writing on HASS stop."""
store = storage.Store(hass, 'test')
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
store._async_ensure_stop_listener = lambda: None
yield store


async def test_loading(hass, store, mock_save, mock_load):
"""Test we can save and load data."""
await store.async_save(MOCK_VERSION, MOCK_DATA)
data = await store.async_load(MOCK_VERSION)
await store.async_save(MOCK_DATA)
data = await store.async_load()
assert data == MOCK_DATA


async def test_loading_non_existing(hass, store):
"""Test we can save and load data."""
with patch('homeassistant.util.json.open', side_effect=FileNotFoundError):
data = await store.async_load(MOCK_VERSION)
data = await store.async_load()
assert data == {}


async def test_saving_with_delay(hass, store, mock_save):
"""Test saving data after a delay."""
await store.async_save(MOCK_VERSION, MOCK_DATA, delay=1)
await store.async_save(MOCK_DATA, delay=1)
assert len(mock_save) == 0

async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
Expand All @@ -66,8 +67,8 @@ async def test_saving_with_delay(hass, store, mock_save):

async def test_saving_on_stop(hass, mock_save):
"""Test delayed saves trigger when we quit Home Assistant."""
store = storage.Store(hass, 'test')
await store.async_save(MOCK_VERSION, MOCK_DATA, delay=1)
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
await store.async_save(MOCK_DATA, delay=1)
assert len(mock_save) == 0

hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
Expand All @@ -77,28 +78,28 @@ async def test_saving_on_stop(hass, mock_save):

async def test_loading_while_delay(hass, store, mock_save, mock_load):
"""Test we load new data even if not written yet."""
await store.async_save(MOCK_VERSION, {'delay': 'no'})
await store.async_save({'delay': 'no'})
assert len(mock_save) == 1

await store.async_save(MOCK_VERSION, {'delay': 'yes'}, delay=1)
await store.async_save({'delay': 'yes'}, delay=1)
assert len(mock_save) == 1

data = await store.async_load(MOCK_VERSION)
data = await store.async_load()
assert data == {'delay': 'yes'}


async def test_writing_while_writing_delay(hass, store, mock_save, mock_load):
"""Test a write while a write with delay is active."""
await store.async_save(MOCK_VERSION, {'delay': 'yes'}, delay=1)
await store.async_save({'delay': 'yes'}, delay=1)
assert len(mock_save) == 0
await store.async_save(MOCK_VERSION, {'delay': 'no'})
await store.async_save({'delay': 'no'})
assert len(mock_save) == 1

async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
await hass.async_block_till_done()
assert len(mock_save) == 1

data = await store.async_load(MOCK_VERSION)
data = await store.async_load()
assert data == {'delay': 'no'}


Expand All @@ -108,7 +109,7 @@ async def test_migrator_no_existing_config(hass, store, mock_save):
patch.object(store, 'async_load',
return_value=mock_coro({'cur': 'config'})):
data = await storage.async_migrator(
hass, 'old-path', store, MOCK_VERSION)
hass, 'old-path', store)

assert data is None
assert len(mock_save) == 0
Expand All @@ -121,13 +122,13 @@ async def test_migrator_existing_config(hass, store, mock_save):
patch('homeassistant.util.json.load_json',
return_value={'old': 'config'}):
data = await storage.async_migrator(
hass, 'old-path', store, MOCK_VERSION)
hass, 'old-path', store)

assert len(mock_remove.mock_calls) == 1
assert data == {'old': 'config'}
assert len(mock_save) == 1
assert mock_save[0][1] == {
'key': 'test',
'key': MOCK_KEY,
'version': MOCK_VERSION,
'data': data,
}
Expand All @@ -144,14 +145,14 @@ async def old_conf_migrate_func(old_config):
patch('homeassistant.util.json.load_json',
return_value={'old': 'config'}):
data = await storage.async_migrator(
hass, 'old-path', store, MOCK_VERSION,
hass, 'old-path', store,
old_conf_migrate_func=old_conf_migrate_func)

assert len(mock_remove.mock_calls) == 1
assert data == {'new': 'config'}
assert len(mock_save) == 1
assert mock_save[0][1] == {
'key': 'test',
'key': MOCK_KEY,
'version': MOCK_VERSION,
'data': data,
}