Skip to content

Commit

Permalink
ENH preserve dtype and type when providing a dataframe with sparse dt…
Browse files Browse the repository at this point in the history
…ype (scikit-learn-contrib#1054)

Co-authored-by: timschulz <tim.schulz@ginkgo-analytics.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
  • Loading branch information
3 people authored Jan 19, 2024
1 parent c7a1838 commit 9e976a4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
6 changes: 6 additions & 0 deletions doc/whats_new/v0.12.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,9 @@ Deprecations
- Deprecate `kind_sel` in :class:`~imblearn.under_sampling.NeighbourhoodCleaningRule.
It will be removed in 0.14. The parameter does not have any effect.
:pr:`1012` by :user:`Guillaume Lemaitre <glemaitre>`.

Enhancements
............

- Allows to output dataframe with sparse format if provided as input.
:pr:`1059` by :user:`ts2095 <ts2095>`.
6 changes: 5 additions & 1 deletion imblearn/utils/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from numbers import Integral, Real

import numpy as np
from scipy.sparse import issparse
from sklearn.base import clone
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_array, column_or_1d
Expand Down Expand Up @@ -61,7 +62,10 @@ def _transfrom_one(self, array, props):
elif type_ == "dataframe":
import pandas as pd

ret = pd.DataFrame(array, columns=props["columns"])
if issparse(array):
ret = pd.DataFrame.sparse.from_spmatrix(array, columns=props["columns"])
else:
ret = pd.DataFrame(array, columns=props["columns"])
ret = ret.astype(props["dtypes"])
elif type_ == "series":
import pandas as pd
Expand Down
29 changes: 29 additions & 0 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _yield_sampler_checks(sampler):
yield check_samplers_sparse
if "dataframe" in tags["X_types"]:
yield check_samplers_pandas
yield check_samplers_pandas_sparse
if "string" in tags["X_types"]:
yield check_samplers_string
if tags["allow_nan"]:
Expand Down Expand Up @@ -312,6 +313,34 @@ def check_samplers_sparse(name, sampler_orig):
assert_allclose(y_res_sparse, y_res)


def check_samplers_pandas_sparse(name, sampler_orig):
pd = pytest.importorskip("pandas")
sampler = clone(sampler_orig)
# Check that the samplers handle pandas dataframe and pandas series
X, y = sample_dataset_generator()
X_df = pd.DataFrame(
X, columns=[str(i) for i in range(X.shape[1])], dtype=pd.SparseDtype(float, 0)
)
y_s = pd.Series(y, name="class")

X_res_df, y_res_s = sampler.fit_resample(X_df, y_s)
X_res, y_res = sampler.fit_resample(X, y)

# check that we return the same type for dataframes or series types
assert isinstance(X_res_df, pd.DataFrame)
assert isinstance(y_res_s, pd.Series)

for column_dtype in X_res_df.dtypes:
assert isinstance(column_dtype, pd.SparseDtype)

assert X_df.columns.tolist() == X_res_df.columns.tolist()
assert y_s.name == y_res_s.name

# FIXME: we should use to_numpy with pandas >= 0.25
assert_allclose(X_res_df.values, X_res)
assert_allclose(y_res_s.values, y_res)


def check_samplers_pandas(name, sampler_orig):
pd = pytest.importorskip("pandas")
sampler = clone(sampler_orig)
Expand Down

0 comments on commit 9e976a4

Please sign in to comment.