|
2 | 2 | # coding: utf-8
|
3 | 3 |
|
4 | 4 | from typing import Set
|
5 |
| -from collections import namedtuple |
| 5 | +from collections import namedtuple, defaultdict |
6 | 6 | from toolz.itertoolz import groupby
|
7 | 7 | from toolz.dicttoolz import merge, keyfilter
|
8 | 8 |
|
@@ -36,7 +36,8 @@ def __init__(self,
|
36 | 36 | if any(f.output_names is Unknown for f in self.funcs):
|
37 | 37 | raise AttributeError("Cannot build a pipeline with a function whose outputs are unknown.")
|
38 | 38 |
|
39 |
| - pipe_inputs, sub_default_values, pipe_outputs, *_ = self._graph() |
| 39 | + self._graph_data = self._graph() |
| 40 | + pipe_inputs, sub_default_values, pipe_outputs, *_ = self._graph_data |
40 | 41 |
|
41 | 42 | if default_values is None:
|
42 | 43 | sub_default_values = sub_default_values
|
@@ -93,15 +94,22 @@ def __ror__(self, other):
|
93 | 94 | else:
|
94 | 95 | return NotImplemented
|
95 | 96 |
|
| 97 | + def _which_input_is_used_by_this_function(self): |
| 98 | + _, _, _, _, _, edges = self._graph_data |
| 99 | + inputs_used_by = defaultdict(set) |
| 100 | + for e in edges: |
| 101 | + if e.start is None: |
| 102 | + inputs_used_by[e.end].add(e.label) |
| 103 | + return inputs_used_by |
| 104 | + |
96 | 105 | def fix(self, **names_to_fix):
|
| 106 | + inputs_used_by = self._which_input_is_used_by_this_function() |
| 107 | + |
97 | 108 | fixed_funcs = []
|
98 | 109 | for f in self.funcs:
|
99 |
| - fixable_names = {n: v for n, v in names_to_fix.items() if n in f.input_names} |
| 110 | + fixable_names = {k: v for k, v in names_to_fix.items() if k in inputs_used_by[f.name]} |
100 | 111 | if len(fixable_names) > 0:
|
101 | 112 | fixed_funcs.append(f.fix(**fixable_names))
|
102 |
| - |
103 |
| - # Remove inputs that have been fixed |
104 |
| - names_to_fix = {n: v for n, v in names_to_fix.items() if n not in fixable_names} |
105 | 113 | else:
|
106 | 114 | fixed_funcs.append(f)
|
107 | 115 |
|
|
0 commit comments