Skip to content

Commit

Permalink
Fixed Config persistence (#289)
Browse files Browse the repository at this point in the history
This commit fixes the `ScheduleStep` persistence, particularly the `BindingFilter` object.
The `filters` attribute in `BindingFilter` was not being saved correctly in the database.
A new class extends the `Config`, called `FilterConfig`, and a new table was added to the database.

The extention of `PersistableEntity` was moved from `Config` to its superclasses only where needed, i.e. `DeploymentConfig` and `FilterConfig`.
  • Loading branch information
LanderOtto authored Nov 30, 2023
1 parent 1a1c7fb commit 52f8882
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 52 deletions.
12 changes: 6 additions & 6 deletions streamflow/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, name: str, config: MutableMapping[str, Any]) -> None:
super().__init__(
name=name, type=workflow_config["type"], config=workflow_config["config"]
)
self.deplyoments = config.get("deployments", {})
self.deployments = config.get("deployments", {})
self.policies = {
k: Config(name=k, type=v["type"], config=v["config"])
for k, v in config.get("scheduling", {}).get("policies", {}).items()
Expand All @@ -38,16 +38,16 @@ def __init__(self, name: str, config: MutableMapping[str, Any]) -> None:
k: Config(name=k, type=v["type"], config=v["config"])
for k, v in config.get("bindingFilters", {}).items()
}
if not self.deplyoments:
self.deplyoments = config.get("models", {})
if self.deplyoments:
if not self.deployments:
self.deployments = config.get("models", {})
if self.deployments:
if logger.isEnabledFor(logging.WARNING):
logger.warning(
"The `models` keyword is deprecated and will be removed in StreamFlow 0.3.0. "
"Use `deployments` instead."
)
self.scheduling_groups: MutableMapping[str, MutableSequence[str]] = {}
for name, deployment in self.deplyoments.items():
for name, deployment in self.deployments.items():
deployment["name"] = name
self.filesystem = {"children": {}}
for binding in workflow_config.get("bindings", []):
Expand All @@ -68,7 +68,7 @@ def _process_binding(self, binding: MutableMapping[str, Any]):
for target in targets:
policy = target.get(
"policy",
self.deplyoments[target.get("deployment", target.get("model", {}))].get(
self.deployments[target.get("deployment", target.get("model", {}))].get(
"policy", "__DEFAULT__"
),
)
Expand Down
34 changes: 10 additions & 24 deletions streamflow/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,31 @@
import asyncio
from typing import Any, MutableMapping, MutableSequence, TYPE_CHECKING, cast

from streamflow.core.persistence import PersistableEntity

if TYPE_CHECKING:
from streamflow.core.context import StreamFlowContext
from streamflow.core.deployment import Target
from streamflow.core.deployment import Target, FilterConfig
from streamflow.core.persistence import DatabaseLoadingContext


class Config(PersistableEntity):
class Config:
__slots__ = ("name", "type", "config")

def __init__(self, name: str, type: str, config: MutableMapping[str, Any]) -> None:
super().__init__()
self.name: str = name
self.type: str = type
self.config: MutableMapping[str, Any] = config or {}

@classmethod
async def load(
cls,
context: StreamFlowContext,
row: MutableMapping[str, Any],
loading_context: DatabaseLoadingContext,
) -> Config:
return cls(row["name"], row["type"], row["config"])

async def save(self, context: StreamFlowContext):
return {"name": self.name, "type": self.type, "config": self.config}


class BindingConfig:
__slots__ = ("targets", "filters")

def __init__(
self, targets: MutableSequence[Target], filters: MutableSequence[Config] = None
self,
targets: MutableSequence[Target],
filters: MutableSequence[FilterConfig] = None,
):
self.targets: MutableSequence[Target] = targets
self.filters: MutableSequence[Config] = filters or []
self.filters: MutableSequence[FilterConfig] = filters or []

@classmethod
async def load(
Expand All @@ -63,7 +50,7 @@ async def load(
MutableSequence,
await asyncio.gather(
*(
asyncio.create_task(Config.load(context, f, loading_context))
asyncio.create_task(loading_context.load_filter(context, f))
for f in row["filters"]
)
),
Expand All @@ -72,11 +59,10 @@ async def load(

async def save(self, context: StreamFlowContext):
await asyncio.gather(
*(asyncio.create_task(t.save(context)) for t in self.targets)
*(asyncio.create_task(t.save(context)) for t in self.targets),
*(asyncio.create_task(f.save(context)) for f in self.filters),
)
return {
"targets": [t.persistent_id for t in self.targets],
"filters": await asyncio.gather(
*(asyncio.create_task(f.save(context)) for f in self.filters)
),
"filters": [f.persistent_id for f in self.filters],
}
39 changes: 36 additions & 3 deletions streamflow/core/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def undeploy_all(self):
...


class DeploymentConfig(Config):
class DeploymentConfig(Config, PersistableEntity):
__slots__ = ("name", "type", "config", "external", "lazy", "workdir", "wraps")

def __init__(
Expand All @@ -175,7 +175,8 @@ def __init__(
workdir: str | None = None,
wraps: str | None = None,
) -> None:
super().__init__(name, type, config)
Config.__init__(self, name, type, config)
PersistableEntity.__init__(self)
self.external: bool = external
self.lazy: bool = lazy
self.workdir: str | None = workdir
Expand Down Expand Up @@ -307,5 +308,37 @@ async def _load(
context: StreamFlowContext,
row: MutableMapping[str, Any],
loading_context: DatabaseLoadingContext,
) -> Target:
) -> LocalTarget:
return cls(workdir=row["workdir"])


class FilterConfig(Config, PersistableEntity):
def __init__(self, name: str, type: str, config: MutableMapping[str, Any]):
Config.__init__(self, name, type, config)
PersistableEntity.__init__(self)

@classmethod
async def load(
cls,
context: StreamFlowContext,
persistent_id: int,
loading_context: DatabaseLoadingContext,
) -> FilterConfig:
row = await context.database.get_filter(persistent_id)
obj = cls(
name=row["name"],
type=row["type"],
config=json.loads(row["config"]),
)
obj.persistent_id = persistent_id
loading_context.add_filter(persistent_id, obj)
return obj

async def save(self, context: StreamFlowContext) -> None:
async with self.persistence_lock:
if not self.persistent_id:
self.persistent_id = await context.database.add_filter(
name=self.name,
type=self.type,
config=json.dumps(self.config),
)
31 changes: 29 additions & 2 deletions streamflow/core/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from streamflow.core.context import SchemaEntity, StreamFlowContext

if TYPE_CHECKING:
from streamflow.core.deployment import DeploymentConfig, Target
from streamflow.core.deployment import DeploymentConfig, Target, FilterConfig
from streamflow.core.workflow import Port, Step, Token, Workflow


Expand All @@ -17,6 +17,10 @@ class DatabaseLoadingContext(ABC):
def add_deployment(self, persistent_id: int, deployment: DeploymentConfig):
...

@abstractmethod
def add_filter(self, persistent_id: int, filter_config: FilterConfig):
...

@abstractmethod
def add_port(self, persistent_id: int, port: Port):
...
Expand All @@ -41,6 +45,10 @@ def add_workflow(self, persistent_id: int, workflow: Workflow):
async def load_deployment(self, context: StreamFlowContext, persistent_id: int):
...

@abstractmethod
async def load_filter(self, context: StreamFlowContext, persistent_id: int):
...

@abstractmethod
async def load_port(self, context: StreamFlowContext, persistent_id: int):
...
Expand Down Expand Up @@ -113,6 +121,15 @@ async def add_deployment(
) -> int:
...

@abstractmethod
async def add_filter(
self,
name: str,
type: str,
config: str,
) -> int:
...

@abstractmethod
async def add_port(
self,
Expand Down Expand Up @@ -189,7 +206,11 @@ async def get_commands_by_step(
...

@abstractmethod
async def get_deployment(self, deplyoment_id: int) -> MutableMapping[str, Any]:
async def get_deployment(self, deployment_id: int) -> MutableMapping[str, Any]:
...

@abstractmethod
async def get_filter(self, filter_id: int) -> MutableMapping[str, Any]:
...

@abstractmethod
Expand Down Expand Up @@ -270,6 +291,12 @@ async def update_deployment(
) -> int:
...

@abstractmethod
async def update_filter(
self, filter_id: int, updates: MutableMapping[str, Any]
) -> int:
...

@abstractmethod
async def update_port(self, port_id: int, updates: MutableMapping[str, Any]) -> int:
...
Expand Down
19 changes: 15 additions & 4 deletions streamflow/deployment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from typing import TYPE_CHECKING

from streamflow.core.config import BindingConfig
from streamflow.core.deployment import DeploymentConfig, LocalTarget, Target
from streamflow.core.deployment import (
DeploymentConfig,
LocalTarget,
Target,
FilterConfig,
)
from streamflow.deployment.connector import LocalConnector
from streamflow.log_handler import logger

Expand All @@ -27,9 +32,9 @@ def get_binding_config(
for target in config["targets"]:
workdir = target.get("workdir") if target is not None else None
if "deployment" in target:
target_deployment = workflow_config.deplyoments[target["deployment"]]
target_deployment = workflow_config.deployments[target["deployment"]]
else:
target_deployment = workflow_config.deplyoments[target["model"]]
target_deployment = workflow_config.deployments[target["model"]]
if logger.isEnabledFor(logging.WARNING):
logger.warning(
"The `model` keyword is deprecated and will be removed in StreamFlow 0.3.0. "
Expand Down Expand Up @@ -63,7 +68,13 @@ def get_binding_config(
workdir=workdir,
)
)
return BindingConfig(targets=targets, filters=config.get("filters"))
return BindingConfig(
targets=targets,
filters=[
FilterConfig(name=c.name, type=c.type, config=c.config)
for c in config.get("filters")
],
)
else:
return BindingConfig(targets=[LocalTarget()])

Expand Down
1 change: 1 addition & 0 deletions streamflow/persistence/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ def __init__(self, context: StreamFlowContext):
self.port_cache: Cache = LRUCache(maxsize=sys.maxsize)
self.step_cache: Cache = LRUCache(maxsize=sys.maxsize)
self.target_cache: Cache = LRUCache(maxsize=sys.maxsize)
self.filter_cache: Cache = LRUCache(maxsize=sys.maxsize)
self.token_cache: Cache = LRUCache(maxsize=sys.maxsize)
self.workflow_cache: Cache = LRUCache(maxsize=sys.maxsize)
11 changes: 10 additions & 1 deletion streamflow/persistence/loading_context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import MutableMapping

from streamflow.core.context import StreamFlowContext
from streamflow.core.deployment import DeploymentConfig, Target
from streamflow.core.deployment import DeploymentConfig, Target, FilterConfig
from streamflow.core.persistence import DatabaseLoadingContext
from streamflow.core.workflow import Port, Step, Token, Workflow

Expand All @@ -13,12 +13,16 @@ def __init__(self):
self._ports: MutableMapping[int, Port] = {}
self._steps: MutableMapping[int, Step] = {}
self._targets: MutableMapping[int, Target] = {}
self._filter_configs: MutableMapping[int, FilterConfig] = {}
self._tokens: MutableMapping[int, Token] = {}
self._workflows: MutableMapping[int, Workflow] = {}

def add_deployment(self, persistent_id: int, deployment: DeploymentConfig):
self._deployment_configs[persistent_id] = deployment

def add_filter(self, persistent_id: int, filter_config: FilterConfig):
self._filter_configs[persistent_id] = filter_config

def add_port(self, persistent_id: int, port: Port):
self._ports[persistent_id] = port

Expand All @@ -39,6 +43,11 @@ async def load_deployment(self, context: StreamFlowContext, persistent_id: int):
persistent_id
) or await DeploymentConfig.load(context, persistent_id, self)

async def load_filter(self, context: StreamFlowContext, persistent_id: int):
return self._filter_configs.get(persistent_id) or await FilterConfig.load(
context, persistent_id, self
)

async def load_port(self, context: StreamFlowContext, persistent_id: int):
return self._ports.get(persistent_id) or await Port.load(
context, persistent_id, self
Expand Down
9 changes: 9 additions & 0 deletions streamflow/persistence/schemas/sqlite.sql
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,12 @@ CREATE TABLE IF NOT EXISTS target
params TEXT,
FOREIGN KEY (deployment) REFERENCES deployment (id)
);


CREATE TABLE IF NOT EXISTS filter
(
id INTEGER PRIMARY KEY,
name TEXT,
type TEXT,
config TEXT
);
Loading

0 comments on commit 52f8882

Please sign in to comment.