From 332f85af19a8aed7fa98d80ba25840bbcc0e0b1c Mon Sep 17 00:00:00 2001 From: Alberto Mulone Date: Sat, 3 Jun 2023 17:14:50 +0200 Subject: [PATCH] fixed loop combinator + renamed param inputs --- streamflow/cwl/combinator.py | 14 +++++------ streamflow/cwl/step.py | 8 +++---- streamflow/workflow/combinator.py | 6 ++--- streamflow/workflow/step.py | 40 ++++++++++++++++++------------- 4 files changed, 35 insertions(+), 33 deletions(-) diff --git a/streamflow/cwl/combinator.py b/streamflow/cwl/combinator.py index 93d86a88..fef9e090 100644 --- a/streamflow/cwl/combinator.py +++ b/streamflow/cwl/combinator.py @@ -75,17 +75,15 @@ async def combine( else: outputs = [outputs] tag = schema[self.input_names[0]]["token"].tag - inputs_token_id.extend(schema[self.input_names[0]]["inputs_id"]) + inputs_token_id = schema[self.input_names[0]]["inputs_id"] # Otherwise, merge multiple inputs in a single list else: outputs = [schema[name]["token"] for name in self.input_names] - inputs_token_id.extend( - [ - id - for name in self.input_names - for id in schema[name]["inputs_id"] - ] - ) + inputs_token_id = [ + id + for name in self.input_names + for id in schema[name]["inputs_id"] + ] tag = get_tag(outputs) # Flatten if needed if self.flatten: diff --git a/streamflow/cwl/step.py b/streamflow/cwl/step.py index 0489441d..98dda8c4 100644 --- a/streamflow/cwl/step.py +++ b/streamflow/cwl/step.py @@ -104,7 +104,7 @@ async def _on_true(self, inputs: MutableMapping[str, Token]): await self._persist_token( token=inputs[port_name].update(inputs[port_name].value), port=port, - inputs=_get_tokens_id(inputs.values()), + inputs_token_id=_get_tokens_id(inputs.values()), ) ) @@ -115,7 +115,7 @@ async def _on_false(self, inputs: MutableMapping[str, Token]): await self._persist_token( token=Token(value=None, tag=get_tag(inputs.values())), port=port, - inputs=_get_tokens_id(inputs.values()), + inputs_token_id=_get_tokens_id(inputs.values()), ) ) @@ -216,7 +216,7 @@ async def _on_true(self, inputs: MutableMapping[str, Token]): await self._persist_token( token=inputs[port_name].update(inputs[port_name].value), port=port, - inputs=_get_tokens_id(inputs.values()), + inputs_token_id=_get_tokens_id(inputs.values()), ) ) @@ -234,7 +234,7 @@ async def _on_false(self, inputs: MutableMapping[str, Token]): await self._persist_token( token=ListToken(value=token_value, tag=get_tag(inputs.values())), port=port, - inputs=_get_tokens_id(inputs.values()), + inputs_token_id=_get_tokens_id(inputs.values()), ) ) diff --git a/streamflow/workflow/combinator.py b/streamflow/workflow/combinator.py index d4884b1c..fb8bc530 100644 --- a/streamflow/workflow/combinator.py +++ b/streamflow/workflow/combinator.py @@ -200,15 +200,13 @@ def __init__(self, name: str, workflow: Workflow): def add_output_item(self, item: str) -> None: self.output_items.append(item) - async def _product( - self, - ) -> AsyncIterable[MutableMapping[str, Token]]: + async def _product(self) -> AsyncIterable[MutableMapping[str, Token]]: async for schema in super()._product(): tag = utils.get_tag([t["token"] for t in schema.values()]) yield { k: { "token": IterationTerminationToken(tag=tag), - "inputs_id": schema[k]["inputs_id"], + "inputs_id": [id for t in schema.values() for id in t["inputs_id"]], } for k in self.output_items } diff --git a/streamflow/workflow/step.py b/streamflow/workflow/step.py index aca1f08e..d8af00fd 100644 --- a/streamflow/workflow/step.py +++ b/streamflow/workflow/step.py @@ -79,7 +79,9 @@ def _group_by_tag( def _get_tokens_id(token_list): - return [t.persistent_id for t in token_list if t.persistent_id] + if token_list: + return [t.persistent_id for t in token_list if t.persistent_id] + return [] class BaseStep(Step, ABC): @@ -109,16 +111,16 @@ async def _get_inputs(self, input_ports: MutableMapping[str, Port]): return inputs async def _persist_token( - self, token: Token, port: Port, inputs: Iterable[Token] + self, token: Token, port: Port, inputs_token_id: Iterable[int] ) -> Token: if token.persistent_id: raise WorkflowDefinitionException( f"Token already has an id: {token.persistent_id}" ) await token.save(self.workflow.context, port_id=port.persistent_id) - if inputs: + if inputs_token_id: await self.workflow.context.database.add_provenance( - inputs=inputs, token=token.persistent_id + inputs=inputs_token_id, token=token.persistent_id ) return token @@ -343,7 +345,7 @@ async def run(self): await self._persist_token( token=token["token"], port=self.get_output_port(port_name), - inputs=ins, + inputs_token_id=ins, ) ) @@ -530,7 +532,7 @@ async def run(self): await self._persist_token( token=Token(value=self.deployment_config.name), port=self.get_output_port(), - inputs=[], + inputs_token_id=[], ) ) await self.terminate(Status.COMPLETED) @@ -607,7 +609,7 @@ async def _retrieve_output( await self._persist_token( token=token, port=output_port, - inputs=_get_tokens_id( + inputs_token_id=_get_tokens_id( list(job.inputs.values()) + [ get_job_token( @@ -893,7 +895,7 @@ async def run(self): tag=tag, value=sorted(tokens, key=lambda cur: cur.tag) ), port=output_port, - inputs=_get_tokens_id(tokens), + inputs_token_id=_get_tokens_id(tokens), ) ) break @@ -991,7 +993,7 @@ async def run(self): await self._persist_token( token=await self.process_input(job, token.value), port=self.get_output_port(), - inputs=_get_tokens_id(in_list), + inputs_token_id=_get_tokens_id(in_list), ) ) finally: @@ -1075,7 +1077,7 @@ async def run(self): await self._persist_token( token=token["token"], port=self.get_output_port(port_name), - inputs=ins, + inputs_token_id=ins, ) ) # Create a new task in place of the completed one if the port is not terminated @@ -1169,7 +1171,7 @@ async def run(self): await self._persist_token( token=await self._process_output(prefix), port=self.get_output_port(), - inputs=_get_tokens_id(self.token_map.get(prefix)), + inputs_token_id=_get_tokens_id(self.token_map.get(prefix)), ) ) # If all iterations are terminated, terminate the step @@ -1284,7 +1286,7 @@ async def _propagate_job( await self._persist_token( token=JobToken(value=job), port=self.get_output_port(), - inputs=_get_tokens_id(token_inputs), + inputs_token_id=_get_tokens_id(token_inputs), ) ) @@ -1428,7 +1430,7 @@ async def _scatter(self, token: Token): await self._persist_token( token=t.retag(token.tag + "." + str(i)), port=output_port, - inputs=_get_tokens_id([token]), + inputs_token_id=_get_tokens_id([token]), ) ) else: @@ -1538,7 +1540,7 @@ async def run(self): await self._persist_token( token=await self.transfer(job, token), port=self.get_output_port(port_name), - inputs=_get_tokens_id( + inputs_token_id=_get_tokens_id( list(inputs.values()) + [ get_job_token( @@ -1595,7 +1597,9 @@ async def run(self): await self._persist_token( token=token.update(token.value), port=self.get_output_port(port_name), - inputs=_get_tokens_id(inputs.values()), + inputs_token_id=_get_tokens_id( + inputs.values() + ), ) ) # Otherwise, apply transformation and propagate outputs @@ -1607,7 +1611,9 @@ async def run(self): await self._persist_token( token=token, port=self.get_output_port(port_name), - inputs=_get_tokens_id(inputs.values()), + inputs_token_id=_get_tokens_id( + inputs.values() + ), ) ) else: @@ -1616,7 +1622,7 @@ async def run(self): await self._persist_token( token=token, port=self.get_output_port(port_name), - inputs=[], + inputs_token_id=[], ) ) # Terminate step