Skip to content

Commit

Permalink
moved _persist_token method to the BaseStep class
Browse files Browse the repository at this point in the history
  • Loading branch information
LanderOtto committed Jun 3, 2023
1 parent 250b032 commit 62b9f26
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 148 deletions.
26 changes: 14 additions & 12 deletions streamflow/cwl/combinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from streamflow.core.utils import get_tag
from streamflow.core.workflow import Token, Workflow
from streamflow.workflow.combinator import DotProductCombinator
from streamflow.workflow.step import CombinatorStep
from streamflow.workflow.token import IterationTerminationToken, ListToken


Expand Down Expand Up @@ -62,7 +61,7 @@ async def _save_additional_params(self, context: StreamFlowContext):
}

async def combine(
self, port_name: str, token: Token, combinator_step: CombinatorStep = None
self, port_name: str, token: Token
) -> AsyncIterable[MutableMapping[str, Token]]:
if not isinstance(token, IterationTerminationToken):
async for schema in super().combine(port_name, token):
Expand All @@ -80,17 +79,20 @@ async def combine(
# Otherwise, merge multiple inputs in a single list
else:
outputs = [schema[name]["token"] for name in self.input_names]
inputs_token_id.extend([id for name in self.input_names for id in schema[name]["inputs_id"]])
inputs_token_id.extend(
[
id
for name in self.input_names
for id in schema[name]["inputs_id"]
]
)
tag = get_tag(outputs)
# Flatten if needed
if self.flatten:
outputs = _flatten_token_list(outputs)
yield await self._save_token(
{
self.output_name: {
"token": ListToken(value=outputs, tag=tag),
"inputs_id": inputs_token_id,
}
},
combinator_step,
)
yield {
self.output_name: {
"token": ListToken(value=outputs, tag=tag),
"inputs_id": inputs_token_id,
}
}
14 changes: 5 additions & 9 deletions streamflow/cwl/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
InputInjectorStep,
LoopOutputStep,
TransferStep,
_persist_token, _get_tokens_id,
_get_tokens_id,
)
from streamflow.workflow.token import IterationTerminationToken, ListToken, ObjectToken

Expand Down Expand Up @@ -101,23 +101,21 @@ async def _on_true(self, inputs: MutableMapping[str, Token]):
# Propagate output tokens
for port_name, port in self.get_output_ports().items():
port.put(
await _persist_token(
await self._persist_token(
token=inputs[port_name].update(inputs[port_name].value),
port=port,
inputs=_get_tokens_id(inputs.values()),
context=self.workflow.context,
)
)

async def _on_false(self, inputs: MutableMapping[str, Token]):
# Propagate skip tokens
for port in self.get_skip_ports().values():
port.put(
await _persist_token(
await self._persist_token(
token=Token(value=None, tag=get_tag(inputs.values())),
port=port,
inputs=_get_tokens_id(inputs.values()),
context=self.workflow.context,
)
)

Expand Down Expand Up @@ -215,11 +213,10 @@ async def _on_true(self, inputs: MutableMapping[str, Token]):
# Propagate output tokens
for port_name, port in self.get_output_ports().items():
port.put(
await _persist_token(
await self._persist_token(
token=inputs[port_name].update(inputs[port_name].value),
port=port,
inputs=_get_tokens_id(inputs.values()),
context=self.workflow.context,
)
)

Expand All @@ -234,11 +231,10 @@ async def _on_false(self, inputs: MutableMapping[str, Token]):
# Propagate skip tokens
for port in self.get_skip_ports().values():
port.put(
await _persist_token(
await self._persist_token(
token=ListToken(value=token_value, tag=get_tag(inputs.values())),
port=port,
inputs=_get_tokens_id(inputs.values()),
context=self.workflow.context,
)
)

Expand Down
95 changes: 41 additions & 54 deletions streamflow/workflow/combinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from streamflow.core.exception import WorkflowExecutionException
from streamflow.core.persistence import DatabaseLoadingContext
from streamflow.core.workflow import Token, Workflow
from streamflow.workflow.step import Combinator, CombinatorStep
from streamflow.workflow.step import Combinator
from streamflow.workflow.token import IterationTerminationToken


Expand Down Expand Up @@ -37,7 +37,6 @@ async def _product(
self,
port_name: str,
token: Token | MutableSequence[Token],
combinator_step: CombinatorStep,
) -> AsyncIterable[MutableMapping[str, Token]]:
# Get all combinations of the new element with the others
tag = ".".join(token.tag.split(".")[: -self.depth])
Expand All @@ -57,16 +56,13 @@ async def _product(
else:
schema[key] = config[key]
suffix = [t.tag.split(".")[-1] for t in schema.values()]
yield await self._save_token(
{
k: {
"token": t.retag(".".join(t.tag.split(".")[:-1] + suffix)),
"inputs_id": [t.persistent_id],
}
for k, t in schema.items()
},
combinator_step,
)
yield {
k: {
"token": t.retag(".".join(t.tag.split(".")[:-1] + suffix)),
"inputs_id": [t.persistent_id],
}
for k, t in schema.items()
}

async def _save_additional_params(self, context: StreamFlowContext):
return {
Expand All @@ -75,7 +71,9 @@ async def _save_additional_params(self, context: StreamFlowContext):
}

async def combine(
self, port_name: str, token: Token, combinator_step: CombinatorStep = None
self,
port_name: str,
token: Token,
) -> AsyncIterable[MutableMapping[str, Token]]:
# If port is associated to an inner combinator, call it and put schemas in their related list
if c := self.get_combinator(port_name):
Expand All @@ -84,12 +82,12 @@ async def combine(
c.combine(port_name, token),
):
self._add_to_list(schema, c.name, self.depth)
async for product in self._product(port_name, token, combinator_step):
async for product in self._product(port_name, token):
yield product
# If port is associated directly with the current combinator, put the token in the list
elif port_name in self.items:
self._add_to_list(token, port_name, self.depth)
async for product in self._product(port_name, token, combinator_step):
async for product in self._product(port_name, token):
yield product
# Otherwise throw Exception
else:
Expand Down Expand Up @@ -118,9 +116,7 @@ def __init__(self, name: str, workflow: Workflow):
str, MutableMapping[str, MutableSequence[Any]]
] = {}

async def _product(
self, combinator_step
) -> AsyncIterable[MutableMapping[str, Token]]:
async def _product(self) -> AsyncIterable[MutableMapping[str, Token]]:
# Check if some complete input sets are available
for tag in list(self.token_values):
if len(self.token_values[tag]) == len(self.items):
Expand All @@ -138,19 +134,18 @@ async def _product(
"inputs_id": [element.persistent_id],
}
tag = utils.get_tag([t["token"] for t in schema.values()])
yield await self._save_token(
{
k: {
"token": t["token"].retag(tag),
"inputs_id": t["inputs_id"],
}
for k, t in schema.items()
},
combinator_step,
)
yield {
k: {
"token": t["token"].retag(tag),
"inputs_id": t["inputs_id"],
}
for k, t in schema.items()
}

async def combine(
self, port_name: str, token: Token, combinator_step: CombinatorStep = None
self,
port_name: str,
token: Token,
) -> AsyncIterable[MutableMapping[str, Token]]:
# If port is associated to an inner combinator, call it and put schemas in their related list
if c := self.get_combinator(port_name):
Expand All @@ -159,12 +154,12 @@ async def combine(
c.combine(port_name, token),
):
self._add_to_list(schema, c.name)
async for product in self._product(combinator_step):
async for product in self._product():
yield product
# If port is associated directly with the current combinator, put the token in the list
elif port_name in self.items:
self._add_to_list(token, port_name)
async for product in self._product(combinator_step):
async for product in self._product():
yield product
# Otherwise throw Exception
else:
Expand All @@ -178,10 +173,8 @@ def __init__(self, name: str, workflow: Workflow):
super().__init__(name, workflow)
self.iteration_map: MutableMapping[str, int] = {}

async def _product(
self, combinator_step
) -> AsyncIterable[MutableMapping[str, Token]]:
async for schema in super()._product(None):
async def _product(self) -> AsyncIterable[MutableMapping[str, Token]]:
async for schema in super()._product():
tag = utils.get_tag([t["token"] for t in schema.values()])
prefix = ".".join(tag.split(".")[:-1])
if prefix not in self.iteration_map:
Expand All @@ -190,13 +183,10 @@ async def _product(
else:
self.iteration_map[prefix] += 1
tag = ".".join(tag.split(".")[:-1] + [str(self.iteration_map[prefix])])
yield await self._save_token(
{
k: {"token": t["token"].retag(tag), "inputs_id": t["inputs_id"]}
for k, t in schema.items()
},
combinator_step,
)
yield {
k: {"token": t["token"].retag(tag), "inputs_id": t["inputs_id"]}
for k, t in schema.items()
}


class LoopTerminationCombinator(DotProductCombinator):
Expand All @@ -211,17 +201,14 @@ def add_output_item(self, item: str) -> None:
self.output_items.append(item)

async def _product(
self, combinator_step
self,
) -> AsyncIterable[MutableMapping[str, Token]]:
async for schema in super()._product(None):
async for schema in super()._product():
tag = utils.get_tag([t["token"] for t in schema.values()])
yield await self._save_token(
{
k: {
"token": IterationTerminationToken(tag=tag),
"inputs_id": schema[k]["inputs_id"],
}
for k in self.output_items
},
combinator_step,
)
yield {
k: {
"token": IterationTerminationToken(tag=tag),
"inputs_id": schema[k]["inputs_id"],
}
for k in self.output_items
}
Loading

0 comments on commit 62b9f26

Please sign in to comment.