Skip to content

Commit a207c70

Browse files
committed
Use a RuntimeError instead of a warning to avoid raising a ValueError randomly later
1 parent 73b1b82 commit a207c70

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

sklearn/metrics/tests/test_metrics.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -854,11 +854,10 @@ def test_classification_report_multiclass_with_unicode_label():
854854
avg / total 0.51 0.53 0.47 75
855855
"""
856856
if np_version[:3] < (1, 6, 1):
857-
# check that we get a warning about a bug in numpy
858-
with warnings.catch_warnings(record=True) as record:
859-
warnings.simplefilter('always')
860-
classification_report(y_true, y_pred)
861-
assert_true(len(record) != 0)
857+
expected_message = ("NumPy < 1.6.1 does not implement"
858+
" searchsorted on unicode data correctly.")
859+
assert_raise_message(RuntimeError, expected_message,
860+
classification_report, y_true, y_pred)
862861
else:
863862
report = classification_report(y_true, y_pred)
864863
assert_equal(report, expected_report)

sklearn/preprocessing/label.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# Andreas Mueller <amueller@ais.uni-bonn.de>
55
# License: BSD 3 clause
66

7-
import warnings
8-
97
import numpy as np
108

119
from ..base import BaseEstimator, TransformerMixin
@@ -27,7 +25,7 @@
2725
'LabelEncoder',
2826
]
2927

30-
def _warn_numpy_unicode_bug(labels):
28+
def _check_numpy_unicode_bug(labels):
3129
"""Check that user is not subject to an old numpy bug
3230
3331
Fixed in master before 1.7.0:
@@ -37,9 +35,9 @@ def _warn_numpy_unicode_bug(labels):
3735
and then backported to 1.6.1.
3836
"""
3937
if np_version[:3] < (1, 6, 1) and labels.dtype.kind == 'U':
40-
warnings.warn("NumPy < 1.6.1 does not implement searchsorted"
41-
" on unicode data correctly. Please upgrade"
42-
" NumPy to use LabelEncoder with unicode inputs.")
38+
raise RuntimeError("NumPy < 1.6.1 does not implement searchsorted"
39+
" on unicode data correctly. Please upgrade"
40+
" NumPy to use LabelEncoder with unicode inputs.")
4341

4442

4543
class LabelEncoder(BaseEstimator, TransformerMixin):
@@ -97,7 +95,7 @@ def fit(self, y):
9795
self : returns an instance of self.
9896
"""
9997
y = column_or_1d(y, warn=True)
100-
_warn_numpy_unicode_bug(y)
98+
_check_numpy_unicode_bug(y)
10199
self.classes_ = np.unique(y)
102100
return self
103101

@@ -114,7 +112,7 @@ def fit_transform(self, y):
114112
y : array-like of shape [n_samples]
115113
"""
116114
y = column_or_1d(y, warn=True)
117-
_warn_numpy_unicode_bug(y)
115+
_check_numpy_unicode_bug(y)
118116
self.classes_, y = unique(y, return_inverse=True)
119117
return y
120118

@@ -133,7 +131,7 @@ def transform(self, y):
133131
self._check_fitted()
134132

135133
classes = np.unique(y)
136-
_warn_numpy_unicode_bug(classes)
134+
_check_numpy_unicode_bug(classes)
137135
if len(np.intersect1d(classes, self.classes_)) < len(classes):
138136
diff = np.setdiff1d(classes, self.classes_)
139137
raise ValueError("y contains new labels: %s" % str(diff))

0 commit comments

Comments
 (0)