Skip to content

Commit 9b9d081

Browse files
committed
stacking predict proba
1 parent 21c263a commit 9b9d081

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

stlearn/stacking.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,31 @@ def predict(self, X):
181181

182182
return self.stacking_estimator.predict(predictions_)
183183

184+
def predict_proba(self, X):
185+
"""Predict class probability for samples in X.
186+
187+
Parameters
188+
----------
189+
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
190+
The multi-input samples.
191+
192+
Returns
193+
-------
194+
C : array, shape = (n_samples)
195+
Predicted class label per sample.
196+
"""
197+
_check_Xy(self, X)
198+
X_list = _split_features(X, self.feature_indices)
199+
predictions_ = Parallel(n_jobs=self.n_jobs)(
200+
delayed(_predict_proba_estimator)(clf, x)
201+
for x, clf in zip(X_list, self.estimators))
202+
predictions_ = np.array(predictions_).T
203+
204+
return _predict_proba_estimator(self.stacking_estimator, predictions_)
205+
206+
def decision_function(self, X):
207+
return self.predict_proba(X)
208+
184209
def score(self, X, y):
185210
"""Returns the mean accuracy on the given test data and labels.
186211

0 commit comments

Comments
 (0)