Skip to content

Commit

Permalink
add test for parallelisation
Browse files Browse the repository at this point in the history
  • Loading branch information
Microsheep committed Feb 19, 2019
1 parent 4b5546f commit d8b061f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
16 changes: 16 additions & 0 deletions imblearn/combine/tests/test_smote_enn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,22 @@ def test_validate_estimator_default():
assert_array_equal(y_resampled, y_gt)


def test_parallelisation():
# Check if default job count is 1
smt = SMOTEENN(random_state=RND_SEED)
smt._validate_estimator()
assert smt.n_jobs == 1
assert smt.smote_.n_jobs == 1
assert smt.enn_.n_jobs == 1

# Check if job count is set
smt = SMOTEENN(random_state=RND_SEED, n_jobs=8)
smt._validate_estimator()
assert smt.n_jobs == 8
assert smt.smote_.n_jobs == 8
assert smt.enn_.n_jobs == 8


@pytest.mark.parametrize(
"smote_params, err_msg",
[({'smote': 'rnd'}, "smote needs to be a SMOTE"),
Expand Down
16 changes: 16 additions & 0 deletions imblearn/combine/tests/test_smote_tomek.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,22 @@ def test_validate_estimator_default():
assert_array_equal(y_resampled, y_gt)


def test_parallelisation():
# Check if default job count is 1
smt = SMOTETomek(random_state=RND_SEED)
smt._validate_estimator()
assert smt.n_jobs == 1
assert smt.smote_.n_jobs == 1
assert smt.tomek_.n_jobs == 1

# Check if job count is set
smt = SMOTETomek(random_state=RND_SEED, n_jobs=8)
smt._validate_estimator()
assert smt.n_jobs == 8
assert smt.smote_.n_jobs == 8
assert smt.tomek_.n_jobs == 8


@pytest.mark.parametrize(
"smote_params, err_msg",
[({'smote': 'rnd'}, "smote needs to be a SMOTE"),
Expand Down

0 comments on commit d8b061f

Please sign in to comment.