@@ -76,17 +76,18 @@ class MultiLabelEstimator(sklearn.base.BaseEstimator):
76
76
scoring_metric (str, optional): The scoring metric. Defaults to 'P@1'.
77
77
"""
78
78
79
- def __init__ (self , options : str = "" , linear_technique : str = "1vsrest" , scoring_metric : str = "P@1" ):
79
+ def __init__ (self , options : str = "" , linear_technique : str = "1vsrest" , scoring_metric : str = "P@1" , multiclass : bool = False ):
80
80
super ().__init__ ()
81
81
self .options = options
82
82
self .linear_technique = linear_technique
83
83
self .scoring_metric = scoring_metric
84
84
self ._is_fitted = False
85
+ self .multiclass = multiclass
85
86
86
87
def fit (self , X : sparse .csr_matrix , y : sparse .csr_matrix ):
87
88
X , y = sklearn .utils .validation .check_X_y (X , y , accept_sparse = True , multi_output = True )
88
89
self ._is_fitted = True
89
- self .model = LINEAR_TECHNIQUES [self .linear_technique ](y , X , self .options )
90
+ self .model = LINEAR_TECHNIQUES [self .linear_technique ](y , X , options = self .options )
90
91
return self
91
92
92
93
def predict (self , X : sparse .csr_matrix ) -> np .ndarray :
@@ -96,8 +97,9 @@ def predict(self, X: sparse.csr_matrix) -> np.ndarray:
96
97
97
98
def score (self , X : sparse .csr_matrix , y : sparse .csr_matrix ) -> float :
98
99
metrics = linear .get_metrics (
99
- [self .scoring_metric ],
100
- y .shape [1 ],
100
+ monitor_metrics = [self .scoring_metric ],
101
+ num_classes = y .shape [1 ],
102
+ multiclass = self .multiclass
101
103
)
102
104
preds = self .predict (X )
103
105
metrics .update (preds , y .toarray ())
0 commit comments