Skip to content

Commit 21c263a

Browse files
committed
high level sklearn pt2
1 parent 05c7828 commit 21c263a

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

stlearn/tests/test_stacking.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,23 @@ def test_stacking_essentials():
7878

7979
assert_raises(ValueError, stacking.fit, X, y)
8080
stacking = StackingClassifier(
81-
estimators=[LogisticRegression() for _ in range(3)],
81+
estimators=[LogisticRegression() for _ in range(n_estimators)],
8282
feature_indices=[np.array([-500]), np.array([1]), np.array([2])],
8383
stacking_estimator=LogisticRegression())
8484

8585
assert_raises(ValueError, stacking.fit, X_stacked, y)
8686

8787
stacking = StackingClassifier(
88-
estimators=[LogisticRegression() for _ in range(3)],
88+
estimators=[LogisticRegression() for _ in range(n_estimators)],
8989
feature_indices=[slice(5000, -5000), slice(1, 10), slice(20)],
9090
stacking_estimator=LogisticRegression())
9191
assert_raises(ValueError, stacking.fit, X_stacked, y)
9292

9393

9494
def test_sklearn_high_level():
95+
"""Test high-level sklearn API"""
9596
stacking = StackingClassifier(
96-
estimators=[LogisticRegression() for _ in range(3)],
97+
estimators=[LogisticRegression() for _ in range(n_estimators)],
9798
feature_indices=feature_indices,
9899
stacking_estimator=LogisticRegression())
99100
assert_true(is_classifier(stacking))

0 commit comments

Comments
 (0)