Skip to content

Commit 928f724

Browse files
olologinjnothman
authored andcommitted
[MRG] fix scikit-learn#5269: Overflow error with sklearn.datasets.load_svmlight (scikit-learn#7101)
* fix for scikit-learn#5269, overflow error * test with long qid added * What's new section added
1 parent f2df3db commit 928f724

File tree

4 files changed

+45
-5
lines changed

4 files changed

+45
-5
lines changed

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ Bug fixes
312312
(`#5460 https://github.com/scikit-learn/scikit-learn/pull/5460>`_)
313313
By `Tom Dupre la Tour`_.
314314

315+
- :func:`datasets.load_svmlight_file` now is able to read long int QID values.
316+
(`#7101 <https://github.com/scikit-learn/scikit-learn/pull/7101>`_)
317+
By `Ibraim Ganiev`_.
318+
315319
API changes summary
316320
-------------------
317321

@@ -4312,3 +4316,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
43124316
.. _Nelson Liu: https://github.com/nelson-liu
43134317

43144318
.. _Manvendra Singh: https://github.com/manu-chroma
4319+
4320+
.. _Ibraim Ganiev: https://github.com/olologin

sklearn/datasets/_svmlight_format.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ cdef bytes COLON = u':'.encode('ascii')
2727
@cython.wraparound(False)
2828
def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
2929
bint query_id):
30-
cdef array.array data, indices, indptr, query
30+
cdef array.array data, indices, indptr
3131
cdef bytes line
3232
cdef char *hash_ptr
3333
cdef char *line_cstr
@@ -45,7 +45,7 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
4545
data = array.array("d")
4646
indices = array.array("i")
4747
indptr = array.array("i", [0])
48-
query = array.array("i")
48+
query = np.arange(0, dtype=np.int64)
4949

5050
if multilabel:
5151
labels = []
@@ -80,8 +80,8 @@ def _load_svmlight_file(f, dtype, bint multilabel, bint zero_based,
8080
if n_features and features[0].startswith(qid_prefix):
8181
_, value = features[0].split(COLON, 1)
8282
if query_id:
83-
array.resize_smart(query, len(query) + 1)
84-
query[len(query) - 1] = int(value)
83+
query.resize(len(query) + 1)
84+
query[len(query) - 1] = np.int64(value)
8585
features.pop(0)
8686
n_features -= 1
8787

sklearn/datasets/svmlight_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _open_and_load(f, dtype, multilabel, zero_based, query_id):
166166
data = frombuffer_empty(data, actual_dtype)
167167
indices = frombuffer_empty(ind, np.intc)
168168
indptr = np.frombuffer(indptr, dtype=np.intc) # never empty
169-
query = frombuffer_empty(query, np.intc)
169+
query = frombuffer_empty(query, np.int64)
170170

171171
data = np.asarray(data, dtype=dtype) # no-op for float{32,64}
172172
return data, indices, indptr, labels, query

sklearn/datasets/tests/test_svmlight_format.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,37 @@ def test_dump_query_id():
368368
assert_array_almost_equal(X, X1.toarray())
369369
assert_array_almost_equal(y, y1)
370370
assert_array_almost_equal(query_id, query_id1)
371+
372+
373+
def test_load_with_long_qid():
374+
# load svmfile with longint qid attribute
375+
data = b("""
376+
1 qid:0 0:1 1:2 2:3
377+
0 qid:72048431380967004 0:1440446648 1:72048431380967004 2:236784985
378+
0 qid:-9223372036854775807 0:1440446648 1:72048431380967004 2:236784985
379+
3 qid:9223372036854775807 0:1440446648 1:72048431380967004 2:236784985""")
380+
X, y, qid = load_svmlight_file(BytesIO(data), query_id=True)
381+
382+
true_X = [[1, 2, 3],
383+
[1440446648, 72048431380967004, 236784985],
384+
[1440446648, 72048431380967004, 236784985],
385+
[1440446648, 72048431380967004, 236784985]]
386+
387+
true_y = [1, 0, 0, 3]
388+
trueQID = [0, 72048431380967004, -9223372036854775807, 9223372036854775807]
389+
assert_array_equal(y, true_y)
390+
assert_array_equal(X.toarray(), true_X)
391+
assert_array_equal(qid, trueQID)
392+
393+
f = BytesIO()
394+
dump_svmlight_file(X, y, f, query_id=qid, zero_based=True)
395+
f.seek(0)
396+
X, y, qid = load_svmlight_file(f, query_id=True, zero_based=True)
397+
assert_array_equal(y, true_y)
398+
assert_array_equal(X.toarray(), true_X)
399+
assert_array_equal(qid, trueQID)
400+
401+
f.seek(0)
402+
X, y = load_svmlight_file(f, query_id=False, zero_based=True)
403+
assert_array_equal(y, true_y)
404+
assert_array_equal(X.toarray(), true_X)

0 commit comments

Comments
 (0)