Skip to content

Commit

Permalink
ENH make pipeline.named_steps a property, fix pipeline.named_steps do…
Browse files Browse the repository at this point in the history
…ctest
  • Loading branch information
amueller committed May 7, 2015
1 parent b17b3aa commit 6bd0844
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,22 @@ class Pipeline(BaseEstimator):
>>> anova_svm.score(X, y) # doctest: +ELLIPSIS
0.77...
>>> # getting the selected features chosen by anova_filter
>>> support = anova_svm.named_steps.get_support()
>>> anova_svm.named_steps['anova'].get_support()
... # doctest: +NORMALIZE_WHITESPACE
array([ True, True, True, False, False, True, False, True, True, True,
False, False, True, False, True, False, False, False, False,
True], dtype=bool)
"""

# BaseEstimator interface

def __init__(self, steps):
self.named_steps = dict(steps)
names, estimators = zip(*steps)
if len(self.named_steps) != len(steps):
raise ValueError("Names provided are not unique: %s" % (names,))
if len(dict(steps)) != len(steps):
raise ValueError("Provided step names are not unique: %s" % (names,))

# shallow copy of steps
self.steps = tosequence(zip(names, estimators))
self.steps = tosequence(steps)
transforms = estimators[:-1]
estimator = estimators[-1]

Expand All @@ -110,14 +113,18 @@ def get_params(self, deep=True):
if not deep:
return super(Pipeline, self).get_params(deep=False)
else:
out = self.named_steps.copy()
out = self.named_steps
for name, step in six.iteritems(self.named_steps):
for key, value in six.iteritems(step.get_params(deep=True)):
out['%s__%s' % (name, key)] = value

out.update(super(Pipeline, self).get_params(deep=False))
return out

@property
def named_steps(self):
return dict(self.steps)

@property
def _final_estimator(self):
return self.steps[-1][1]
Expand Down

0 comments on commit 6bd0844

Please sign in to comment.