Skip to content

Commit d9307a0

Browse files
authored
Merge pull request #362 from thomas0125/fix_gridsearch
Fix GridSearchCV to support sklearn>=1.3.0
2 parents 7106704 + 5b04bc5 commit d9307a0

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

libmultilabel/linear/utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,32 +109,31 @@ class GridSearchCV(sklearn.model_selection.GridSearchCV):
109109
The usage is similar to sklearn's, except that the parameter ``scoring`` is unavailable. Instead, specify ``scoring_metric`` in ``MultiLabelEstimator`` in the Pipeline.
110110
111111
Args:
112-
pipeline (sklearn.pipeline.Pipeline): A sklearn Pipeline for grid search.
112+
estimator (estimator object): An estimator for grid search.
113113
param_grid (dict): Search space for a grid search containing a dictionary of
114114
parameters and their corresponding list of candidate values.
115115
n_jobs (int, optional): Number of CPU cores run in parallel. Defaults to None.
116116
"""
117117

118-
_required_parameters = ["pipeline", "param_grid"]
118+
_required_parameters = ["estimator", "param_grid"]
119119

120-
def __init__(self, pipeline: sklearn.pipeline.Pipeline, param_grid: dict, n_jobs=None, **kwargs):
121-
assert isinstance(pipeline, sklearn.pipeline.Pipeline)
120+
def __init__(self, estimator, param_grid: dict, n_jobs=None, **kwargs):
122121
if n_jobs is not None and n_jobs > 1:
123-
param_grid = self._set_singlecore_options(pipeline, param_grid)
122+
param_grid = self._set_singlecore_options(estimator, param_grid)
124123
if "scoring" in kwargs.keys():
125124
raise ValueError(
126125
"Please specify the validation metric with `MultiLabelEstimator.scoring_metric` in the Pipeline instead of using the parameter `scoring`."
127126
)
128127

129-
super().__init__(estimator=pipeline, n_jobs=n_jobs, param_grid=param_grid, **kwargs)
128+
super().__init__(estimator=estimator, n_jobs=n_jobs, param_grid=param_grid, **kwargs)
130129

131-
def _set_singlecore_options(self, pipeline: sklearn.pipeline.Pipeline, param_grid: dict):
130+
def _set_singlecore_options(self, estimator, param_grid: dict):
132131
"""Set liblinear options to `-m 1`. The grid search option `n_jobs`
133132
runs multiple processes in parallel. Using multithreaded liblinear
134133
in conjunction with grid search oversubscribes the CPU and deteriorates
135134
the performance significantly.
136135
"""
137-
params = pipeline.get_params()
136+
params = estimator.get_params()
138137
for name, transform in params.items():
139138
if isinstance(transform, MultiLabelEstimator):
140139
regex = r"-m \d+"

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ liblinear-multicore
22
numba
33
pandas>1.3.0
44
PyYAML
5-
scikit-learn==1.2.2
5+
scikit-learn
66
scipy
77
tqdm

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = libmultilabel
3-
version = 0.6.1
3+
version = 0.6.2
44
author = LibMultiLabel Team
55
license = MIT License
66
license_file = LICENSE
@@ -29,7 +29,7 @@ install_requires =
2929
numba
3030
pandas>1.3.0
3131
PyYAML
32-
scikit-learn==1.2.2
32+
scikit-learn
3333
scipy
3434
tqdm
3535

0 commit comments

Comments
 (0)