Skip to content

Commit

Permalink
ENH: Add indicator features to imputer output
Browse files Browse the repository at this point in the history
  • Loading branch information
maniteja123 committed Apr 14, 2016
1 parent b2002dc commit 18396be
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 28 deletions.
31 changes: 26 additions & 5 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -445,28 +445,49 @@ values, either using the mean, the median or the most frequent value of
the row or column in which the missing values are located. This class
also allows for different missing values encodings.

Imputing missing values ordinarily discards the information of which values
were missing. Setting ``add_indicator_features=True`` allows the knowledge of
which features were imputed to be exploited by a downstream estimator
by adding features that indicate which elements have been imputed.

The following snippet demonstrates how to replace missing values,
encoded as ``np.nan``, using the mean value of the columns (axis 0)
that contain the missing values::
that contain the missing values. In case there is a feature which has
all missing features, it is discarded when transformed. Also if the
indicator matrix is requested (``add_indicator_features=True``),
then the shape of the transformed input is
``(n_samples, n_features_new + len(imputed_features_))`` ::

>>> import numpy as np
>>> from sklearn.preprocessing import Imputer
>>> imp = Imputer(missing_values='NaN', strategy='mean', axis=0)
>>> imp.fit([[1, 2], [np.nan, 3], [7, 6]])
Imputer(axis=0, copy=True, missing_values='NaN', strategy='mean', verbose=0)
>>> imp.fit([[1, 2], [np.nan, 3], [7, 6]]) # doctest: +NORMALIZE_WHITESPACE
Imputer(add_indicator_features=False, axis=0, copy=True, missing_values='NaN',
strategy='mean', verbose=0)
>>> X = [[np.nan, 2], [6, np.nan], [7, 6]]
>>> print(imp.transform(X)) # doctest: +ELLIPSIS
[[ 4. 2. ]
[ 6. 3.666...]
[ 7. 6. ]]
>>> imp_with_in = Imputer(missing_values='NaN', strategy='mean', axis=0,add_indicator_features=True)
>>> imp_with_in.fit([[1, 2], [np.nan, 3], [7, 6]])
Imputer(add_indicator_features=True, axis=0, copy=True, missing_values='NaN',
strategy='mean', verbose=0)
>>> print(imp_with_in.transform(X)) # doctest: +ELLIPSIS
[[ 4. 2. 1. 0. ]
[ 6. 3.66666667 0. 1. ]
[ 7. 6. 0. 0. ]]
>>> print(imp_with_in.imputed_features_)
[0 1]

The :class:`Imputer` class also supports sparse matrices::

>>> import scipy.sparse as sp
>>> X = sp.csc_matrix([[1, 2], [0, 3], [7, 6]])
>>> imp = Imputer(missing_values=0, strategy='mean', axis=0)
>>> imp.fit(X)
Imputer(axis=0, copy=True, missing_values=0, strategy='mean', verbose=0)
>>> imp.fit(X) # doctest: +NORMALIZE_WHITESPACE
Imputer(add_indicator_features=False, axis=0, copy=True, missing_values=0,
strategy='mean', verbose=0)
>>> X_test = sp.csc_matrix([[0, 2], [6, 0], [7, 6]])
>>> print(imp.transform(X_test)) # doctest: +ELLIPSIS
[[ 4. 2. ]
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ Enhancements
- :class:`naive_bayes.GaussianNB` now accepts data-independent class-priors
through the parameter ``priors``. By `Guillaume Lemaitre`_.

- Add option to show ``indicator features`` in the output of Imputer.
By `Mani Teja`_.

Bug fixes
.........

Expand Down Expand Up @@ -4143,3 +4146,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
.. _Guillaume Lemaitre: https://github.com/glemaitre

.. _JPFrancoia: https://github.com/JPFrancoia

.. _Mani Teja: https://github.com/maniteja123
21 changes: 18 additions & 3 deletions examples/missing_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@
Imputing does not always improve the predictions, so please check via cross-validation.
Sometimes dropping rows or using marker values is more effective.
In this example, we artificially mark some of the elements in complete
dataset as missing. Then we estimate performance using the complete dataset,
dataset without the missing samples, after imputation without the indicator
matrix and imputation with the indicator matrix for the missing values.
Missing values can be replaced by the mean, the median or the most frequent
value using the ``strategy`` hyper-parameter.
The median is a more robust estimator for data with high magnitude variables
which could dominate results (otherwise known as a 'long tail').
Script output::
Score with the entire dataset = 0.56
Score with the complete dataset = 0.56
Score without the samples containing missing values = 0.48
Score after imputation of the missing values = 0.55
Score after imputation with indicator features = 0.57
In this case, imputing helps the classifier get close to the original score.
Expand All @@ -40,11 +46,11 @@
# Estimate the score on the entire dataset, with no missing values
estimator = RandomForestRegressor(random_state=0, n_estimators=100)
score = cross_val_score(estimator, X_full, y_full).mean()
print("Score with the entire dataset = %.2f" % score)
print("Score with the complete dataset = %.2f" % score)

# Add missing values in 75% of the lines
missing_rate = 0.75
n_missing_samples = np.floor(n_samples * missing_rate)
n_missing_samples = int(n_samples * missing_rate)
missing_samples = np.hstack((np.zeros(n_samples - n_missing_samples,
dtype=np.bool),
np.ones(n_missing_samples,
Expand All @@ -70,3 +76,12 @@
n_estimators=100))])
score = cross_val_score(estimator, X_missing, y_missing).mean()
print("Score after imputation of the missing values = %.2f" % score)

# Estimate score after imputation of the missing values with indicator matrix
estimator = Pipeline([("imputer", Imputer(missing_values=0,
strategy="mean",
axis=0, add_indicator_features=True)),
("forest", RandomForestRegressor(random_state=0,
n_estimators=100))])
score = cross_val_score(estimator, X_missing, y_missing).mean()
print("Score after imputation with indicator features = %.2f" % score)
101 changes: 81 additions & 20 deletions sklearn/preprocessing/imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..base import BaseEstimator, TransformerMixin
from ..utils import check_array
from ..utils import safe_mask
from ..utils.fixes import astype
from ..utils.sparsefuncs import _get_median
from ..utils.validation import check_is_fitted
Expand Down Expand Up @@ -102,11 +103,21 @@ class Imputer(BaseEstimator, TransformerMixin):
- If `axis=0` and X is encoded as a CSR matrix;
- If `axis=1` and X is encoded as a CSC matrix.
add_indicator_features : boolean, optional (default=False)
If True, the transformed ``X`` will have binary indicator features
appended. These correspond to input features with at least one
missing value marking which elements have been imputed.
Attributes
----------
statistics_ : array of shape (n_features,)
The imputation fill value for each feature if axis == 0.
imputed_features_ : array of shape (n_features_with_missing, )
The input features which have been imputed during transform.
The size of this attribute will be the number of features with
at least one missing value (and fewer than all in the axis=0 case).
Notes
-----
- When ``axis=0``, columns which only contained missing values at `fit`
Expand All @@ -116,12 +127,13 @@ class Imputer(BaseEstimator, TransformerMixin):
contain missing values).
"""
def __init__(self, missing_values="NaN", strategy="mean",
axis=0, verbose=0, copy=True):
axis=0, verbose=0, copy=True, add_indicator_features=False):
self.missing_values = missing_values
self.strategy = strategy
self.axis = axis
self.verbose = verbose
self.copy = copy
self.add_indicator_features = add_indicator_features

def fit(self, X, y=None):
"""Fit the imputer on X.
Expand Down Expand Up @@ -299,13 +311,74 @@ def _dense_fit(self, X, strategy, missing_values, axis):

return most_frequent

def _sparse_transform(self, X, valid_stats, valid_idx):
"""transformer on sparse data."""
mask = _get_mask(X.data, self.missing_values)
indexes = np.repeat(np.arange(len(X.indptr) - 1, dtype=np.int),
np.diff(X.indptr))[mask]

X.data[mask] = astype(valid_stats[indexes], X.dtype,
copy=False)

mask_matrix = X.__class__((mask, X.indices.copy(),
X.indptr.copy()), shape=X.shape,
dtype=X.dtype)
mask_matrix.eliminate_zeros() # removes explicit False entries
features_with_missing_values = mask_matrix.sum(axis=0).A.nonzero()[1]
features_mask = safe_mask(mask_matrix, features_with_missing_values)
imputed_mask = mask_matrix[:, features_mask]
if self.axis == 0:
self.imputed_features_ = valid_idx[features_with_missing_values]
else:
self.imputed_features_ = features_with_missing_values

if self.add_indicator_features:
X = sparse.hstack((X, imputed_mask))

return X

def _dense_transform(self, X, valid_stats, valid_idx):
"""transformer on dense data."""
mask = _get_mask(X, self.missing_values)
n_missing = np.sum(mask, axis=self.axis)
values = np.repeat(valid_stats, n_missing)

if self.axis == 0:
coordinates = np.where(mask.transpose())[::-1]
else:
coordinates = mask

X[coordinates] = values

features_with_missing_values = np.where(np.any
(mask, axis=0))[0]
imputed_mask = mask[:, features_with_missing_values]
if self.axis == 0:
self.imputed_features_ = valid_idx[features_with_missing_values]
else:
self.imputed_features_ = features_with_missing_values

if self.add_indicator_features:
X = np.hstack((X, imputed_mask))

return X

def transform(self, X):
"""Impute all missing values in X.
Parameters
----------
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
The input data to complete.
Return
------
X_new : {array-like, sparse matrix},
Transformed array.
shape (n_samples, n_features_new) when
``add_indicator_features`` is False,
shape (n_samples, n_features_new + len(imputed_features_)
when ``add_indicator_features`` is True.
"""
if self.axis == 0:
check_is_fitted(self, 'statistics_')
Expand Down Expand Up @@ -337,39 +410,27 @@ def transform(self, X):
invalid_mask = np.isnan(statistics)
valid_mask = np.logical_not(invalid_mask)
valid_statistics = statistics[valid_mask]
valid_statistics_indexes = np.where(valid_mask)[0]
valid_idx = np.where(valid_mask)[0]
missing = np.arange(X.shape[not self.axis])[invalid_mask]

if self.axis == 0 and invalid_mask.any():
if self.verbose:
warnings.warn("Deleting features without "
"observed values: %s" % missing)
X = X[:, valid_statistics_indexes]
X = X[:, valid_idx]
elif self.axis == 1 and invalid_mask.any():
raise ValueError("Some rows only contain "
"missing values: %s" % missing)

# Do actual imputation
if sparse.issparse(X) and self.missing_values != 0:
mask = _get_mask(X.data, self.missing_values)
indexes = np.repeat(np.arange(len(X.indptr) - 1, dtype=np.int),
np.diff(X.indptr))[mask]

X.data[mask] = astype(valid_statistics[indexes], X.dtype,
copy=False)
# sparse matrix and missing values is not zero
X = self._sparse_transform(X, valid_statistics, valid_idx)
else:
# sparse with zero as missing value and dense matrix
if sparse.issparse(X):
X = X.toarray()

mask = _get_mask(X, self.missing_values)
n_missing = np.sum(mask, axis=self.axis)
values = np.repeat(valid_statistics, n_missing)

if self.axis == 0:
coordinates = np.where(mask.transpose())[::-1]
else:
coordinates = mask

X[coordinates] = values
X = self._dense_transform(X, valid_statistics, valid_idx)

return X
48 changes: 48 additions & 0 deletions sklearn/preprocessing/tests/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from scipy import sparse

from sklearn.base import clone
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_raises
Expand Down Expand Up @@ -358,3 +359,50 @@ def test_imputation_copy():

# Note: If X is sparse and if missing_values=0, then a (dense) copy of X is
# made, even if copy=False.


def check_indicator(X, expected_imputed_features, axis):
n_samples, n_features = X.shape
imputer = Imputer(missing_values=-1, strategy='mean', axis=axis)
imputer_with_in = clone(imputer).set_params(add_indicator_features=True)
Xt = imputer.fit_transform(X)
Xt_with_in = imputer_with_in.fit_transform(X)
imputed_features_mask = X[:, expected_imputed_features] == -1
n_features_new = Xt.shape[1]
n_imputed_features = len(imputer_with_in.imputed_features_)
assert_array_equal(imputer.imputed_features_, expected_imputed_features)
assert_array_equal(imputer_with_in.imputed_features_,
expected_imputed_features)
assert_equal(Xt_with_in.shape,
(n_samples, n_features_new + n_imputed_features))
assert_array_equal(Xt_with_in, np.hstack((Xt, imputed_features_mask)))
imputer_with_in = clone(imputer).set_params(add_indicator_features=True)
assert_array_equal(Xt_with_in,
imputer_with_in.fit_transform(sparse.csc_matrix(X)).A)
assert_array_equal(Xt_with_in,
imputer_with_in.fit_transform(sparse.csr_matrix(X)).A)


def test_indicator_features():
# one feature with all missng values
X = np.array([
[-1, -1, 2, 3],
[4, -1, 6, -1],
[8, -1, 10, 11],
[12, -1, -1, 15],
[16, -1, 18, 19]
])
check_indicator(X, np.array([0, 2, 3]), axis=0)
check_indicator(X, np.array([0, 1, 2, 3]), axis=1)

# one feature with all missing values and one with no missing value
# when axis=0 the feature gets discarded
X = np.array([
[-1, -1, 1, 3],
[4, -1, 0, -1],
[8, -1, 1, 0],
[0, -1, 0, 15],
[16, -1, 1, 19]
])
check_indicator(X, np.array([0, 3]), axis=0)
check_indicator(X, np.array([0, 1, 3]), axis=1)

0 comments on commit 18396be

Please sign in to comment.