Skip to content

Commit 8994d0e

Browse files
manu-chromaagramfort
authored andcommitted
[MRG+1] added return_X_y option to toy datasets in sklearn.datasets (scikit-learn#7154)
* added return_X_y support to more dataset loaders * fix typo * updated whats_new.rst * fix indentation for version added tag * call astype before the branching * better formatting in whats_new.rst * better formatting * updated what's new
1 parent 12d5f07 commit 8994d0e

File tree

3 files changed

+104
-14
lines changed

3 files changed

+104
-14
lines changed

doc/whats_new.rst

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,16 @@ Enhancements
230230
(`#6846 <https://github.com/scikit-learn/scikit-learn/pull/6846>`_)
231231
By `Sebastian Säger`_ and `YenChen Lin`_.
232232

233-
- Added new return type ``(data, target)`` : tuple option to
234-
:func:`load_iris` dataset,
235-
(`#7049 <https://github.com/scikit-learn/scikit-learn/pull/7049>`_)
233+
- Added parameter ``return_X_y`` and return type ``(data, target) : tuple`` option to
234+
:func:`load_iris` dataset
235+
`#7049 <https://github.com/scikit-learn/scikit-learn/pull/7049>`_,
236236
:func:`load_breast_cancer` dataset
237-
(`#7152 <https://github.com/scikit-learn/scikit-learn/pull/7152>`_) by
237+
`#7152 <https://github.com/scikit-learn/scikit-learn/pull/7152>`_,
238+
:func:`load_digits` dataset,
239+
:func:`load_diabetes` dataset,
240+
:func:`load_linnerud` dataset,
241+
:func:`load_boston` dataset
242+
`#7154 <https://github.com/scikit-learn/scikit-learn/pull/7154>`_ by
238243
`Manvendra Singh`_.
239244

240245
Bug fixes

sklearn/datasets/base.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def load_iris(return_X_y=False):
264264
If True, returns ``(data, target)`` instead of a Bunch object.
265265
See below for more information about the `data` and `target` object.
266266
267-
.. versionadded:: 0.18
267+
.. versionadded:: 0.18
268268
269269
Returns
270270
-------
@@ -277,7 +277,7 @@ def load_iris(return_X_y=False):
277277
278278
(data, target) : tuple if ``return_X_y`` is True
279279
280-
.. versionadded:: 0.18
280+
.. versionadded:: 0.18
281281
282282
Examples
283283
--------
@@ -338,7 +338,7 @@ def load_breast_cancer(return_X_y=False):
338338
If True, returns ``(data, target)`` instead of a Bunch object.
339339
See below for more information about the `data` and `target` object.
340340
341-
.. versionadded:: 0.18
341+
.. versionadded:: 0.18
342342
343343
Returns
344344
-------
@@ -351,7 +351,7 @@ def load_breast_cancer(return_X_y=False):
351351
352352
(data, target) : tuple if ``return_X_y`` is True
353353
354-
.. versionadded:: 0.18
354+
.. versionadded:: 0.18
355355
356356
The copy of UCI ML Breast Cancer Wisconsin (Diagnostic) dataset is
357357
downloaded from:
@@ -411,7 +411,7 @@ def load_breast_cancer(return_X_y=False):
411411
feature_names=feature_names)
412412

413413

414-
def load_digits(n_class=10):
414+
def load_digits(n_class=10, return_X_y=False):
415415
"""Load and return the digits dataset (classification).
416416
417417
Each datapoint is a 8x8 image of a digit.
@@ -431,6 +431,12 @@ def load_digits(n_class=10):
431431
n_class : integer, between 0 and 10, optional (default=10)
432432
The number of classes to return.
433433
434+
return_X_y : boolean, default=False.
435+
If True, returns ``(data, target)`` instead of a Bunch object.
436+
See below for more information about the `data` and `target` object.
437+
438+
.. versionadded:: 0.18
439+
434440
Returns
435441
-------
436442
data : Bunch
@@ -440,6 +446,10 @@ def load_digits(n_class=10):
440446
sample, 'target_names', the meaning of the labels, and 'DESCR',
441447
the full description of the dataset.
442448
449+
(data, target) : tuple if ``return_X_y`` is True
450+
451+
.. versionadded:: 0.18
452+
443453
Examples
444454
--------
445455
To load the data and visualize the images::
@@ -458,7 +468,7 @@ def load_digits(n_class=10):
458468
delimiter=',')
459469
with open(join(module_path, 'descr', 'digits.rst')) as f:
460470
descr = f.read()
461-
target = data[:, -1]
471+
target = data[:, -1].astype(np.int)
462472
flat_data = data[:, :-1]
463473
images = flat_data.view()
464474
images.shape = (-1, 8, 8)
@@ -468,14 +478,17 @@ def load_digits(n_class=10):
468478
flat_data, target = flat_data[idx], target[idx]
469479
images = images[idx]
470480

481+
if return_X_y:
482+
return flat_data, target
483+
471484
return Bunch(data=flat_data,
472-
target=target.astype(np.int),
485+
target=target,
473486
target_names=np.arange(10),
474487
images=images,
475488
DESCR=descr)
476489

477490

478-
def load_diabetes():
491+
def load_diabetes(return_X_y=False):
479492
"""Load and return the diabetes dataset (regression).
480493
481494
============== ==================
@@ -487,34 +500,62 @@ def load_diabetes():
487500
488501
Read more in the :ref:`User Guide <datasets>`.
489502
503+
Parameters
504+
----------
505+
return_X_y : boolean, default=False.
506+
If True, returns ``(data, target)`` instead of a Bunch object.
507+
See below for more information about the `data` and `target` object.
508+
509+
.. versionadded:: 0.18
510+
490511
Returns
491512
-------
492513
data : Bunch
493514
Dictionary-like object, the interesting attributes are:
494515
'data', the data to learn and 'target', the regression target for each
495516
sample.
517+
518+
(data, target) : tuple if ``return_X_y`` is True
519+
520+
.. versionadded:: 0.18
496521
"""
497522
base_dir = join(dirname(__file__), 'data')
498523
data = np.loadtxt(join(base_dir, 'diabetes_data.csv.gz'))
499524
target = np.loadtxt(join(base_dir, 'diabetes_target.csv.gz'))
525+
526+
if return_X_y:
527+
return data, target
528+
500529
return Bunch(data=data, target=target)
501530

502531

503-
def load_linnerud():
532+
def load_linnerud(return_X_y=False):
504533
"""Load and return the linnerud dataset (multivariate regression).
505534
506535
Samples total: 20
507536
Dimensionality: 3 for both data and targets
508537
Features: integer
509538
Targets: integer
510539
540+
Parameters
541+
----------
542+
return_X_y : boolean, default=False.
543+
If True, returns ``(data, target)`` instead of a Bunch object.
544+
See below for more information about the `data` and `target` object.
545+
546+
.. versionadded:: 0.18
547+
511548
Returns
512549
-------
513550
data : Bunch
514551
Dictionary-like object, the interesting attributes are: 'data' and
515552
'targets', the two multivariate datasets, with 'data' corresponding to
516553
the exercise and 'targets' corresponding to the physiological
517554
measurements, as well as 'feature_names' and 'target_names'.
555+
556+
(data, target) : tuple if ``return_X_y`` is True
557+
558+
.. versionadded:: 0.18
518559
"""
519560
base_dir = join(dirname(__file__), 'data/')
520561
# Read data
@@ -529,13 +570,16 @@ def load_linnerud():
529570
with open(dirname(__file__) + '/descr/linnerud.rst') as f:
530571
descr = f.read()
531572

573+
if return_X_y:
574+
return data_exercise, data_physiological
575+
532576
return Bunch(data=data_exercise, feature_names=header_exercise,
533577
target=data_physiological,
534578
target_names=header_physiological,
535579
DESCR=descr)
536580

537581

538-
def load_boston():
582+
def load_boston(return_X_y=False):
539583
"""Load and return the boston house-prices dataset (regression).
540584
541585
============== ==============
@@ -545,13 +589,25 @@ def load_boston():
545589
Targets real 5. - 50.
546590
============== ==============
547591
592+
Parameters
593+
----------
594+
return_X_y : boolean, default=False.
595+
If True, returns ``(data, target)`` instead of a Bunch object.
596+
See below for more information about the `data` and `target` object.
597+
598+
.. versionadded:: 0.18
599+
548600
Returns
549601
-------
550602
data : Bunch
551603
Dictionary-like object, the interesting attributes are:
552604
'data', the data to learn, 'target', the regression targets,
553605
and 'DESCR', the full description of the dataset.
554606
607+
(data, target) : tuple if ``return_X_y`` is True
608+
609+
.. versionadded:: 0.18
610+
555611
Examples
556612
--------
557613
>>> from sklearn.datasets import load_boston
@@ -580,6 +636,9 @@ def load_boston():
580636
data[i] = np.asarray(d[:-1], dtype=np.float64)
581637
target[i] = np.asarray(d[-1], dtype=np.float64)
582638

639+
if return_X_y:
640+
return data, target
641+
583642
return Bunch(data=data,
584643
target=target,
585644
# last column is target value

sklearn/datasets/tests/test_base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ def test_load_digits():
128128
assert_equal(digits.data.shape, (1797, 64))
129129
assert_equal(numpy.unique(digits.target).size, 10)
130130

131+
# test return_X_y option
132+
X_y_tuple = load_digits(return_X_y=True)
133+
bunch = load_digits()
134+
assert_true(isinstance(X_y_tuple, tuple))
135+
assert_array_equal(X_y_tuple[0], bunch.data)
136+
assert_array_equal(X_y_tuple[1], bunch.target)
137+
131138

132139
def test_load_digits_n_class_lt_10():
133140
digits = load_digits(9)
@@ -165,6 +172,13 @@ def test_load_diabetes():
165172
assert_equal(res.data.shape, (442, 10))
166173
assert_true(res.target.size, 442)
167174

175+
# test return_X_y option
176+
X_y_tuple = load_diabetes(return_X_y=True)
177+
bunch = load_diabetes()
178+
assert_true(isinstance(X_y_tuple, tuple))
179+
assert_array_equal(X_y_tuple[0], bunch.data)
180+
assert_array_equal(X_y_tuple[1], bunch.target)
181+
168182

169183
def test_load_linnerud():
170184
res = load_linnerud()
@@ -173,6 +187,12 @@ def test_load_linnerud():
173187
assert_equal(len(res.target_names), 3)
174188
assert_true(res.DESCR)
175189

190+
# test return_X_y option
191+
X_y_tuple = load_linnerud(return_X_y=True)
192+
bunch = load_linnerud()
193+
assert_true(isinstance(X_y_tuple, tuple))
194+
assert_array_equal(X_y_tuple[0], bunch.data)
195+
assert_array_equal(X_y_tuple[1], bunch.target)
176196

177197
def test_load_iris():
178198
res = load_iris()
@@ -211,6 +231,12 @@ def test_load_boston():
211231
assert_equal(res.feature_names.size, 13)
212232
assert_true(res.DESCR)
213233

234+
# test return_X_y option
235+
X_y_tuple = load_boston(return_X_y=True)
236+
bunch = load_boston()
237+
assert_true(isinstance(X_y_tuple, tuple))
238+
assert_array_equal(X_y_tuple[0], bunch.data)
239+
assert_array_equal(X_y_tuple[1], bunch.target)
214240

215241
def test_loads_dumps_bunch():
216242
bunch = Bunch(x="x")

0 commit comments

Comments
 (0)