|
| 1 | +from nose.tools import assert_equal |
| 2 | +from nose.tools import assert_true |
| 3 | +from nose.tools import assert_raises |
| 4 | +import numpy as np |
| 5 | +from numpy.testing import assert_array_equal |
| 6 | + |
| 7 | +from sklearn.datasets import make_classification |
| 8 | +from sklearn.linear_model import LogisticRegression |
| 9 | +from sklearn.model_selection import ShuffleSplit |
| 10 | +from stlearn import StackingClassifier |
| 11 | + |
| 12 | +n_samples = 200 |
| 13 | +n_estimators = 2 |
| 14 | +X0, y = make_classification(n_samples=200, random_state=42) |
| 15 | +X1 = X0 ** 2 |
| 16 | +X = np.array([X0, X1]) |
| 17 | +ss = ShuffleSplit(n_splits=10, test_size=0.2, random_state=42) |
| 18 | + |
| 19 | + |
| 20 | +def test_stacking_essentials(): |
| 21 | + """Test initializaing and essential basic function""" |
| 22 | + stacking = StackingClassifier( |
| 23 | + estimators=n_estimators * [LogisticRegression()], |
| 24 | + stacking_estimator=LogisticRegression()) |
| 25 | + # assert_equal(getattr(stacking, 'predictions_', None), None) |
| 26 | + assert_equal(stacking.stacking_estimator.__class__, |
| 27 | + LogisticRegression) |
| 28 | + assert_equal([ee.__class__ for ee in stacking.estimators], |
| 29 | + n_estimators * [LogisticRegression]) |
| 30 | + assert_raises(ValueError, stacking.fit, X[0], y) |
| 31 | + assert_raises(ValueError, stacking.fit, X[:1], y) |
| 32 | + assert_raises(ValueError, stacking.fit, X[:, :1], y) |
| 33 | + |
| 34 | + stacking.fit(X, y) |
| 35 | + |
| 36 | + predictions = stacking.predict(X) |
| 37 | + assert_array_equal(np.unique(predictions), np.array([0, 1])) |
| 38 | + |
| 39 | + score = stacking.score(X, y) |
| 40 | + assert_true(np.isscalar(score)) |
| 41 | + |
| 42 | + predictions_estimators = stacking.predict_estimators(X) |
| 43 | + assert_array_equal( |
| 44 | + predictions_estimators.shape, (n_samples, n_estimators)) |
| 45 | + scores_estimators = stacking.score_estimators(X, y) |
| 46 | + assert_equal(len(scores_estimators), n_estimators) |
0 commit comments