Skip to content

Commit

Permalink
MAINT compatibility sklearn 1.5 (#1074)
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed May 28, 2024
1 parent 78d94a5 commit 7c26d56
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 18 deletions.
2 changes: 1 addition & 1 deletion imblearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
self.sampler_ = clone(self.sampler)
# RandomUnderSampler is not supporting sample_weight. We need to pass
# None.
return super()._fit(X, y, self.max_samples, sample_weight=None)
return super()._fit(X, y, self.max_samples)

# TODO: remove when minimum supported version of scikit-learn is 1.1
@available_if(_estimator_has("decision_function"))
Expand Down
14 changes: 8 additions & 6 deletions imblearn/ensemble/_easy_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
check_target_type(y)
# RandomUnderSampler is not supporting sample_weight. We need to pass
# None.
return super()._fit(X, y, self.max_samples, sample_weight=None)
return super()._fit(X, y, self.max_samples)

# TODO: remove when minimum supported version of scikit-learn is 1.1
@available_if(_estimator_has("decision_function"))
Expand Down Expand Up @@ -365,9 +365,11 @@ def base_estimator_(self):
raise error
raise error

def _more_tags(self):
def _get_estimator(self):
if self.estimator is None:
estimator = AdaBoostClassifier(algorithm="SAMME")
else:
estimator = self.estimator
return {"allow_nan": _safe_tags(estimator, "allow_nan")}
return AdaBoostClassifier(algorithm="SAMME")
return self.estimator

# TODO: remove when minimum supported version of scikit-learn is 1.5
def _more_tags(self):
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
9 changes: 8 additions & 1 deletion imblearn/over_sampling/_smote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
import warnings

import numpy as np
import sklearn
from scipy import sparse
from sklearn.base import clone
from sklearn.exceptions import DataConversionWarning
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
from sklearn.utils import (
_get_column_indices,
_safe_indexing,
check_array,
check_random_state,
)
from sklearn.utils.fixes import parse_version
from sklearn.utils.sparsefuncs_fast import (
csr_mean_variance_axis0,
)
Expand All @@ -34,6 +35,12 @@
from ...utils.fixes import _is_pandas_df, _mode
from ..base import BaseOverSampler

sklearn_version = parse_version(sklearn.__version__).base_version
if parse_version(sklearn_version) < parse_version("1.5"):
from sklearn.utils import _get_column_indices
else:
from sklearn.utils._indexing import _get_column_indices


class BaseSMOTE(BaseOverSampler):
"""Base class for the different SMOTE algorithms."""
Expand Down
10 changes: 9 additions & 1 deletion imblearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# Christos Aridas
# Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: BSD
import sklearn
from sklearn import pipeline
from sklearn.base import clone
from sklearn.utils import Bunch, _print_elapsed_time
from sklearn.utils import Bunch
from sklearn.utils.fixes import parse_version
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import check_memory

Expand All @@ -34,6 +36,12 @@

__all__ = ["Pipeline", "make_pipeline"]

sklearn_version = parse_version(sklearn.__version__).base_version
if parse_version(sklearn_version) < parse_version("1.5"):
from sklearn.utils import _print_elapsed_time
else:
from sklearn.utils._user_interface import _print_elapsed_time


class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
"""Pipeline of transforms and resamples with a final estimator.
Expand Down
38 changes: 29 additions & 9 deletions imblearn/utils/_metadata_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,9 +1086,12 @@ def _serialize(self):

def __iter__(self):
if self._self_request:
yield "$self_request", RouterMappingPair(
mapping=MethodMapping.from_str("one-to-one"),
router=self._self_request,
yield (
"$self_request",
RouterMappingPair(
mapping=MethodMapping.from_str("one-to-one"),
router=self._self_request,
),
)
for name, route_mapping in self._route_mappings.items():
yield (name, route_mapping)
Expand Down Expand Up @@ -1234,7 +1237,7 @@ def __init__(self, name, keys, validate_keys=True):

def __get__(self, instance, owner):
# we would want to have a method which accepts only the expected args
def func(**kw):
def func(*args, **kw):
"""Updates the request for provided parameters
This docstring is overwritten below.
Expand All @@ -1253,15 +1256,32 @@ def func(**kw):
f"arguments are: {set(self.keys)}"
)

requests = instance._get_metadata_request()
# This makes it possible to use the decorated method as an unbound
# method, for instance when monkeypatching.
# https://github.com/scikit-learn/scikit-learn/issues/28632
if instance is None:
_instance = args[0]
args = args[1:]
else:
_instance = instance

# Replicating python's behavior when positional args are given other
# than `self`, and `self` is only allowed if this method is unbound.
if args:
raise TypeError(
f"set_{self.name}_request() takes 0 positional argument but"
f" {len(args)} were given"
)

requests = _instance._get_metadata_request()
method_metadata_request = getattr(requests, self.name)

for prop, alias in kw.items():
if alias is not UNCHANGED:
method_metadata_request.add_request(param=prop, alias=alias)
instance._metadata_request = requests
_instance._metadata_request = requests

return instance
return _instance

# Now we set the relevant attributes of the function so that it seems
# like a normal method to the end user, with known expected arguments.
Expand Down Expand Up @@ -1525,13 +1545,13 @@ def process_routing(_obj, _method, /, **kwargs):
metadata to corresponding methods or corresponding child objects. The object
names are those defined in `obj.get_metadata_routing()`.
"""
if not _routing_enabled() and not kwargs:
if not kwargs:
# If routing is not enabled and kwargs are empty, then we don't have to
# try doing any routing, we can simply return a structure which returns
# an empty dict on routed_params.ANYTHING.ANY_METHOD.
class EmptyRequest:
def get(self, name, default=None):
return default if default else {}
return Bunch(**{method: dict() for method in METHODS})

def __getitem__(self, name):
return Bunch(**{method: dict() for method in METHODS})
Expand Down

0 comments on commit 7c26d56

Please sign in to comment.