Skip to content

Commit 34e8079

Browse files
committed
add first API tests
1 parent 354893d commit 34e8079

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

stlearn/stacking.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,18 @@
1212
from sklearn.externals.joblib import Memory, Parallel, delayed
1313

1414

15-
def fit_estimator(clf, X, y):
15+
def _fit_estimator(clf, X, y):
16+
"""Helper to fit estimator"""
1617
return clf.fit(X, y)
1718

1819

19-
def predict_estimator(clf, X):
20+
def _predict_estimator(clf, X):
21+
"""Helper tor predict"""
2022
return clf.predict(X)
2123

2224

23-
def predict_proba_estimator(clf, X):
25+
def _predict_proba_estimator(clf, X):
26+
"""Helper to get prediction method"""
2427
# try predict_proba
2528
predict_proba = getattr(clf, "predict_proba", None)
2629
if callable(predict_proba):
@@ -56,11 +59,11 @@ def fit(self, X, y):
5659
"""
5760

5861
self.estimators = Parallel(n_jobs=self.n_jobs)(
59-
delayed(fit_estimator)(clf, x, y)
62+
delayed(_fit_estimator)(clf, x, y)
6063
for x, clf in zip(X, self.estimators))
6164

6265
predictions_ = Parallel(n_jobs=self.n_jobs)(
63-
delayed(predict_proba_estimator)(clf, x)
66+
delayed(_predict_proba_estimator)(clf, x)
6467
for x, clf in zip(X, self.estimators))
6568
predictions_ = np.array(predictions_).T
6669

@@ -73,7 +76,7 @@ def predict(self, X):
7376
"""
7477

7578
predictions_ = Parallel(n_jobs=self.n_jobs)(
76-
delayed(predict_proba_estimator)(clf, x)
79+
delayed(_predict_proba_estimator)(clf, x)
7780
for x, clf in zip(X, self.estimators))
7881
predictions_ = np.array(predictions_).T
7982

@@ -88,7 +91,7 @@ def predict_estimators(self, X):
8891
""" prediction from separate estimators
8992
"""
9093
predictions_ = Parallel(n_jobs=self.n_jobs)(
91-
delayed(predict_estimator)(clf, x)
94+
delayed(_predict_estimator)(clf, x)
9295
for x, clf in zip(X, self.estimators))
9396
return np.array(predictions_).T
9497

0 commit comments

Comments
 (0)