Skip to content

Commit 05c7828

Browse files
committed
high level sklearn
1 parent 53dddf8 commit 05c7828

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

stlearn/tests/test_stacking.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from sklearn.linear_model import LogisticRegression
99
from stlearn import StackingClassifier
1010
from stlearn import stack_features
11+
from sklearn.model_selection import cross_val_score
12+
from sklearn.base import is_classifier
13+
1114

1215
n_samples = 200
1316
n_estimators = 3
@@ -86,3 +89,13 @@ def test_stacking_essentials():
8689
feature_indices=[slice(5000, -5000), slice(1, 10), slice(20)],
8790
stacking_estimator=LogisticRegression())
8891
assert_raises(ValueError, stacking.fit, X_stacked, y)
92+
93+
94+
def test_sklearn_high_level():
95+
stacking = StackingClassifier(
96+
estimators=[LogisticRegression() for _ in range(3)],
97+
feature_indices=feature_indices,
98+
stacking_estimator=LogisticRegression())
99+
assert_true(is_classifier(stacking))
100+
scores = cross_val_score(X=X_stacked, y=y, estimator=stacking)
101+
assert_equal(len(scores), 3)

0 commit comments

Comments
 (0)