Skip to content

Commit 499c2e4

Browse files
committed
explain_prediction for sklearn linear classifiers: use predicted class
1 parent 26a7ab4 commit 499c2e4

File tree

4 files changed

+42
-19
lines changed

4 files changed

+42
-19
lines changed

eli5/sklearn/explain_prediction.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,14 @@ def explain_prediction_linear_classifier(clf, doc,
198198
add_weighted_spans(doc, vec, vectorized, target_expl)
199199
res.targets.append(target_expl)
200200
else:
201+
label_id = 1 if score > 0 else 0
202+
scale = -1 if label_id == 0 else 1
203+
201204
target_expl = TargetExplanation(
202-
target=display_names[1][1],
203-
feature_weights=_weights(0),
205+
target=display_names[label_id][1],
206+
feature_weights=_weights(0, scale=scale),
204207
score=score,
205-
proba=proba[1] if proba is not None else None,
208+
proba=proba[label_id] if proba is not None else None,
206209
)
207210
add_weighted_spans(doc, vec, vectorized, target_expl)
208211
res.targets.append(target_expl)
@@ -606,8 +609,8 @@ def _multiply(X, coef):
606609
def _linear_weights(clf, x, top, feature_names, flt_indices):
607610
""" Return top weights getter for label_id.
608611
"""
609-
def _weights(label_id):
610-
coef = get_coef(clf, label_id)
612+
def _weights(label_id, scale=1.0):
613+
coef = get_coef(clf, label_id) * scale
611614
_x = x
612615
scores = _multiply(_x, coef)
613616
if flt_indices is not None:

tests/test_ipython.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_show_prediction():
4343
html = eli5.show_prediction(clf, doc)
4444
write_html(clf, html.data, '')
4545
assert isinstance(html, HTML)
46-
assert 'y=b' in html.data
46+
assert 'y=a' in html.data
4747
assert 'BIAS' in html.data
4848
assert 'x1' in html.data
4949

@@ -56,6 +56,14 @@ def test_show_prediction():
5656
# format_as_html arguments are supported
5757
html = eli5.show_prediction(clf, doc, show=['method'])
5858
write_html(clf, html.data, '')
59-
assert 'y=b' not in html.data
59+
assert 'y=a' not in html.data
6060
assert 'BIAS' not in html.data
6161
assert 'Explained as' in html.data
62+
63+
# top target is used
64+
html = eli5.show_prediction(clf, np.array([1, 1]))
65+
write_html(clf, html.data, '')
66+
assert 'y=b' in html.data
67+
assert 'BIAS' in html.data
68+
assert 'x1' in html.data
69+

tests/test_sklearn_explain_prediction.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,23 @@ def assert_binary_linear_classifier_explained(newsgroups_train_binary, clf,
128128
X = vec.fit_transform(docs)
129129
clf.fit(X, y)
130130

131-
get_res = lambda **kwargs: explain_prediction(
132-
clf, docs[2], vec=vec, target_names=target_names, top=20, **kwargs)
133-
res = get_res()
134-
pprint(res)
131+
assert y[2] == 1
132+
cg_document = docs[2]
133+
res = explain_prediction(clf, cg_document, vec=vec,
134+
target_names=target_names, top=20)
135135
expl_text, expl_html = format_as_all(res, clf)
136136
for expl in [expl_text, expl_html]:
137137
assert 'software' in expl
138+
assert target_names[1] in expl
139+
140+
assert y[15] == 0
141+
atheism_document = docs[15]
142+
res = explain_prediction(clf, atheism_document, vec=vec,
143+
target_names=target_names, top=20)
144+
expl_text, expl_html = format_as_all(res, clf)
145+
for expl in [expl_text, expl_html]:
146+
assert 'god' in expl
147+
assert target_names[0] in expl
138148

139149

140150
def assert_linear_regression_explained(boston_train, reg, explain_prediction,
@@ -288,6 +298,8 @@ def test_explain_linear(newsgroups_train, clf):
288298

289299
@pytest.mark.parametrize(['clf'], [
290300
[LogisticRegression(random_state=42)],
301+
[LogisticRegressionCV(random_state=42)],
302+
[OneVsRestClassifier(LogisticRegression(random_state=42))],
291303
[SGDClassifier(random_state=42)],
292304
[SVC(kernel='linear', random_state=42)],
293305
[SVC(kernel='linear', random_state=42, decision_function_shape='ovr')],

tests/test_sklearn_vectorizers.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
from .utils import format_as_all, get_all_features, get_names_coefs, write_html
1919

2020

21-
def check_explain_linear_binary(res, clf):
21+
def check_explain_linear_binary(res, clf, target='alt.atheism'):
2222
expl_text, expl_html = format_as_all(res, clf)
2323
assert len(res.targets) == 1
2424
e = res.targets[0]
25-
assert e.target == 'comp.graphics'
26-
neg = get_all_features(e.feature_weights.neg)
27-
assert 'objective' in neg
25+
assert e.target == target
26+
pos = get_all_features(e.feature_weights.pos)
27+
assert 'objective' in pos
2828
for expl in [expl_text, expl_html]:
29-
assert 'comp.graphics' in expl
29+
assert target in expl
3030
assert 'objective' in expl
3131

3232

@@ -50,9 +50,9 @@ def test_explain_linear_binary(vec, newsgroups_train_binary):
5050
top=20, vectorized=True)
5151
if isinstance(vec, HashingVectorizer):
5252
# InvertableHashingVectorizer must be passed with vectorized=True
53-
neg_weights = res_vectorized.targets[0].feature_weights.neg
54-
neg_vectorized = get_all_features(neg_weights)
55-
assert all(name.startswith('x') for name in neg_vectorized)
53+
pos_weights = res_vectorized.targets[0].feature_weights.pos
54+
pos_vectorized = get_all_features(pos_weights)
55+
assert all(name.startswith('x') for name in pos_vectorized)
5656
else:
5757
assert res_vectorized == _without_weighted_spans(res)
5858

0 commit comments

Comments
 (0)