Skip to content

Commit c156d04

Browse files
author
Matthieu Ancellin
committed
Fix bug in LabelledPipeline.fix
1 parent 7587286 commit c156d04

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

labelled_functions/pipeline.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# coding: utf-8
33

44
from typing import Set
5-
from collections import namedtuple
5+
from collections import namedtuple, defaultdict
66
from toolz.itertoolz import groupby
77
from toolz.dicttoolz import merge, keyfilter
88

@@ -36,7 +36,8 @@ def __init__(self,
3636
if any(f.output_names is Unknown for f in self.funcs):
3737
raise AttributeError("Cannot build a pipeline with a function whose outputs are unknown.")
3838

39-
pipe_inputs, sub_default_values, pipe_outputs, *_ = self._graph()
39+
self._graph_data = self._graph()
40+
pipe_inputs, sub_default_values, pipe_outputs, *_ = self._graph_data
4041

4142
if default_values is None:
4243
sub_default_values = sub_default_values
@@ -93,15 +94,22 @@ def __ror__(self, other):
9394
else:
9495
return NotImplemented
9596

97+
def _which_input_is_used_by_this_function(self):
98+
_, _, _, _, _, edges = self._graph_data
99+
inputs_used_by = defaultdict(set)
100+
for e in edges:
101+
if e.start is None:
102+
inputs_used_by[e.end].add(e.label)
103+
return inputs_used_by
104+
96105
def fix(self, **names_to_fix):
106+
inputs_used_by = self._which_input_is_used_by_this_function()
107+
97108
fixed_funcs = []
98109
for f in self.funcs:
99-
fixable_names = {n: v for n, v in names_to_fix.items() if n in f.input_names}
110+
fixable_names = {k: v for k, v in names_to_fix.items() if k in inputs_used_by[f.name]}
100111
if len(fixable_names) > 0:
101112
fixed_funcs.append(f.fix(**fixable_names))
102-
103-
# Remove inputs that have been fixed
104-
names_to_fix = {n: v for n, v in names_to_fix.items() if n not in fixable_names}
105113
else:
106114
fixed_funcs.append(f)
107115

test/test_pipeline.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from labelled_functions.labels import LabelledFunction
6+
from labelled_functions.labels import LabelledFunction, label
77
from labelled_functions.pipeline import LabelledPipeline, pipeline, compose
88
from labelled_functions.special_functions import let, relabel, show
99
from labelled_functions.decorators import keeping_inputs
@@ -66,9 +66,6 @@ def g(b, c=3):
6666
assert pipe(a=1) == {'output': 1}
6767
assert pipe(a=1, c=0) == {'output': -2}
6868

69-
pipe = pipeline([f, g]).fix(a=1)
70-
assert pipe() == {'output': 1}
71-
7269
###
7370
pipe = pipeline([cube], default_values={'x': 1.0})
7471
assert pipe()['volume'] == 1.0
@@ -79,6 +76,19 @@ def g(b, c=3):
7976
assert pipe(x=10.0)['volume'] == 1000.0
8077

8178

79+
def test_fix():
80+
a, b = 1, 2
81+
f = label(lambda x, y: x + y, name="foo", output_names=['z'])
82+
g = label(lambda x, z: x * z, name="bar", output_names=['u'])
83+
h = label(lambda u: u**2, name="baz", output_names=['w'])
84+
85+
pipe = pipeline([f, g, h]).fix(y=b)
86+
assert pipe(x=a) == {'w': (a*(a+b))**2}
87+
88+
pipe = pipeline([f, g, h]).fix(x=a)
89+
assert pipe(y=b) == {'w': (a*(a+b))**2}
90+
91+
8292
def test_let():
8393
l = let(x=1.0)
8494
assert l.input_names == []

0 commit comments

Comments
 (0)