Skip to content

Commit

Permalink
Adress scikit-learn-contrib#176 - Fix "fit then sample" bug in pipeli…
Browse files Browse the repository at this point in the history
  • Loading branch information
chkoar authored and glemaitre committed Oct 30, 2016
1 parent ee33364 commit 0573d83
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
8 changes: 6 additions & 2 deletions imblearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,14 @@ def sample(self, X, y):
Xt = X
for _, transform in self.steps[:-1]:
if hasattr(transform, "fit_sample"):
pass
# XXX: Calling sample in pipeline it means that the
# last estimator is a sampler. Samplers don't carry
# the sampled data. So, call 'fit_sample' in all intermediate
# steps to get the sampled data for the last estimator.
Xt, y = transform.fit_sample(Xt, y)
else:
Xt = transform.transform(Xt)
return self.steps[-1][-1].sample(Xt, y)
return self.steps[-1][-1].fit_sample(Xt, y)

@if_delegate_has_method(delegate='_final_estimator')
def predict(self, X):
Expand Down
36 changes: 34 additions & 2 deletions imblearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
assert_true, assert_warns_message)

from imblearn.pipeline import Pipeline, make_pipeline
from imblearn.under_sampling import RandomUnderSampler
from imblearn.under_sampling import RandomUnderSampler, EditedNearestNeighbours as ENN

JUNK_FOOD_DOCS = (
"the pizza pizza beer copyright",
Expand Down Expand Up @@ -473,4 +473,36 @@ def test_pipeline_with_step_that_it_is_pipeline():
filter1 = SelectKBest(f_classif, k=2)
pipe1 = Pipeline([('rus', rus), ('anova', filter1)])
assert_raises(TypeError, Pipeline, [('pipe1', pipe1), ('logistic', clf)])


def test_pipeline_fit_then_sample_with_sampler_last_estimator():
X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9],
n_informative=3, n_redundant=1, flip_y=0,
n_features=20, n_clusters_per_class=1,
n_samples=50000, random_state=0)

rus = RandomUnderSampler(random_state=42)
enn = ENN()
pipeline = make_pipeline(rus, enn)
X_fit_sample_resampled, y_fit_sample_resampled = pipeline.fit_sample(X,y)
pipeline = make_pipeline(rus, enn)
pipeline.fit(X,y)
X_fit_then_sample_resampled, y_fit_then_sample_resampled = pipeline.sample(X,y)
assert_array_equal(X_fit_sample_resampled, X_fit_then_sample_resampled)
assert_array_equal(y_fit_sample_resampled, y_fit_then_sample_resampled)


def test_pipeline_fit_then_sample_of_three_samplers_with_sampler_last_estimator():
X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9],
n_informative=3, n_redundant=1, flip_y=0,
n_features=20, n_clusters_per_class=1,
n_samples=50000, random_state=0)

rus = RandomUnderSampler(random_state=42)
enn = ENN()
pipeline = make_pipeline(rus, enn, rus)
X_fit_sample_resampled, y_fit_sample_resampled = pipeline.fit_sample(X,y)
pipeline = make_pipeline(rus, enn, rus)
pipeline.fit(X,y)
X_fit_then_sample_resampled, y_fit_then_sample_resampled = pipeline.sample(X,y)
assert_array_equal(X_fit_sample_resampled, X_fit_then_sample_resampled)
assert_array_equal(y_fit_sample_resampled, y_fit_then_sample_resampled)

0 comments on commit 0573d83

Please sign in to comment.