Skip to content

Commit 4f713ce

Browse files
committed
Merge pull request scikit-learn#5182 from MechCoder/predict_proba_fix
[MRG + 2] predict_proba should use the softmax function in the multinomial case
2 parents 0bf7536 + c85f2ad commit 4f713ce

File tree

3 files changed

+42
-3
lines changed

3 files changed

+42
-3
lines changed

sklearn/linear_model/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _predict_proba_lr(self, X):
255255
np.exp(prob, prob)
256256
prob += 1
257257
np.reciprocal(prob, prob)
258-
if len(prob.shape) == 1:
258+
if prob.ndim == 1:
259259
return np.vstack([1 - prob, prob]).T
260260
else:
261261
# OvR normalization, like LibLinear's predict_probability

sklearn/linear_model/logistic.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
squared_norm)
2626
from ..utils.optimize import newton_cg
2727
from ..utils.validation import (as_float_array, DataConversionWarning,
28-
check_X_y)
28+
check_X_y, NotFittedError)
2929
from ..utils.fixes import expit
3030
from ..externals.joblib import Parallel, delayed
3131
from ..cross_validation import check_cv
@@ -1088,6 +1088,13 @@ def predict_proba(self, X):
10881088
The returned estimates for all classes are ordered by the
10891089
label of classes.
10901090
1091+
For a multi_class problem, if multi_class is set to be "multinomial"
1092+
the softmax function is used to find the predicted probability of
1093+
each class.
1094+
Else use a one-vs-rest approach, i.e calculate the probability
1095+
of each class assuming it to be positive using the logistic function.
1096+
and normalize these values across all the classes.
1097+
10911098
Parameters
10921099
----------
10931100
X : array-like, shape = [n_samples, n_features]
@@ -1098,7 +1105,17 @@ def predict_proba(self, X):
10981105
Returns the probability of the sample for each class in the model,
10991106
where classes are ordered as they are in ``self.classes_``.
11001107
"""
1101-
return self._predict_proba_lr(X)
1108+
if not hasattr(self, "coef_"):
1109+
raise NotFittedError("Call fit before prediction")
1110+
calculate_ovr = self.coef_.shape[0] == 1 or self.multi_class == "ovr"
1111+
if calculate_ovr:
1112+
return super(LogisticRegression, self)._predict_proba_lr(X)
1113+
else:
1114+
prob = self.decision_function(X)
1115+
np.exp(prob, prob)
1116+
sum_prob = np.sum(prob, axis=1).reshape((-1, 1))
1117+
prob /= sum_prob
1118+
return prob
11021119

11031120
def predict_log_proba(self, X):
11041121
"""Log of probability estimates.

sklearn/linear_model/tests/test_logistic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from sklearn.cross_validation import StratifiedKFold
2525
from sklearn.datasets import load_iris, make_classification
26+
from sklearn.metrics import log_loss
2627

2728

2829
X = [[-1, 0], [0, 1], [1, 1]]
@@ -675,3 +676,24 @@ def test_logreg_cv_penalty():
675676
lr = LogisticRegression(penalty="l1", C=1.0, solver='liblinear')
676677
lr.fit(X, y)
677678
assert_equal(np.count_nonzero(lr_cv.coef_), np.count_nonzero(lr.coef_))
679+
680+
681+
def test_logreg_predict_proba_multinomial():
682+
X, y = make_classification(
683+
n_samples=10, n_features=20, random_state=0, n_classes=3, n_informative=10)
684+
685+
# Predicted probabilites using the true-entropy loss should give a smaller loss
686+
# than those using the ovr method.
687+
clf_multi = LogisticRegression(multi_class="multinomial", solver="lbfgs")
688+
clf_multi.fit(X, y)
689+
clf_multi_loss = log_loss(y, clf_multi.predict_proba(X))
690+
clf_ovr = LogisticRegression(multi_class="ovr", solver="lbfgs")
691+
clf_ovr.fit(X, y)
692+
clf_ovr_loss = log_loss(y, clf_ovr.predict_proba(X))
693+
assert_greater(clf_ovr_loss, clf_multi_loss)
694+
695+
# Predicted probabilites using the soft-max function should give a smaller loss
696+
# than those using the logistic function.
697+
clf_multi_loss = log_loss(y, clf_multi.predict_proba(X))
698+
clf_wrong_loss = log_loss(y, clf_multi._predict_proba_lr(X))
699+
assert_greater(clf_wrong_loss, clf_multi_loss)

0 commit comments

Comments
 (0)