Skip to content

Commit

Permalink
MAINT validate parameters for public functions (scikit-learn-contrib#956
Browse files Browse the repository at this point in the history
)
  • Loading branch information
glemaitre authored Dec 5, 2022
1 parent f8c27ae commit ad71707
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 109 deletions.
25 changes: 15 additions & 10 deletions imblearn/datasets/_imbalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@
# License: MIT

from collections import Counter
from collections.abc import Mapping

from ..under_sampling import RandomUnderSampler
from ..utils import check_sampling_strategy


from ..utils._param_validation import validate_params


@validate_params(
{
"X": ["array-like"],
"y": ["array-like"],
"sampling_strategy": [Mapping, callable, None],
"random_state": ["random_state"],
"verbose": ["boolean"],
}
)
def make_imbalance(
X, y, *, sampling_strategy=None, random_state=None, verbose=False, **kwargs
):
Expand All @@ -26,7 +37,7 @@ def make_imbalance(
X : {array-like, dataframe} of shape (n_samples, n_features)
Matrix containing the data to be imbalanced.
y : ndarray of shape (n_samples,)
y : array-like of shape (n_samples,)
Corresponding label for each sample in X.
sampling_strategy : dict or callable,
Expand Down Expand Up @@ -86,16 +97,10 @@ def make_imbalance(
"""
target_stats = Counter(y)
# restrict ratio to be a dict or a callable
if isinstance(sampling_strategy, dict) or callable(sampling_strategy):
if isinstance(sampling_strategy, Mapping) or callable(sampling_strategy):
sampling_strategy_ = check_sampling_strategy(
sampling_strategy, y, "under-sampling", **kwargs
)
else:
raise ValueError(
f"'sampling_strategy' has to be a dictionary or a "
f"function returning a dictionary. Got {type(sampling_strategy)} "
f"instead."
)

if verbose:
print(f"The original target distribution in the dataset is: {target_stats}")
Expand Down
12 changes: 12 additions & 0 deletions imblearn/datasets/_zenodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from sklearn.datasets import get_data_home
from sklearn.utils import Bunch, check_random_state

from ..utils._param_validation import validate_params

URL = "https://zenodo.org/record/61452/files/benchmark-imbalanced-learn.tar.gz"
PRE_FILENAME = "x"
POST_FILENAME = "data.npz"
Expand Down Expand Up @@ -95,6 +97,16 @@
MAP_ID_NAME[v + 1] = k


@validate_params(
{
"data_home": [None, str],
"filter_data": [None, tuple],
"download_if_missing": ["boolean"],
"random_state": ["random_state"],
"shuffle": ["boolean"],
"verbose": ["boolean"],
}
)
def fetch_datasets(
*,
data_home=None,
Expand Down
1 change: 0 additions & 1 deletion imblearn/datasets/tests/test_imbalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def iris():
[
({0: -100, 1: 50, 2: 50}, "in a class cannot be negative"),
({0: 10, 1: 70}, "should be less or equal to the original"),
("random-string", "has to be a dictionary or a function"),
],
)
def test_make_imbalance_error(iris, sampling_strategy, err_msg):
Expand Down
135 changes: 115 additions & 20 deletions imblearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# License: MIT

import functools
import numbers
import warnings
from inspect import signature

Expand All @@ -26,7 +27,23 @@
from sklearn.utils.multiclass import unique_labels
from sklearn.utils.validation import check_consistent_length, column_or_1d

from ..utils._param_validation import Interval, StrOptions, validate_params


@validate_params(
{
"y_true": ["array-like"],
"y_pred": ["array-like"],
"labels": ["array-like", None],
"pos_label": [str, numbers.Integral, None],
"average": [
None,
StrOptions({"binary", "micro", "macro", "weighted", "samples"}),
],
"warn_for": ["array-like"],
"sample_weight": ["array-like", None],
}
)
def sensitivity_specificity_support(
y_true,
y_pred,
Expand Down Expand Up @@ -57,13 +74,13 @@ def sensitivity_specificity_support(
Parameters
----------
y_true : ndarray of shape (n_samples,)
y_true : array-like of shape (n_samples,)
Ground truth (correct) target values.
y_pred : ndarray of shape (n_samples,)
y_pred : array-like of shape (n_samples,)
Estimated targets as returned by a classifier.
labels : list, default=None
labels : array-like, default=None
The set of labels to include when ``average != 'binary'``, and their
order if ``average is None``. Labels present in the data can be
excluded, for example to calculate a multiclass average ignoring a
Expand All @@ -72,8 +89,11 @@ def sensitivity_specificity_support(
labels are column indices. By default, all labels in ``y_true`` and
``y_pred`` are used in sorted order.
pos_label : str or int, default=1
pos_label : str, int or None, default=1
The class to report if ``average='binary'`` and the data is binary.
If ``pos_label is None`` and in binary classification, this function
returns the average sensitivity and specificity if ``average``
is one of ``'weighted'``.
If the data are multiclass, this will be ignored;
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
scores for that label only.
Expand Down Expand Up @@ -105,7 +125,7 @@ def sensitivity_specificity_support(
This determines which warnings will be made in the case that this
function is being used to return only one of its metrics.
sample_weight : ndarray of shape (n_samples,), default=None
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
Expand Down Expand Up @@ -274,6 +294,19 @@ def sensitivity_specificity_support(
return sensitivity, specificity, true_sum


@validate_params(
{
"y_true": ["array-like"],
"y_pred": ["array-like"],
"labels": ["array-like", None],
"pos_label": [str, numbers.Integral, None],
"average": [
None,
StrOptions({"binary", "micro", "macro", "weighted", "samples"}),
],
"sample_weight": ["array-like", None],
}
)
def sensitivity_score(
y_true,
y_pred,
Expand All @@ -295,21 +328,23 @@ def sensitivity_score(
Parameters
----------
y_true : ndarray of shape (n_samples,)
y_true : array-like of shape (n_samples,)
Ground truth (correct) target values.
y_pred : ndarray of shape (n_samples,)
y_pred : array-like of shape (n_samples,)
Estimated targets as returned by a classifier.
labels : list, default=None
labels : array-like, default=None
The set of labels to include when ``average != 'binary'``, and their
order if ``average is None``. Labels present in the data can be
excluded, for example to calculate a multiclass average ignoring a
majority negative class, while labels not present in the data will
result in 0 components in a macro average.
pos_label : str or int, default=1
pos_label : str, int or None, default=1
The class to report if ``average='binary'`` and the data is binary.
If ``pos_label is None`` and in binary classification, this function
returns the average sensitivity if ``average`` is one of ``'weighted'``.
If the data are multiclass, this will be ignored;
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
scores for that label only.
Expand Down Expand Up @@ -337,7 +372,7 @@ def sensitivity_score(
meaningful for multilabel classification where this differs from
:func:`accuracy_score`).
sample_weight : ndarray of shape (n_samples,), default=None
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
Expand Down Expand Up @@ -374,6 +409,19 @@ def sensitivity_score(
return s


@validate_params(
{
"y_true": ["array-like"],
"y_pred": ["array-like"],
"labels": ["array-like", None],
"pos_label": [str, numbers.Integral, None],
"average": [
None,
StrOptions({"binary", "micro", "macro", "weighted", "samples"}),
],
"sample_weight": ["array-like", None],
}
)
def specificity_score(
y_true,
y_pred,
Expand All @@ -395,21 +443,23 @@ def specificity_score(
Parameters
----------
y_true : ndarray of shape (n_samples,)
y_true : array-like of shape (n_samples,)
Ground truth (correct) target values.
y_pred : ndarray of shape (n_samples,)
y_pred : array-like of shape (n_samples,)
Estimated targets as returned by a classifier.
labels : list, default=None
labels : array-like, default=None
The set of labels to include when ``average != 'binary'``, and their
order if ``average is None``. Labels present in the data can be
excluded, for example to calculate a multiclass average ignoring a
majority negative class, while labels not present in the data will
result in 0 components in a macro average.
pos_label : str or int, default=1
pos_label : str, int or None, default=1
The class to report if ``average='binary'`` and the data is binary.
If ``pos_label is None`` and in binary classification, this function
returns the average specificity if ``average`` is one of ``'weighted'``.
If the data are multiclass, this will be ignored;
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
scores for that label only.
Expand Down Expand Up @@ -437,7 +487,7 @@ def specificity_score(
meaningful for multilabel classification where this differs from
:func:`accuracy_score`).
sample_weight : ndarray of shape (n_samples,), default=None
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
Expand Down Expand Up @@ -474,6 +524,22 @@ def specificity_score(
return s


@validate_params(
{
"y_true": ["array-like"],
"y_pred": ["array-like"],
"labels": ["array-like", None],
"pos_label": [str, numbers.Integral, None],
"average": [
None,
StrOptions(
{"binary", "micro", "macro", "weighted", "samples", "multiclass"}
),
],
"sample_weight": ["array-like", None],
"correction": [Interval(numbers.Real, 0, None, closed="left")],
}
)
def geometric_mean_score(
y_true,
y_pred,
Expand Down Expand Up @@ -507,21 +573,24 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
Parameters
----------
y_true : ndarray of shape (n_samples,)
y_true : array-like of shape (n_samples,)
Ground truth (correct) target values.
y_pred : ndarray of shape (n_samples,)
y_pred : array-like of shape (n_samples,)
Estimated targets as returned by a classifier.
labels : list, default=None
labels : array-like, default=None
The set of labels to include when ``average != 'binary'``, and their
order if ``average is None``. Labels present in the data can be
excluded, for example to calculate a multiclass average ignoring a
majority negative class, while labels not present in the data will
result in 0 components in a macro average.
pos_label : str or int, default=1
pos_label : str, int or None, default=1
The class to report if ``average='binary'`` and the data is binary.
If ``pos_label is None`` and in binary classification, this function
returns the average geometric mean if ``average`` is one of
``'weighted'``.
If the data are multiclass, this will be ignored;
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
scores for that label only.
Expand All @@ -539,6 +608,8 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account.
``'multiclass'``:
No average is taken.
``'weighted'``:
Calculate metrics for each label, and find their average, weighted
by support (the number of true instances for each label). This
Expand All @@ -549,7 +620,7 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
meaningful for multilabel classification where this differs from
:func:`accuracy_score`).
sample_weight : ndarray of shape (n_samples,), default=None
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
correction : float, default=0.0
Expand Down Expand Up @@ -658,6 +729,7 @@ class is unrecognized by the classifier, G-mean resolves to zero. To
return gmean


@validate_params({"alpha": [numbers.Real], "squared": ["boolean"]})
def make_index_balanced_accuracy(*, alpha=0.1, squared=True):
"""Balance any scoring function using the index balanced accuracy.
Expand Down Expand Up @@ -763,6 +835,22 @@ def compute_score(*args, **kwargs):
return decorate


@validate_params(
{
"y_true": ["array-like"],
"y_pred": ["array-like"],
"labels": ["array-like", None],
"target_names": ["array-like", None],
"sample_weight": ["array-like", None],
"digits": [Interval(numbers.Integral, 0, None, closed="left")],
"alpha": [numbers.Real],
"output_dict": ["boolean"],
"zero_division": [
StrOptions({"warn"}),
Interval(numbers.Integral, 0, 1, closed="both"),
],
}
)
def classification_report_imbalanced(
y_true,
y_pred,
Expand Down Expand Up @@ -970,6 +1058,13 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.64\
return report


@validate_params(
{
"y_true": ["array-like"],
"y_pred": ["array-like"],
"sample_weight": ["array-like", None],
}
)
def macro_averaged_mean_absolute_error(y_true, y_pred, *, sample_weight=None):
"""Compute Macro-Averaged MAE for imbalanced ordinal classification.
Expand Down
Loading

0 comments on commit ad71707

Please sign in to comment.