Skip to content

Commit

Permalink
MAINT Add fixture for dataset generation in common tests (scikit-lear…
Browse files Browse the repository at this point in the history
…n-contrib#932)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
  • Loading branch information
awinml and glemaitre authored Jul 8, 2023
1 parent 5be5670 commit 3a7633d
Showing 1 changed file with 25 additions and 63 deletions.
88 changes: 25 additions & 63 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@
sklearn_version = parse_version(sklearn.__version__)


def sample_dataset_generator():
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
return X, y


@pytest.fixture(name="sample_dataset_generator")
def sample_dataset_generator_fixture():
return sample_dataset_generator()


def _set_checking_parameters(estimator):
params = estimator.get_params()
name = estimator.__class__.__name__
Expand Down Expand Up @@ -233,13 +249,7 @@ def check_samplers_fit(name, sampler_orig):

def check_samplers_fit_resample(name, sampler_orig):
sampler = clone(sampler_orig)
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
target_stats = Counter(y)
X_res, y_res = sampler.fit_resample(X, y)
if isinstance(sampler, BaseOverSampler):
Expand Down Expand Up @@ -269,13 +279,7 @@ def check_samplers_fit_resample(name, sampler_orig):
def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
sampler = clone(sampler_orig)
# in this test we will force all samplers to not change the class 1
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
expected_stat = Counter(y)[1]
if isinstance(sampler, BaseOverSampler):
sampling_strategy = {2: 498, 0: 498}
Expand All @@ -298,13 +302,7 @@ def check_samplers_sparse(name, sampler_orig):
sampler = clone(sampler_orig)
# check that sparse matrices can be passed through the sampler leading to
# the same results than dense
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
X_sparse = sparse.csr_matrix(X)
X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y)
sampler = clone(sampler)
Expand All @@ -318,13 +316,7 @@ def check_samplers_pandas(name, sampler_orig):
pd = pytest.importorskip("pandas")
sampler = clone(sampler_orig)
# Check that the samplers handle pandas dataframe and pandas series
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
X_df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
y_df = pd.DataFrame(y)
y_s = pd.Series(y, name="class")
Expand All @@ -351,13 +343,7 @@ def check_samplers_pandas(name, sampler_orig):
def check_samplers_list(name, sampler_orig):
sampler = clone(sampler_orig)
# Check that the can samplers handle simple lists
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
X_list = X.tolist()
y_list = y.tolist()

Expand All @@ -374,13 +360,7 @@ def check_samplers_list(name, sampler_orig):
def check_samplers_multiclass_ova(name, sampler_orig):
sampler = clone(sampler_orig)
# Check that multiclass target lead to the same results than OVA encoding
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
y_ova = label_binarize(y, classes=np.unique(y))
X_res, y_res = sampler.fit_resample(X, y)
X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
Expand All @@ -391,27 +371,15 @@ def check_samplers_multiclass_ova(name, sampler_orig):

def check_samplers_2d_target(name, sampler_orig):
sampler = clone(sampler_orig)
X, y = make_classification(
n_samples=100,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()

y = y.reshape(-1, 1) # Make the target 2d
sampler.fit_resample(X, y)


def check_samplers_preserve_dtype(name, sampler_orig):
sampler = clone(sampler_orig)
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
# Cast X and y to not default dtype
X = X.astype(np.float32)
y = y.astype(np.int32)
Expand All @@ -422,13 +390,7 @@ def check_samplers_preserve_dtype(name, sampler_orig):

def check_samplers_sample_indices(name, sampler_orig):
sampler = clone(sampler_orig)
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
sampler.fit_resample(X, y)
sample_indices = sampler._get_tags().get("sample_indices", None)
if sample_indices:
Expand Down

0 comments on commit 3a7633d

Please sign in to comment.