Skip to content

Commit f2d756f

Browse files
authored
Merge pull request #107 from common-workflow-language/scatter-inputs
Evaluate valueFrom after scattering.
2 parents 81ff56f + c8d672e commit f2d756f

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

cwltool/workflow.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -242,23 +242,23 @@ def try_make_job(self, step, **kwargs):
242242
raise WorkflowException("Workflow step contains valueFrom but StepInputExpressionRequirement not in requirements")
243243

244244
vfinputs = {shortname(k): v for k,v in inputobj.iteritems()}
245-
def valueFromFunc(k, v): # type: (Any, Any) -> Any
246-
if k in valueFrom:
247-
return expression.do_eval(
248-
valueFrom[k], vfinputs, self.workflow.requirements,
249-
None, None, {}, context=v)
250-
else:
251-
return v
245+
def postScatterEval(io):
246+
shortio = {shortname(k): v for k,v in io.iteritems()}
247+
def valueFromFunc(k, v): # type: (Any, Any) -> Any
248+
if k in valueFrom:
249+
return expression.do_eval(
250+
valueFrom[k], shortio, self.workflow.requirements,
251+
None, None, {}, context=v)
252+
else:
253+
return v
254+
return {k: valueFromFunc(k, v) for k,v in io.items()}
252255

253256
if "scatter" in step.tool:
254257
scatter = aslist(step.tool["scatter"])
255258
method = step.tool.get("scatterMethod")
256259
if method is None and len(scatter) != 1:
257260
raise WorkflowException("Must specify scatterMethod when scattering over multiple inputs")
258-
kwargs["valueFrom"] = valueFromFunc
259-
260-
inputobj = {k: valueFromFunc(k, v) if k not in scatter else v
261-
for k,v in inputobj.items()}
261+
kwargs["postScatterEval"] = postScatterEval
262262

263263
if method == "dotproduct" or method is None:
264264
jobs = dotproduct_scatter(step, inputobj, scatter,
@@ -280,7 +280,7 @@ def valueFromFunc(k, v): # type: (Any, Any) -> Any
280280
callback), 0, **kwargs)
281281
else:
282282
_logger.debug(u"[job %s] job input %s", step.name, json.dumps(inputobj, indent=4))
283-
inputobj = {k: valueFromFunc(k, v) for k,v in inputobj.items()}
283+
inputobj = postScatterEval(inputobj)
284284
_logger.debug(u"[job %s] evaluated job input to %s", step.name, json.dumps(inputobj, indent=4))
285285
jobs = step.job(inputobj, callback, **kwargs)
286286

@@ -592,7 +592,9 @@ def dotproduct_scatter(process, joborder, scatter_keys, output_callback, **kwarg
592592
for n in range(0, l):
593593
jo = copy.copy(joborder)
594594
for s in scatter_keys:
595-
jo[s] = kwargs["valueFrom"](s, joborder[s][n])
595+
jo[s] = joborder[s][n]
596+
597+
jo = kwargs["postScatterEval"](jo)
596598

597599
for j in process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs):
598600
yield j
@@ -612,9 +614,10 @@ def nested_crossproduct_scatter(process, joborder, scatter_keys, output_callback
612614

613615
for n in range(0, l):
614616
jo = copy.copy(joborder)
615-
jo[scatter_key] = kwargs["valueFrom"](scatter_key, joborder[scatter_key][n])
617+
jo[scatter_key] = joborder[scatter_key][n]
616618

617619
if len(scatter_keys) == 1:
620+
jo = kwargs["postScatterEval"](jo)
618621
for j in process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs):
619622
yield j
620623
else:
@@ -661,9 +664,10 @@ def flat_crossproduct_scatter(process, joborder, scatter_keys, output_callback,
661664
put = startindex
662665
for n in range(0, l):
663666
jo = copy.copy(joborder)
664-
jo[scatter_key] = kwargs["valueFrom"](scatter_key, joborder[scatter_key][n])
667+
jo[scatter_key] = joborder[scatter_key][n]
665668

666669
if len(scatter_keys) == 1:
670+
jo = kwargs["postScatterEval"](jo)
667671
for j in process.job(jo, functools.partial(rc.receive_scatter_output, put), **kwargs):
668672
yield j
669673
put += 1

0 commit comments

Comments
 (0)