forked from scikit-learn-contrib/imbalanced-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline.py
636 lines (532 loc) · 22.9 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
"""
The :mod:`imblearn.pipeline` module implements utilities to build a
composite estimator, as a chain of transforms, samples and estimators.
"""
# Adapted from scikit-learn
# Author: Edouard Duchesnay
# Gael Varoquaux
# Virgile Fritsch
# Alexandre Gramfort
# Lars Buitinck
# Christos Aridas
# Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: BSD
from __future__ import division
from collections import defaultdict
from itertools import islice
from sklearn import pipeline
from sklearn.base import clone
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.validation import check_memory
__all__ = ['Pipeline', 'make_pipeline']
class Pipeline(pipeline.Pipeline):
"""Pipeline of transforms and resamples with a final estimator.
Sequentially apply a list of transforms, sampling, and a final estimator.
Intermediate steps of the pipeline must be transformers or resamplers,
that is, they must implement fit, transform and sample methods.
The samplers are only applied during fit.
The final estimator only needs to implement fit.
The transformers and samplers in the pipeline can be cached using
``memory`` argument.
The purpose of the pipeline is to assemble several steps that can be
cross-validated together while setting different parameters.
For this, it enables setting parameters of the various steps using their
names and the parameter name separated by a '__', as in the example below.
A step's estimator may be replaced entirely by setting the parameter
with its name to another estimator, or a transformer removed by setting
it to 'passthrough' or ``None``.
Parameters
----------
steps : list
List of (name, transform) tuples (implementing
fit/transform/fit_resample) that are chained, in the order in which
they are chained, with the last object an estimator.
memory : Instance of joblib.Memory or string, optional (default=None)
Used to cache the fitted transformers of the pipeline. By default,
no caching is performed. If a string is given, it is the path to
the caching directory. Enabling caching triggers a clone of
the transformers before fitting. Therefore, the transformer
instance given to the pipeline cannot be inspected
directly. Use the attribute ``named_steps`` or ``steps`` to
inspect estimators within the pipeline. Caching the
transformers is advantageous when fitting is time consuming.
Attributes
----------
named_steps : dict
Read-only attribute to access any step parameter by user given name.
Keys are step names and values are steps parameters.
Notes
-----
See :ref:`sphx_glr_auto_examples_pipeline_plot_pipeline_classification.py`
See also
--------
make_pipeline : helper function to make pipeline.
Examples
--------
>>> from collections import Counter
>>> from sklearn.datasets import make_classification
>>> from sklearn.model_selection import train_test_split as tts
>>> from sklearn.decomposition import PCA
>>> from sklearn.neighbors import KNeighborsClassifier as KNN
>>> from sklearn.metrics import classification_report
>>> from imblearn.over_sampling import SMOTE
>>> from imblearn.pipeline import Pipeline # doctest: +NORMALIZE_WHITESPACE
>>> X, y = make_classification(n_classes=2, class_sep=2,
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
>>> print('Original dataset shape {}'.format(Counter(y)))
Original dataset shape Counter({1: 900, 0: 100})
>>> pca = PCA()
>>> smt = SMOTE(random_state=42)
>>> knn = KNN()
>>> pipeline = Pipeline([('smt', smt), ('pca', pca), ('knn', knn)])
>>> X_train, X_test, y_train, y_test = tts(X, y, random_state=42)
>>> pipeline.fit(X_train, y_train) # doctest: +ELLIPSIS
Pipeline(...)
>>> y_hat = pipeline.predict(X_test)
>>> print(classification_report(y_test, y_hat))
precision recall f1-score support
<BLANKLINE>
0 0.87 1.00 0.93 26
1 1.00 0.98 0.99 224
<BLANKLINE>
accuracy 0.98 250
macro avg 0.93 0.99 0.96 250
weighted avg 0.99 0.98 0.98 250
<BLANKLINE>
"""
# BaseEstimator interface
def _validate_steps(self):
names, estimators = zip(*self.steps)
# validate names
self._validate_names(names)
# validate estimators
transformers = estimators[:-1]
estimator = estimators[-1]
for t in transformers:
if t is None or t == 'passthrough':
continue
if (not (hasattr(t, "fit") or
hasattr(t, "fit_transform") or
hasattr(t, "fit_resample")) or
not (hasattr(t, "transform") or
hasattr(t, "fit_resample"))):
raise TypeError(
"All intermediate steps of the chain should "
"be estimators that implement fit and transform or "
"fit_resample (but not both) or be a string 'passthrough' "
"'%s' (type %s) doesn't)" % (t, type(t)))
if (hasattr(t, "fit_resample") and (hasattr(t, "fit_transform") or
hasattr(t, "transform"))):
raise TypeError(
"All intermediate steps of the chain should "
"be estimators that implement fit and transform or sample."
" '%s' implements both)" % (t))
if isinstance(t, pipeline.Pipeline):
raise TypeError(
"All intermediate steps of the chain should not be"
" Pipelines")
# We allow last estimator to be None as an identity transformation
if (estimator is not None and estimator != 'passthrough'
and not hasattr(estimator, "fit")):
raise TypeError("Last step of Pipeline should implement fit or be "
"the string 'passthrough'. '%s' (type %s) doesn't"
% (estimator, type(estimator)))
# Estimator interface
def _fit(self, X, y=None, **fit_params):
self.steps = list(self.steps)
self._validate_steps()
# Setup the memory
memory = check_memory(self.memory)
fit_transform_one_cached = memory.cache(_fit_transform_one)
fit_resample_one_cached = memory.cache(_fit_resample_one)
fit_params_steps = {name: {} for name, step in self.steps
if step is not None}
for pname, pval in fit_params.items():
step, param = pname.split('__', 1)
fit_params_steps[step][param] = pval
for step_idx, name, transformer in self._iter(with_final=False):
if hasattr(memory, 'location'):
# joblib >= 0.12
if memory.location is None:
# we do not clone when caching is disabled to
# preserve backward compatibility
cloned_transformer = transformer
else:
cloned_transformer = clone(transformer)
elif hasattr(memory, 'cachedir'):
# joblib < 0.11
if memory.cachedir is None:
# we do not clone when caching is disabled to
# preserve backward compatibility
cloned_transformer = transformer
else:
cloned_transformer = clone(transformer)
# Fit or load from cache the current transfomer
if (hasattr(cloned_transformer, "transform") or
hasattr(cloned_transformer, "fit_transform")):
X, fitted_transformer = fit_transform_one_cached(
cloned_transformer, None, X, y,
**fit_params_steps[name])
elif hasattr(cloned_transformer, "fit_resample"):
X, y, fitted_transformer = fit_resample_one_cached(
cloned_transformer, X, y, **fit_params_steps[name])
# Replace the transformer of the step with the fitted
# transformer. This is necessary when loading the transformer
# from the cache.
self.steps[step_idx] = (name, fitted_transformer)
if self._final_estimator == 'passthrough':
return X, y, {}
return X, y, fit_params_steps[self.steps[-1][0]]
def fit(self, X, y=None, **fit_params):
"""Fit the model
Fit all the transforms/samplers one after the other and
transform/sample the data, then fit the transformed/sampled
data using the final estimator.
Parameters
----------
X : iterable
Training data. Must fulfill input requirements of first step of the
pipeline.
y : iterable, default=None
Training targets. Must fulfill label requirements for all steps of
the pipeline.
**fit_params : dict of string -> object
Parameters passed to the ``fit`` method of each step, where
each parameter name is prefixed such that parameter ``p`` for step
``s`` has key ``s__p``.
Returns
-------
self : Pipeline
This estimator
"""
Xt, yt, fit_params = self._fit(X, y, **fit_params)
if self._final_estimator != 'passthrough':
self._final_estimator.fit(Xt, yt, **fit_params)
return self
def fit_transform(self, X, y=None, **fit_params):
"""Fit the model and transform with the final estimator
Fits all the transformers/samplers one after the other and
transform/sample the data, then uses fit_transform on
transformed data with the final estimator.
Parameters
----------
X : iterable
Training data. Must fulfill input requirements of first step of the
pipeline.
y : iterable, default=None
Training targets. Must fulfill label requirements for all steps of
the pipeline.
**fit_params : dict of string -> object
Parameters passed to the ``fit`` method of each step, where
each parameter name is prefixed such that parameter ``p`` for step
``s`` has key ``s__p``.
Returns
-------
Xt : array-like, shape = [n_samples, n_transformed_features]
Transformed samples
"""
last_step = self._final_estimator
Xt, yt, fit_params = self._fit(X, y, **fit_params)
if last_step == 'passthrough':
return Xt
elif hasattr(last_step, 'fit_transform'):
return last_step.fit_transform(Xt, yt, **fit_params)
else:
return last_step.fit(Xt, yt, **fit_params).transform(Xt)
def fit_resample(self, X, y=None, **fit_params):
"""Fit the model and sample with the final estimator
Fits all the transformers/samplers one after the other and
transform/sample the data, then uses fit_resample on transformed
data with the final estimator.
Parameters
----------
X : iterable
Training data. Must fulfill input requirements of first step of the
pipeline.
y : iterable, default=None
Training targets. Must fulfill label requirements for all steps of
the pipeline.
**fit_params : dict of string -> object
Parameters passed to the ``fit`` method of each step, where
each parameter name is prefixed such that parameter ``p`` for step
``s`` has key ``s__p``.
Returns
-------
Xt : array-like, shape = [n_samples, n_transformed_features]
Transformed samples
yt : array-like, shape = [n_samples, n_transformed_features]
Transformed target
"""
last_step = self._final_estimator
Xt, yt, fit_params = self._fit(X, y, **fit_params)
if last_step == 'passthrough':
return Xt
elif hasattr(last_step, 'fit_resample'):
return last_step.fit_resample(Xt, yt, **fit_params)
@if_delegate_has_method(delegate='_final_estimator')
def predict(self, X, **predict_params):
"""Apply transformers/samplers to the data, and predict with the final
estimator
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
**predict_params : dict of string -> object
Parameters to the ``predict`` called at the end of all
transformations in the pipeline. Note that while this may be
used to return uncertainties from some models with return_std
or return_cov, uncertainties that are generated by the
transformations in the pipeline are not propagated to the
final estimator.
Returns
-------
y_pred : array-like
"""
Xt = X
for _, _, transform in self._iter(with_final=False):
if hasattr(transform, "fit_resample"):
pass
else:
Xt = transform.transform(Xt)
return self.steps[-1][-1].predict(Xt, **predict_params)
@if_delegate_has_method(delegate='_final_estimator')
def fit_predict(self, X, y=None, **fit_params):
"""Applies fit_predict of last step in pipeline after transforms.
Applies fit_transforms of a pipeline to the data, followed by the
fit_predict method of the final estimator in the pipeline. Valid
only if the final estimator implements fit_predict.
Parameters
----------
X : iterable
Training data. Must fulfill input requirements of first step of
the pipeline.
y : iterable, default=None
Training targets. Must fulfill label requirements for all steps
of the pipeline.
**fit_params : dict of string -> object
Parameters passed to the ``fit`` method of each step, where
each parameter name is prefixed such that parameter ``p`` for step
``s`` has key ``s__p``.
Returns
-------
y_pred : array-like
"""
Xt, yt, fit_params = self._fit(X, y, **fit_params)
return self.steps[-1][-1].fit_predict(Xt, yt, **fit_params)
@if_delegate_has_method(delegate='_final_estimator')
def predict_proba(self, X):
"""Apply transformers/samplers, and predict_proba of the final
estimator
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
Returns
-------
y_proba : array-like, shape = [n_samples, n_classes]
"""
Xt = X
for _, _, transform in self._iter(with_final=False):
if hasattr(transform, "fit_resample"):
pass
else:
Xt = transform.transform(Xt)
return self.steps[-1][-1].predict_proba(Xt)
@if_delegate_has_method(delegate='_final_estimator')
def score_samples(self, X):
"""Apply transforms, and score_samples of the final estimator.
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
Returns
-------
y_score : ndarray, shape (n_samples,)
"""
Xt = X
for _, _, transformer in self._iter(with_final=False):
if hasattr(transformer, "fit_resample"):
pass
else:
Xt = transformer.transform(Xt)
return self.steps[-1][-1].score_samples(Xt)
@if_delegate_has_method(delegate='_final_estimator')
def decision_function(self, X):
"""Apply transformers/samplers, and decision_function of the final
estimator
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
Returns
-------
y_score : array-like, shape = [n_samples, n_classes]
"""
Xt = X
for _, _, transform in self._iter(with_final=False):
if hasattr(transform, "fit_resample"):
pass
else:
Xt = transform.transform(Xt)
return self.steps[-1][-1].decision_function(Xt)
@if_delegate_has_method(delegate='_final_estimator')
def predict_log_proba(self, X):
"""Apply transformers/samplers, and predict_log_proba of the final
estimator
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
Returns
-------
y_score : array-like, shape = [n_samples, n_classes]
"""
Xt = X
for _, _, transform in self._iter(with_final=False):
if hasattr(transform, "fit_resample"):
pass
else:
Xt = transform.transform(Xt)
return self.steps[-1][-1].predict_log_proba(Xt)
@property
def transform(self):
"""Apply transformers/samplers, and transform with the final estimator
This also works where final estimator is ``None``: all prior
transformations are applied.
Parameters
----------
X : iterable
Data to transform. Must fulfill input requirements of first step
of the pipeline.
Returns
-------
Xt : array-like, shape = [n_samples, n_transformed_features]
"""
# _final_estimator is None or has transform, otherwise attribute error
if self._final_estimator != 'passthrough':
self._final_estimator.transform
return self._transform
def _transform(self, X):
Xt = X
for _, _, transform in self._iter():
if hasattr(transform, "fit_resample"):
pass
else:
Xt = transform.transform(Xt)
return Xt
@property
def inverse_transform(self):
"""Apply inverse transformations in reverse order
All estimators in the pipeline must support ``inverse_transform``.
Parameters
----------
Xt : array-like, shape = [n_samples, n_transformed_features]
Data samples, where ``n_samples`` is the number of samples and
``n_features`` is the number of features. Must fulfill
input requirements of last step of pipeline's
``inverse_transform`` method.
Returns
-------
Xt : array-like, shape = [n_samples, n_features]
"""
# raise AttributeError if necessary for hasattr behaviour
for _, _, transform in self._iter():
transform.inverse_transform
return self._inverse_transform
def _inverse_transform(self, X):
Xt = X
reverse_iter = reversed(list(self._iter()))
for _, _, transform in reverse_iter:
if hasattr(transform, "fit_resample"):
pass
else:
Xt = transform.inverse_transform(Xt)
return Xt
@if_delegate_has_method(delegate='_final_estimator')
def score(self, X, y=None, sample_weight=None):
"""Apply transformers/samplers, and score with the final estimator
Parameters
----------
X : iterable
Data to predict on. Must fulfill input requirements of first step
of the pipeline.
y : iterable, default=None
Targets used for scoring. Must fulfill label requirements for all
steps of the pipeline.
sample_weight : array-like, default=None
If not None, this argument is passed as ``sample_weight`` keyword
argument to the ``score`` method of the final estimator.
Returns
-------
score : float
"""
Xt = X
for _, _, transform in self._iter(with_final=False):
if hasattr(transform, "fit_resample"):
pass
else:
Xt = transform.transform(Xt)
score_params = {}
if sample_weight is not None:
score_params['sample_weight'] = sample_weight
return self.steps[-1][-1].score(Xt, y, **score_params)
def _fit_transform_one(transformer, weight, X, y, **fit_params):
if hasattr(transformer, 'fit_transform'):
res = transformer.fit_transform(X, y, **fit_params)
else:
res = transformer.fit(X, y, **fit_params).transform(X)
# if we have a weight for this transformer, multiply output
if weight is None:
return res, transformer
return res * weight, transformer
def _fit_resample_one(sampler, X, y, **fit_params):
X_res, y_res = sampler.fit_resample(X, y, **fit_params)
return X_res, y_res, sampler
def make_pipeline(*steps, **kwargs):
"""Construct a Pipeline from the given estimators.
This is a shorthand for the Pipeline constructor; it does not require, and
does not permit, naming the estimators. Instead, their names will be set
to the lowercase of their types automatically.
Parameters
----------
*steps : list of estimators.
memory : None, str or object with the joblib.Memory interface, optional
Used to cache the fitted transformers of the pipeline. By default,
no caching is performed. If a string is given, it is the path to
the caching directory. Enabling caching triggers a clone of
the transformers before fitting. Therefore, the transformer
instance given to the pipeline cannot be inspected
directly. Use the attribute ``named_steps`` or ``steps`` to
inspect estimators within the pipeline. Caching the
transformers is advantageous when fitting is time consuming.
Returns
-------
p : Pipeline
See also
--------
imblearn.pipeline.Pipeline : Class for creating a pipeline of
transforms with a final estimator.
Examples
--------
>>> from sklearn.naive_bayes import GaussianNB
>>> from sklearn.preprocessing import StandardScaler
>>> make_pipeline(StandardScaler(), GaussianNB(priors=None))
... # doctest: +NORMALIZE_WHITESPACE
Pipeline(memory=None,
steps=[('standardscaler',
StandardScaler(copy=True, with_mean=True, with_std=True)),
('gaussiannb',
GaussianNB(priors=None, var_smoothing=1e-09))],
verbose=False)
"""
memory = kwargs.pop('memory', None)
if kwargs:
raise TypeError('Unknown keyword arguments: "{}"'
.format(list(kwargs.keys())[0]))
return Pipeline(pipeline._name_estimators(steps), memory=memory)