Skip to content

Commit

Permalink
fixed loop combinator + renamed param inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
LanderOtto committed Jun 3, 2023
1 parent 62b9f26 commit 332f85a
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 33 deletions.
14 changes: 6 additions & 8 deletions streamflow/cwl/combinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,15 @@ async def combine(
else:
outputs = [outputs]
tag = schema[self.input_names[0]]["token"].tag
inputs_token_id.extend(schema[self.input_names[0]]["inputs_id"])
inputs_token_id = schema[self.input_names[0]]["inputs_id"]
# Otherwise, merge multiple inputs in a single list
else:
outputs = [schema[name]["token"] for name in self.input_names]
inputs_token_id.extend(
[
id
for name in self.input_names
for id in schema[name]["inputs_id"]
]
)
inputs_token_id = [
id
for name in self.input_names
for id in schema[name]["inputs_id"]
]
tag = get_tag(outputs)
# Flatten if needed
if self.flatten:
Expand Down
8 changes: 4 additions & 4 deletions streamflow/cwl/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def _on_true(self, inputs: MutableMapping[str, Token]):
await self._persist_token(
token=inputs[port_name].update(inputs[port_name].value),
port=port,
inputs=_get_tokens_id(inputs.values()),
inputs_token_id=_get_tokens_id(inputs.values()),
)
)

Expand All @@ -115,7 +115,7 @@ async def _on_false(self, inputs: MutableMapping[str, Token]):
await self._persist_token(
token=Token(value=None, tag=get_tag(inputs.values())),
port=port,
inputs=_get_tokens_id(inputs.values()),
inputs_token_id=_get_tokens_id(inputs.values()),
)
)

Expand Down Expand Up @@ -216,7 +216,7 @@ async def _on_true(self, inputs: MutableMapping[str, Token]):
await self._persist_token(
token=inputs[port_name].update(inputs[port_name].value),
port=port,
inputs=_get_tokens_id(inputs.values()),
inputs_token_id=_get_tokens_id(inputs.values()),
)
)

Expand All @@ -234,7 +234,7 @@ async def _on_false(self, inputs: MutableMapping[str, Token]):
await self._persist_token(
token=ListToken(value=token_value, tag=get_tag(inputs.values())),
port=port,
inputs=_get_tokens_id(inputs.values()),
inputs_token_id=_get_tokens_id(inputs.values()),
)
)

Expand Down
6 changes: 2 additions & 4 deletions streamflow/workflow/combinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,13 @@ def __init__(self, name: str, workflow: Workflow):
def add_output_item(self, item: str) -> None:
self.output_items.append(item)

async def _product(
self,
) -> AsyncIterable[MutableMapping[str, Token]]:
async def _product(self) -> AsyncIterable[MutableMapping[str, Token]]:
async for schema in super()._product():
tag = utils.get_tag([t["token"] for t in schema.values()])
yield {
k: {
"token": IterationTerminationToken(tag=tag),
"inputs_id": schema[k]["inputs_id"],
"inputs_id": [id for t in schema.values() for id in t["inputs_id"]],
}
for k in self.output_items
}
40 changes: 23 additions & 17 deletions streamflow/workflow/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def _group_by_tag(


def _get_tokens_id(token_list):
return [t.persistent_id for t in token_list if t.persistent_id]
if token_list:
return [t.persistent_id for t in token_list if t.persistent_id]
return []


class BaseStep(Step, ABC):
Expand Down Expand Up @@ -109,16 +111,16 @@ async def _get_inputs(self, input_ports: MutableMapping[str, Port]):
return inputs

async def _persist_token(
self, token: Token, port: Port, inputs: Iterable[Token]
self, token: Token, port: Port, inputs_token_id: Iterable[int]
) -> Token:
if token.persistent_id:
raise WorkflowDefinitionException(
f"Token already has an id: {token.persistent_id}"
)
await token.save(self.workflow.context, port_id=port.persistent_id)
if inputs:
if inputs_token_id:
await self.workflow.context.database.add_provenance(
inputs=inputs, token=token.persistent_id
inputs=inputs_token_id, token=token.persistent_id
)
return token

Expand Down Expand Up @@ -343,7 +345,7 @@ async def run(self):
await self._persist_token(
token=token["token"],
port=self.get_output_port(port_name),
inputs=ins,
inputs_token_id=ins,
)
)

Expand Down Expand Up @@ -530,7 +532,7 @@ async def run(self):
await self._persist_token(
token=Token(value=self.deployment_config.name),
port=self.get_output_port(),
inputs=[],
inputs_token_id=[],
)
)
await self.terminate(Status.COMPLETED)
Expand Down Expand Up @@ -607,7 +609,7 @@ async def _retrieve_output(
await self._persist_token(
token=token,
port=output_port,
inputs=_get_tokens_id(
inputs_token_id=_get_tokens_id(
list(job.inputs.values())
+ [
get_job_token(
Expand Down Expand Up @@ -893,7 +895,7 @@ async def run(self):
tag=tag, value=sorted(tokens, key=lambda cur: cur.tag)
),
port=output_port,
inputs=_get_tokens_id(tokens),
inputs_token_id=_get_tokens_id(tokens),
)
)
break
Expand Down Expand Up @@ -991,7 +993,7 @@ async def run(self):
await self._persist_token(
token=await self.process_input(job, token.value),
port=self.get_output_port(),
inputs=_get_tokens_id(in_list),
inputs_token_id=_get_tokens_id(in_list),
)
)
finally:
Expand Down Expand Up @@ -1075,7 +1077,7 @@ async def run(self):
await self._persist_token(
token=token["token"],
port=self.get_output_port(port_name),
inputs=ins,
inputs_token_id=ins,
)
)
# Create a new task in place of the completed one if the port is not terminated
Expand Down Expand Up @@ -1169,7 +1171,7 @@ async def run(self):
await self._persist_token(
token=await self._process_output(prefix),
port=self.get_output_port(),
inputs=_get_tokens_id(self.token_map.get(prefix)),
inputs_token_id=_get_tokens_id(self.token_map.get(prefix)),
)
)
# If all iterations are terminated, terminate the step
Expand Down Expand Up @@ -1284,7 +1286,7 @@ async def _propagate_job(
await self._persist_token(
token=JobToken(value=job),
port=self.get_output_port(),
inputs=_get_tokens_id(token_inputs),
inputs_token_id=_get_tokens_id(token_inputs),
)
)

Expand Down Expand Up @@ -1428,7 +1430,7 @@ async def _scatter(self, token: Token):
await self._persist_token(
token=t.retag(token.tag + "." + str(i)),
port=output_port,
inputs=_get_tokens_id([token]),
inputs_token_id=_get_tokens_id([token]),
)
)
else:
Expand Down Expand Up @@ -1538,7 +1540,7 @@ async def run(self):
await self._persist_token(
token=await self.transfer(job, token),
port=self.get_output_port(port_name),
inputs=_get_tokens_id(
inputs_token_id=_get_tokens_id(
list(inputs.values())
+ [
get_job_token(
Expand Down Expand Up @@ -1595,7 +1597,9 @@ async def run(self):
await self._persist_token(
token=token.update(token.value),
port=self.get_output_port(port_name),
inputs=_get_tokens_id(inputs.values()),
inputs_token_id=_get_tokens_id(
inputs.values()
),
)
)
# Otherwise, apply transformation and propagate outputs
Expand All @@ -1607,7 +1611,9 @@ async def run(self):
await self._persist_token(
token=token,
port=self.get_output_port(port_name),
inputs=_get_tokens_id(inputs.values()),
inputs_token_id=_get_tokens_id(
inputs.values()
),
)
)
else:
Expand All @@ -1616,7 +1622,7 @@ async def run(self):
await self._persist_token(
token=token,
port=self.get_output_port(port_name),
inputs=[],
inputs_token_id=[],
)
)
# Terminate step
Expand Down

0 comments on commit 332f85a

Please sign in to comment.