From 52f8882de77938bc8a8b4eef158aef2a1d55542e Mon Sep 17 00:00:00 2001 From: LanderOtto <48457093+LanderOtto@users.noreply.github.com> Date: Thu, 30 Nov 2023 08:28:34 +0100 Subject: [PATCH] Fixed Config persistence (#289) This commit fixes the `ScheduleStep` persistence, particularly the `BindingFilter` object. The `filters` attribute in `BindingFilter` was not being saved correctly in the database. A new class extends the `Config`, called `FilterConfig`, and a new table was added to the database. The extention of `PersistableEntity` was moved from `Config` to its superclasses only where needed, i.e. `DeploymentConfig` and `FilterConfig`. --- streamflow/config/config.py | 12 +++---- streamflow/core/config.py | 34 ++++++------------ streamflow/core/deployment.py | 39 ++++++++++++++++++-- streamflow/core/persistence.py | 31 ++++++++++++++-- streamflow/deployment/utils.py | 19 +++++++--- streamflow/persistence/base.py | 1 + streamflow/persistence/loading_context.py | 11 +++++- streamflow/persistence/schemas/sqlite.sql | 9 +++++ streamflow/persistence/sqlite.py | 43 +++++++++++++++++++++-- tests/test_persistence.py | 38 ++++++++++++++------ 10 files changed, 185 insertions(+), 52 deletions(-) diff --git a/streamflow/config/config.py b/streamflow/config/config.py index 58ff52b5..dc7c67cb 100644 --- a/streamflow/config/config.py +++ b/streamflow/config/config.py @@ -26,7 +26,7 @@ def __init__(self, name: str, config: MutableMapping[str, Any]) -> None: super().__init__( name=name, type=workflow_config["type"], config=workflow_config["config"] ) - self.deplyoments = config.get("deployments", {}) + self.deployments = config.get("deployments", {}) self.policies = { k: Config(name=k, type=v["type"], config=v["config"]) for k, v in config.get("scheduling", {}).get("policies", {}).items() @@ -38,16 +38,16 @@ def __init__(self, name: str, config: MutableMapping[str, Any]) -> None: k: Config(name=k, type=v["type"], config=v["config"]) for k, v in config.get("bindingFilters", {}).items() } - if not self.deplyoments: - self.deplyoments = config.get("models", {}) - if self.deplyoments: + if not self.deployments: + self.deployments = config.get("models", {}) + if self.deployments: if logger.isEnabledFor(logging.WARNING): logger.warning( "The `models` keyword is deprecated and will be removed in StreamFlow 0.3.0. " "Use `deployments` instead." ) self.scheduling_groups: MutableMapping[str, MutableSequence[str]] = {} - for name, deployment in self.deplyoments.items(): + for name, deployment in self.deployments.items(): deployment["name"] = name self.filesystem = {"children": {}} for binding in workflow_config.get("bindings", []): @@ -68,7 +68,7 @@ def _process_binding(self, binding: MutableMapping[str, Any]): for target in targets: policy = target.get( "policy", - self.deplyoments[target.get("deployment", target.get("model", {}))].get( + self.deployments[target.get("deployment", target.get("model", {}))].get( "policy", "__DEFAULT__" ), ) diff --git a/streamflow/core/config.py b/streamflow/core/config.py index 38b1716e..59651123 100644 --- a/streamflow/core/config.py +++ b/streamflow/core/config.py @@ -3,44 +3,31 @@ import asyncio from typing import Any, MutableMapping, MutableSequence, TYPE_CHECKING, cast -from streamflow.core.persistence import PersistableEntity - if TYPE_CHECKING: from streamflow.core.context import StreamFlowContext - from streamflow.core.deployment import Target + from streamflow.core.deployment import Target, FilterConfig from streamflow.core.persistence import DatabaseLoadingContext -class Config(PersistableEntity): +class Config: __slots__ = ("name", "type", "config") def __init__(self, name: str, type: str, config: MutableMapping[str, Any]) -> None: - super().__init__() self.name: str = name self.type: str = type self.config: MutableMapping[str, Any] = config or {} - @classmethod - async def load( - cls, - context: StreamFlowContext, - row: MutableMapping[str, Any], - loading_context: DatabaseLoadingContext, - ) -> Config: - return cls(row["name"], row["type"], row["config"]) - - async def save(self, context: StreamFlowContext): - return {"name": self.name, "type": self.type, "config": self.config} - class BindingConfig: __slots__ = ("targets", "filters") def __init__( - self, targets: MutableSequence[Target], filters: MutableSequence[Config] = None + self, + targets: MutableSequence[Target], + filters: MutableSequence[FilterConfig] = None, ): self.targets: MutableSequence[Target] = targets - self.filters: MutableSequence[Config] = filters or [] + self.filters: MutableSequence[FilterConfig] = filters or [] @classmethod async def load( @@ -63,7 +50,7 @@ async def load( MutableSequence, await asyncio.gather( *( - asyncio.create_task(Config.load(context, f, loading_context)) + asyncio.create_task(loading_context.load_filter(context, f)) for f in row["filters"] ) ), @@ -72,11 +59,10 @@ async def load( async def save(self, context: StreamFlowContext): await asyncio.gather( - *(asyncio.create_task(t.save(context)) for t in self.targets) + *(asyncio.create_task(t.save(context)) for t in self.targets), + *(asyncio.create_task(f.save(context)) for f in self.filters), ) return { "targets": [t.persistent_id for t in self.targets], - "filters": await asyncio.gather( - *(asyncio.create_task(f.save(context)) for f in self.filters) - ), + "filters": [f.persistent_id for f in self.filters], } diff --git a/streamflow/core/deployment.py b/streamflow/core/deployment.py index 98a09f95..d2153fdd 100644 --- a/streamflow/core/deployment.py +++ b/streamflow/core/deployment.py @@ -162,7 +162,7 @@ async def undeploy_all(self): ... -class DeploymentConfig(Config): +class DeploymentConfig(Config, PersistableEntity): __slots__ = ("name", "type", "config", "external", "lazy", "workdir", "wraps") def __init__( @@ -175,7 +175,8 @@ def __init__( workdir: str | None = None, wraps: str | None = None, ) -> None: - super().__init__(name, type, config) + Config.__init__(self, name, type, config) + PersistableEntity.__init__(self) self.external: bool = external self.lazy: bool = lazy self.workdir: str | None = workdir @@ -307,5 +308,37 @@ async def _load( context: StreamFlowContext, row: MutableMapping[str, Any], loading_context: DatabaseLoadingContext, - ) -> Target: + ) -> LocalTarget: return cls(workdir=row["workdir"]) + + +class FilterConfig(Config, PersistableEntity): + def __init__(self, name: str, type: str, config: MutableMapping[str, Any]): + Config.__init__(self, name, type, config) + PersistableEntity.__init__(self) + + @classmethod + async def load( + cls, + context: StreamFlowContext, + persistent_id: int, + loading_context: DatabaseLoadingContext, + ) -> FilterConfig: + row = await context.database.get_filter(persistent_id) + obj = cls( + name=row["name"], + type=row["type"], + config=json.loads(row["config"]), + ) + obj.persistent_id = persistent_id + loading_context.add_filter(persistent_id, obj) + return obj + + async def save(self, context: StreamFlowContext) -> None: + async with self.persistence_lock: + if not self.persistent_id: + self.persistent_id = await context.database.add_filter( + name=self.name, + type=self.type, + config=json.dumps(self.config), + ) diff --git a/streamflow/core/persistence.py b/streamflow/core/persistence.py index 2fc6d5cb..74f63c80 100644 --- a/streamflow/core/persistence.py +++ b/streamflow/core/persistence.py @@ -8,7 +8,7 @@ from streamflow.core.context import SchemaEntity, StreamFlowContext if TYPE_CHECKING: - from streamflow.core.deployment import DeploymentConfig, Target + from streamflow.core.deployment import DeploymentConfig, Target, FilterConfig from streamflow.core.workflow import Port, Step, Token, Workflow @@ -17,6 +17,10 @@ class DatabaseLoadingContext(ABC): def add_deployment(self, persistent_id: int, deployment: DeploymentConfig): ... + @abstractmethod + def add_filter(self, persistent_id: int, filter_config: FilterConfig): + ... + @abstractmethod def add_port(self, persistent_id: int, port: Port): ... @@ -41,6 +45,10 @@ def add_workflow(self, persistent_id: int, workflow: Workflow): async def load_deployment(self, context: StreamFlowContext, persistent_id: int): ... + @abstractmethod + async def load_filter(self, context: StreamFlowContext, persistent_id: int): + ... + @abstractmethod async def load_port(self, context: StreamFlowContext, persistent_id: int): ... @@ -113,6 +121,15 @@ async def add_deployment( ) -> int: ... + @abstractmethod + async def add_filter( + self, + name: str, + type: str, + config: str, + ) -> int: + ... + @abstractmethod async def add_port( self, @@ -189,7 +206,11 @@ async def get_commands_by_step( ... @abstractmethod - async def get_deployment(self, deplyoment_id: int) -> MutableMapping[str, Any]: + async def get_deployment(self, deployment_id: int) -> MutableMapping[str, Any]: + ... + + @abstractmethod + async def get_filter(self, filter_id: int) -> MutableMapping[str, Any]: ... @abstractmethod @@ -270,6 +291,12 @@ async def update_deployment( ) -> int: ... + @abstractmethod + async def update_filter( + self, filter_id: int, updates: MutableMapping[str, Any] + ) -> int: + ... + @abstractmethod async def update_port(self, port_id: int, updates: MutableMapping[str, Any]) -> int: ... diff --git a/streamflow/deployment/utils.py b/streamflow/deployment/utils.py index 75548630..17e98b89 100644 --- a/streamflow/deployment/utils.py +++ b/streamflow/deployment/utils.py @@ -8,7 +8,12 @@ from typing import TYPE_CHECKING from streamflow.core.config import BindingConfig -from streamflow.core.deployment import DeploymentConfig, LocalTarget, Target +from streamflow.core.deployment import ( + DeploymentConfig, + LocalTarget, + Target, + FilterConfig, +) from streamflow.deployment.connector import LocalConnector from streamflow.log_handler import logger @@ -27,9 +32,9 @@ def get_binding_config( for target in config["targets"]: workdir = target.get("workdir") if target is not None else None if "deployment" in target: - target_deployment = workflow_config.deplyoments[target["deployment"]] + target_deployment = workflow_config.deployments[target["deployment"]] else: - target_deployment = workflow_config.deplyoments[target["model"]] + target_deployment = workflow_config.deployments[target["model"]] if logger.isEnabledFor(logging.WARNING): logger.warning( "The `model` keyword is deprecated and will be removed in StreamFlow 0.3.0. " @@ -63,7 +68,13 @@ def get_binding_config( workdir=workdir, ) ) - return BindingConfig(targets=targets, filters=config.get("filters")) + return BindingConfig( + targets=targets, + filters=[ + FilterConfig(name=c.name, type=c.type, config=c.config) + for c in config.get("filters") + ], + ) else: return BindingConfig(targets=[LocalTarget()]) diff --git a/streamflow/persistence/base.py b/streamflow/persistence/base.py index c72fcbd4..4fca0eb9 100644 --- a/streamflow/persistence/base.py +++ b/streamflow/persistence/base.py @@ -14,5 +14,6 @@ def __init__(self, context: StreamFlowContext): self.port_cache: Cache = LRUCache(maxsize=sys.maxsize) self.step_cache: Cache = LRUCache(maxsize=sys.maxsize) self.target_cache: Cache = LRUCache(maxsize=sys.maxsize) + self.filter_cache: Cache = LRUCache(maxsize=sys.maxsize) self.token_cache: Cache = LRUCache(maxsize=sys.maxsize) self.workflow_cache: Cache = LRUCache(maxsize=sys.maxsize) diff --git a/streamflow/persistence/loading_context.py b/streamflow/persistence/loading_context.py index d70fd6a3..92800920 100644 --- a/streamflow/persistence/loading_context.py +++ b/streamflow/persistence/loading_context.py @@ -1,7 +1,7 @@ from typing import MutableMapping from streamflow.core.context import StreamFlowContext -from streamflow.core.deployment import DeploymentConfig, Target +from streamflow.core.deployment import DeploymentConfig, Target, FilterConfig from streamflow.core.persistence import DatabaseLoadingContext from streamflow.core.workflow import Port, Step, Token, Workflow @@ -13,12 +13,16 @@ def __init__(self): self._ports: MutableMapping[int, Port] = {} self._steps: MutableMapping[int, Step] = {} self._targets: MutableMapping[int, Target] = {} + self._filter_configs: MutableMapping[int, FilterConfig] = {} self._tokens: MutableMapping[int, Token] = {} self._workflows: MutableMapping[int, Workflow] = {} def add_deployment(self, persistent_id: int, deployment: DeploymentConfig): self._deployment_configs[persistent_id] = deployment + def add_filter(self, persistent_id: int, filter_config: FilterConfig): + self._filter_configs[persistent_id] = filter_config + def add_port(self, persistent_id: int, port: Port): self._ports[persistent_id] = port @@ -39,6 +43,11 @@ async def load_deployment(self, context: StreamFlowContext, persistent_id: int): persistent_id ) or await DeploymentConfig.load(context, persistent_id, self) + async def load_filter(self, context: StreamFlowContext, persistent_id: int): + return self._filter_configs.get(persistent_id) or await FilterConfig.load( + context, persistent_id, self + ) + async def load_port(self, context: StreamFlowContext, persistent_id: int): return self._ports.get(persistent_id) or await Port.load( context, persistent_id, self diff --git a/streamflow/persistence/schemas/sqlite.sql b/streamflow/persistence/schemas/sqlite.sql index 9547587e..7c09b01f 100644 --- a/streamflow/persistence/schemas/sqlite.sql +++ b/streamflow/persistence/schemas/sqlite.sql @@ -101,3 +101,12 @@ CREATE TABLE IF NOT EXISTS target params TEXT, FOREIGN KEY (deployment) REFERENCES deployment (id) ); + + +CREATE TABLE IF NOT EXISTS filter +( + id INTEGER PRIMARY KEY, + name TEXT, + type TEXT, + config TEXT +); \ No newline at end of file diff --git a/streamflow/persistence/sqlite.py b/streamflow/persistence/sqlite.py index f659a335..7e9c460a 100644 --- a/streamflow/persistence/sqlite.py +++ b/streamflow/persistence/sqlite.py @@ -129,6 +129,24 @@ async def add_deployment( ) as cursor: return cursor.lastrowid + async def add_filter( + self, + name: str, + type: str, + config: str, + ) -> int: + async with self.connection as db: + async with db.execute( + "INSERT INTO filter(name, type, config) " + "VALUES (:name, :type, :config)", + { + "name": name, + "type": type, + "config": config, + }, + ) as cursor: + return cursor.lastrowid + async def add_port( self, name: str, @@ -280,10 +298,18 @@ async def get_commands_by_step( return await cursor.fetchall() @cachedmethod(lambda self: self.deployment_cache) - async def get_deployment(self, deplyoment_id: int) -> MutableMapping[str, Any]: + async def get_deployment(self, deployment_id: int) -> MutableMapping[str, Any]: async with self.connection as db: async with db.execute( - "SELECT * FROM deployment WHERE id = :id", {"id": deplyoment_id} + "SELECT * FROM deployment WHERE id = :id", {"id": deployment_id} + ) as cursor: + return await cursor.fetchone() + + @cachedmethod(lambda self: self.filter_cache) + async def get_filter(self, filter_id: int) -> MutableMapping[str, Any]: + async with self.connection as db: + async with db.execute( + "SELECT * FROM filter WHERE id = :id", {"id": filter_id} ) as cursor: return await cursor.fetchone() @@ -459,6 +485,19 @@ async def update_deployment( self.deployment_cache.pop(deployment_id, None) return deployment_id + async def update_filter( + self, filter_id: int, updates: MutableMapping[str, Any] + ) -> int: + async with self.connection as db: + await db.execute( + "UPDATE filter SET {} WHERE id = :id".format( # nosec + ", ".join([f"{k} = :{k}" for k in updates]) + ), + {**updates, **{"id": filter_id}}, + ) + self.filter_cache.pop(filter_id, None) + return filter_id + async def update_port(self, port_id: int, updates: MutableMapping[str, Any]) -> int: async with self.connection as db: await db.execute( diff --git a/tests/test_persistence.py b/tests/test_persistence.py index e227903a..4daf7317 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,7 +5,7 @@ from streamflow.core import utils from streamflow.core.config import BindingConfig from streamflow.core.context import StreamFlowContext -from streamflow.core.deployment import LocalTarget, Target +from streamflow.core.deployment import LocalTarget, Target, FilterConfig from streamflow.core.workflow import Job, Token, Workflow from streamflow.workflow.combinator import ( CartesianProductCombinator, @@ -32,6 +32,7 @@ ) from tests.conftest import save_load_and_test from tests.utils.deployment import get_docker_deployment_config +from tests.utils.workflow import create_workflow @pytest.mark.asyncio @@ -135,21 +136,31 @@ async def test_deploy_step(context: StreamFlowContext): @pytest.mark.asyncio async def test_schedule_step(context: StreamFlowContext): """Test saving and loading ScheduleStep from database""" - workflow = Workflow( - context=context, type="cwl", name=utils.random_name(), config={} + workflow, _ = await create_workflow(context, 0) + binding_config = BindingConfig( + targets=[ + LocalTarget(workdir=utils.random_name()), + Target( + deployment=get_docker_deployment_config(), + workdir=utils.random_name(), + ), + ], + filters=[FilterConfig(config={}, name=utils.random_name(), type="shuffle")], ) - port = workflow.create_port() + connector_ports = { + target.deployment.name: workflow.create_port(ConnectorPort) + for target in binding_config.targets + } await workflow.save(context) - binding_config = BindingConfig(targets=[LocalTarget(workdir=utils.random_name())]) schedule_step = workflow.create_step( cls=ScheduleStep, - name=posixpath.join(utils.random_name() + "-injector", "__schedule__"), + name=posixpath.join(utils.random_name(), "__schedule__"), job_prefix="something", - connector_ports={binding_config.targets[0].deployment.name: port}, - input_directory=binding_config.targets[0].workdir, - output_directory=binding_config.targets[0].workdir, - tmp_directory=binding_config.targets[0].workdir, + connector_ports=connector_ports, + input_directory=posixpath.join(*(utils.random_name() for _ in range(2))), + output_directory=posixpath.join(*(utils.random_name() for _ in range(2))), + tmp_directory=posixpath.join(*(utils.random_name() for _ in range(2))), binding_config=binding_config, ) await save_load_and_test(schedule_step, context) @@ -319,6 +330,13 @@ async def test_iteration_termination_token(context: StreamFlowContext): await save_load_and_test(token, context) +@pytest.mark.asyncio +async def test_filter_config(context: StreamFlowContext): + """Test saving and loading filter configuration from database""" + config = FilterConfig(config={}, name=utils.random_name(), type="shuffle") + await save_load_and_test(config, context) + + @pytest.mark.asyncio async def test_deployment(context: StreamFlowContext): """Test saving and loading deployment configuration from database"""