Skip to content

Commit

Permalink
EHN POC sparse handling for RandomUnderSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Aug 12, 2017
1 parent e9c2756 commit a68e8eb
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 27 deletions.
21 changes: 11 additions & 10 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,25 @@ def sample(self, X, y):
Parameters
----------
X : ndarray, shape (n_samples, n_features)
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : ndarray, shape (n_samples, )
y : array-like, shape (n_samples,)
Corresponding label for each sample in X.
Returns
-------
X_resampled : ndarray, shape (n_samples_new, n_features)
X_resampled : {array-like, sparse matrix}, shape \
(n_samples_new, n_features)
The array containing the resampled data.
y_resampled : ndarray, shape (n_samples_new)
y_resampled : array-like, shape (n_samples_new)
The corresponding label of `X_resampled`
"""

# Check the consistency of X and y
X, y = check_X_y(X, y)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])

check_is_fitted(self, 'ratio_')
self._check_X_y(X, y)
Expand All @@ -70,15 +71,15 @@ def fit_sample(self, X, y):
X : ndarray, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : ndarray, shape (n_samples, )
y : ndarray, shape (n_samples,)
Corresponding label for each sample in X.
Returns
-------
X_resampled : ndarray, shape (n_samples_new, n_features)
The array containing the resampled data.
y_resampled : ndarray, shape (n_samples_new)
y_resampled : ndarray, shape (n_samples_new,)
The corresponding label of `X_resampled`
"""
Expand Down Expand Up @@ -138,10 +139,10 @@ def fit(self, X, y):
Parameters
----------
X : ndarray, shape (n_samples, n_features)
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : ndarray, shape (n_samples, )
y : array-like, shape (n_samples,)
Corresponding label for each sample in X.
Returns
Expand All @@ -150,7 +151,7 @@ def fit(self, X, y):
Return self.
"""
X, y = check_X_y(X, y)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
y = check_target_type(y)
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
# self.sampling_type is already checked in check_ratio
Expand Down
58 changes: 58 additions & 0 deletions imblearn/over_sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# Christos Aridas
# License: MIT

from sklearn.utils import check_X_y

from ..base import BaseSampler


Expand All @@ -16,3 +18,59 @@ class BaseOverSampler(BaseSampler):
"""

_sampling_type = 'over-sampling'

def fit(self, X, y):
"""Find the classes statistics before to perform sampling.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : array-like, shape (n_samples,)
Corresponding label for each sample in X.
Returns
-------
self : object,
Return self.
Notes
-----
Over-samplers do not accept sparse matrices.
"""
# over-sampling method does not handle sparse matrix
X, y = check_X_y(X, y)

return super(BaseOverSampler, self).fit(X, y)

def sample(self, X, y):
"""Resample the dataset.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : array-like, shape (n_samples,)
Corresponding label for each sample in X.
Returns
-------
X_resampled : array-like, shape (n_samples_new, n_features)
The array containing the resampled data.
y_resampled : array-like, shape (n_samples_new,)
The corresponding label of `X_resampled`
Notes
-----
Over-samplers do not accept sparse matrices.
"""

# Check the consistency of X and y
X, y = check_X_y(X, y)

return super(BaseOverSampler, self).sample(X, y)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import division

import numpy as np
from sklearn.utils import check_random_state
from sklearn.utils import check_random_state, safe_indexing

from ..base import BaseUnderSampler

Expand Down Expand Up @@ -110,10 +110,7 @@ def _sample(self, X, y):
"""
random_state = check_random_state(self.random_state)

X_resampled = np.empty((0, X.shape[1]), dtype=X.dtype)
y_resampled = np.empty((0, ), dtype=y.dtype)
if self.return_indices:
idx_under = np.empty((0, ), dtype=int)
idx_under = np.empty((0, ), dtype=int)

for target_class in np.unique(y):
if target_class in self.ratio_.keys():
Expand All @@ -125,18 +122,12 @@ def _sample(self, X, y):
else:
index_target_class = slice(None)

X_resampled = np.concatenate(
(X_resampled, X[y == target_class][index_target_class]),
axis=0)
y_resampled = np.concatenate(
(y_resampled, y[y == target_class][index_target_class]),
axis=0)
if self.return_indices:
idx_under = np.concatenate(
(idx_under, np.flatnonzero(y == target_class)[
index_target_class]), axis=0)
idx_under = np.concatenate(
(idx_under, np.flatnonzero(y == target_class)[
index_target_class]), axis=0)

if self.return_indices:
return X_resampled, y_resampled, idx_under
return (safe_indexing(X, idx_under), safe_indexing(y, idx_under),
idx_under)
else:
return X_resampled, y_resampled
return safe_indexing(X, idx_under), safe_indexing(y, idx_under)

0 comments on commit a68e8eb

Please sign in to comment.