Skip to content

Commit

Permalink
ENH add categorical_encoder to SMOTEN (scikit-learn-contrib#1001)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre authored Jul 7, 2023
1 parent a1d9f3c commit d69acd5
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new/v0.11.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@ Enhancements
allowing to specify a :class:`~sklearn.preprocessing.OneHotEncoder` with custom
parameters.
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.

- :class:`~imblearn.over_sampling.SMOTEN` now accepts a parameter `categorical_encoder`
allowing to specify a :class:`~sklearn.preprocessing.OrdinalEncoder` with custom
parameters. A new fitted parameter `categorical_encoder_` is exposed to access the
fitted encoder.
:pr:`1001` by :user:`Guillaume Lemaitre <glemaitre>`.
43 changes: 39 additions & 4 deletions imblearn/over_sampling/_smote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,10 @@ class SMOTEN(SMOTE):
Parameters
----------
categorical_encoder : estimator, default=None
Ordinal encoder used to encode the categorical features. If `None`, a
:class:`~sklearn.preprocessing.OrdinalEncoder` is used with default parameters.
{sampling_strategy}
{random_state}
Expand Down Expand Up @@ -791,6 +795,9 @@ class SMOTEN(SMOTE):
Attributes
----------
categorical_encoder_ : estimator
The encoder used to encode the categorical features.
sampling_strategy_ : dict
Dictionary containing the information to sample the dataset. The keys
corresponds to the class labels from which to sample and the values
Expand Down Expand Up @@ -853,6 +860,31 @@ class SMOTEN(SMOTE):
Class counts after resampling Counter({{0: 40, 1: 40}})
"""

_parameter_constraints: dict = {
**SMOTE._parameter_constraints,
"categorical_encoder": [
HasMethods(["fit_transform", "inverse_transform"]),
None,
],
}

def __init__(
self,
categorical_encoder=None,
*,
sampling_strategy="auto",
random_state=None,
k_neighbors=5,
n_jobs=None,
):
super().__init__(
sampling_strategy=sampling_strategy,
random_state=random_state,
k_neighbors=k_neighbors,
n_jobs=n_jobs,
)
self.categorical_encoder = categorical_encoder

def _check_X_y(self, X, y):
"""Check should accept strings and not sparse matrices."""
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
Expand Down Expand Up @@ -900,11 +932,14 @@ def _fit_resample(self, X, y):
X_resampled = [X.copy()]
y_resampled = [y.copy()]

encoder = OrdinalEncoder(dtype=np.int32)
X_encoded = encoder.fit_transform(X)
if self.categorical_encoder is None:
self.categorical_encoder_ = OrdinalEncoder(dtype=np.int32)
else:
self.categorical_encoder_ = clone(self.categorical_encoder)
X_encoded = self.categorical_encoder_.fit_transform(X)

vdm = ValueDifferenceMetric(
n_categories=[len(cat) for cat in encoder.categories_]
n_categories=[len(cat) for cat in self.categorical_encoder_.categories_]
).fit(X_encoded, y)

for class_sample, n_samples in self.sampling_strategy_.items():
Expand All @@ -922,7 +957,7 @@ def _fit_resample(self, X, y):
X_class, class_sample, y.dtype, nn_indices, n_samples
)

X_new = encoder.inverse_transform(X_new)
X_new = self.categorical_encoder_.inverse_transform(X_new)
X_resampled.append(X_new)
y_resampled.append(y_new)

Expand Down
21 changes: 21 additions & 0 deletions imblearn/over_sampling/_smote/tests/test_smoten.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from sklearn.preprocessing import OrdinalEncoder

from imblearn.over_sampling import SMOTEN

Expand Down Expand Up @@ -27,6 +28,7 @@ def test_smoten(data):

assert X_res.shape == (80, 3)
assert y_res.shape == (80,)
assert isinstance(sampler.categorical_encoder_, OrdinalEncoder)


def test_smoten_resampling():
Expand All @@ -52,3 +54,22 @@ def test_smoten_resampling():
X_generated, y_generated = X_res[X.shape[0] :], y_res[X.shape[0] :]
np.testing.assert_array_equal(X_generated, "blue")
np.testing.assert_array_equal(y_generated, "not apple")


def test_smoten_categorical_encoder(data):
"""Check that `categorical_encoder` is used when provided."""

X, y = data
sampler = SMOTEN(random_state=0)
sampler.fit_resample(X, y)

assert isinstance(sampler.categorical_encoder_, OrdinalEncoder)
assert sampler.categorical_encoder_.dtype == np.int32

encoder = OrdinalEncoder(dtype=np.int64)
sampler.set_params(categorical_encoder=encoder).fit_resample(X, y)

assert isinstance(sampler.categorical_encoder_, OrdinalEncoder)
assert sampler.categorical_encoder is encoder
assert sampler.categorical_encoder_ is not encoder
assert sampler.categorical_encoder_.dtype == np.int64

0 comments on commit d69acd5

Please sign in to comment.