Skip to content

Commit 6e8fa71

Browse files
authored
Merge pull request #389 from poppopting/ktdev
explicitly assign arguments to avoid incorrect argument assignments
2 parents d2f315d + ac6d1d1 commit 6e8fa71

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

libmultilabel/linear/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,18 @@ class MultiLabelEstimator(sklearn.base.BaseEstimator):
7676
scoring_metric (str, optional): The scoring metric. Defaults to 'P@1'.
7777
"""
7878

79-
def __init__(self, options: str = "", linear_technique: str = "1vsrest", scoring_metric: str = "P@1"):
79+
def __init__(self, options: str = "", linear_technique: str = "1vsrest", scoring_metric: str = "P@1", multiclass: bool = False):
8080
super().__init__()
8181
self.options = options
8282
self.linear_technique = linear_technique
8383
self.scoring_metric = scoring_metric
8484
self._is_fitted = False
85+
self.multiclass = multiclass
8586

8687
def fit(self, X: sparse.csr_matrix, y: sparse.csr_matrix):
8788
X, y = sklearn.utils.validation.check_X_y(X, y, accept_sparse=True, multi_output=True)
8889
self._is_fitted = True
89-
self.model = LINEAR_TECHNIQUES[self.linear_technique](y, X, self.options)
90+
self.model = LINEAR_TECHNIQUES[self.linear_technique](y, X, options=self.options)
9091
return self
9192

9293
def predict(self, X: sparse.csr_matrix) -> np.ndarray:
@@ -96,8 +97,9 @@ def predict(self, X: sparse.csr_matrix) -> np.ndarray:
9697

9798
def score(self, X: sparse.csr_matrix, y: sparse.csr_matrix) -> float:
9899
metrics = linear.get_metrics(
99-
[self.scoring_metric],
100-
y.shape[1],
100+
monitor_metrics=[self.scoring_metric],
101+
num_classes=y.shape[1],
102+
multiclass=self.multiclass
101103
)
102104
preds = self.predict(X)
103105
metrics.update(preds, y.toarray())

linear_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def linear_train(datasets, config):
5151
model = LINEAR_TECHNIQUES[config.linear_technique](
5252
datasets["train"]["y"],
5353
datasets["train"]["x"],
54-
config.liblinear_options,
55-
config.tree_degree,
56-
config.tree_max_depth,
54+
options=config.liblinear_options,
55+
K=config.tree_degree,
56+
dmax=config.tree_max_depth,
5757
)
5858
else:
5959
model = LINEAR_TECHNIQUES[config.linear_technique](

0 commit comments

Comments
 (0)