Skip to content

Commit 535353b

Browse files
committed
quick fix for target class logic in predict_proba
1 parent 4846832 commit 535353b

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

stlearn/stacking.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ def _predict_estimator(clf, X):
5959

6060
def _predict_proba_estimator(clf, X):
6161
"""Helper to get prediction method"""
62+
63+
# XXX this is not safe. Maybe add explicit 1st level scoring param.
6264
# try predict_proba
6365
predict_proba = getattr(clf, "predict_proba", None)
6466
if callable(predict_proba):
65-
return clf.predict_proba(X)[:, 0]
67+
return clf.predict_proba(X)
6668

6769
# or decision_function
6870
decision_function = getattr(clf, "decision_function", None)
@@ -131,6 +133,9 @@ def __init__(self, estimators,
131133
self.feature_indices = feature_indices
132134
self.n_jobs = n_jobs
133135

136+
def _disambiguate_probability(self, x):
137+
return x[:, -1] if np.ndim(x) > 1 else x
138+
134139
def fit(self, X, y):
135140
"""Fit all estimators according to the given training data.
136141
@@ -154,6 +159,8 @@ def fit(self, X, y):
154159
predictions_ = Parallel(n_jobs=self.n_jobs)(
155160
delayed(_predict_proba_estimator)(clf, x)
156161
for x, clf in zip(X_list, self.estimators))
162+
predictions_ = [self._disambiguate_probability(x)
163+
for x in predictions_]
157164
predictions_ = np.array(predictions_).T
158165

159166
self.stacking_estimator.fit(predictions_, y)
@@ -177,6 +184,8 @@ def predict(self, X):
177184
predictions_ = Parallel(n_jobs=self.n_jobs)(
178185
delayed(_predict_proba_estimator)(clf, x)
179186
for x, clf in zip(X_list, self.estimators))
187+
predictions_ = [self._disambiguate_probability(x)
188+
for x in predictions_]
180189
predictions_ = np.array(predictions_).T
181190

182191
return self.stacking_estimator.predict(predictions_)
@@ -199,12 +208,14 @@ def predict_proba(self, X):
199208
predictions_ = Parallel(n_jobs=self.n_jobs)(
200209
delayed(_predict_proba_estimator)(clf, x)
201210
for x, clf in zip(X_list, self.estimators))
211+
predictions_ = [self._disambiguate_probability(x)
212+
for x in predictions_]
202213
predictions_ = np.array(predictions_).T
203214

204215
return _predict_proba_estimator(self.stacking_estimator, predictions_)
205216

206217
def decision_function(self, X):
207-
return self.predict_proba(X)
218+
return self.predict_proba(X)[:, -1]
208219

209220
def score(self, X, y):
210221
"""Returns the mean accuracy on the given test data and labels.

stlearn/tests/test_stacking.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def test_stacking_essentials():
6767
predictions = stacking.predict(X_stacked)
6868
assert_array_equal(np.unique(predictions), np.array([0, 1]))
6969

70+
proba = stacking.predict_proba(X_stacked)
71+
assert_array_equal(proba.sum(1), np.ones_like(proba[:, 1]))
72+
7073
score = stacking.score(X_stacked, y)
7174
assert_true(np.isscalar(score))
7275

0 commit comments

Comments
 (0)