-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: improving sklearn compatibility via parametrize_with_checks
#660
Changes from 13 commits
174ce16
108b4d1
c7ac899
e936251
bff99a6
cd5c19d
293e2e3
b0b9b1b
af2ba77
cc84cf4
429b950
1c9681f
62030fb
f3a5c1c
f09e69b
230eaf6
4addc9e
70b602b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,13 +93,9 @@ def fit(self, X, y=None): | |
) | ||
self.pca_.fit(X, y) | ||
self.offset_ = -self.threshold | ||
return self | ||
|
||
def transform(self, X): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method was throwing off the checks as they dynamically look for method names |
||
"""Transform the data using the underlying PCA method.""" | ||
X = check_array(X, estimator=self, dtype=FLOAT_DTYPES) | ||
check_is_fitted(self, ["pca_", "offset_"]) | ||
return self.pca_.transform(X) | ||
self.n_features_in_ = X.shape[1] | ||
return self | ||
|
||
def difference(self, X): | ||
"""Return the calculated difference between original and reconstructed data. Row by row. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,3 +124,6 @@ def allowed_strategies(self): | |
DeprecationWarning, | ||
) | ||
return self._ALLOWED_STRATEGIES | ||
|
||
def _more_tags(self): | ||
return {"poor_score": True, "non_deterministic": True} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I love the |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,6 +84,7 @@ def fit(self, X, y): | |
raise ValueError(f"Param `sigma` must be >= 0, got: {self.sigma}") | ||
self.X_ = X | ||
self.y_ = y | ||
self.n_features_in_ = X.shape[1] | ||
return self | ||
|
||
def _calc_wts(self, x_i): | ||
|
@@ -491,15 +492,15 @@ def fit(self, X, y): | |
raise ValueError(f"penalty should be either 'l1' or 'none', got {self.penalty}") | ||
|
||
self.sensitive_col_idx_ = self.sensitive_cols | ||
|
||
if isinstance(X, pd.DataFrame): | ||
self.sensitive_col_idx_ = [i for i, name in enumerate(X.columns) if name in self.sensitive_cols] | ||
X, y = check_X_y(X, y, accept_large_sparse=False) | ||
|
||
sensitive = X[:, self.sensitive_col_idx_] | ||
if not self.train_sensitive_cols: | ||
X = np.delete(X, self.sensitive_col_idx_, axis=1) | ||
X = self._add_intercept(X) | ||
|
||
X = self._add_intercept(X) | ||
column_or_1d(y) | ||
label_encoder = LabelEncoder().fit(y) | ||
y = label_encoder.transform(y) | ||
|
@@ -577,6 +578,9 @@ def _add_intercept(self, X): | |
if self.fit_intercept: | ||
return np.c_[np.ones(len(X)), X] | ||
|
||
def _more_tags(self): | ||
return {"poor_score": True} | ||
|
||
|
||
class DemographicParityClassifier(BaseEstimator, LinearClassifierMixin): | ||
r"""`DemographicParityClassifier` is a logistic regression classifier which can be constrained on demographic | ||
|
@@ -1017,17 +1021,16 @@ def __init__( | |
|
||
def _get_objective(self, X, y, sample_weight): | ||
def imbalanced_loss(params): | ||
return 0.5 * np.mean( | ||
sample_weight | ||
* np.where(X @ params > y, self.overestimation_punishment_factor, 1) | ||
* np.square(y - X @ params) | ||
return 0.5 * np.average( | ||
np.where(X @ params > y, self.overestimation_punishment_factor, 1) * np.square(y - X @ params), | ||
weights=sample_weight, | ||
) + self._regularized_loss(params) | ||
|
||
def grad_imbalanced_loss(params): | ||
return ( | ||
-(sample_weight * np.where(X @ params > y, self.overestimation_punishment_factor, 1) * (y - X @ params)) | ||
@ X | ||
/ X.shape[0] | ||
/ sample_weight.sum() | ||
) + self._regularized_grad_loss(params) | ||
|
||
return imbalanced_loss, grad_imbalanced_loss | ||
|
@@ -1128,15 +1131,16 @@ def __init__( | |
|
||
def _get_objective(self, X, y, sample_weight): | ||
def quantile_loss(params): | ||
return np.mean( | ||
sample_weight * np.where(X @ params < y, self.quantile, 1 - self.quantile) * np.abs(y - X @ params) | ||
return np.average( | ||
np.where(X @ params < y, self.quantile, 1 - self.quantile) * np.abs(y - X @ params), | ||
weights=sample_weight, | ||
) + self._regularized_loss(params) | ||
|
||
def grad_quantile_loss(params): | ||
return ( | ||
-(sample_weight * np.where(X @ params < y, self.quantile, 1 - self.quantile) * np.sign(y - X @ params)) | ||
@ X | ||
/ X.shape[0] | ||
/ sample_weight.sum() | ||
Comment on lines
+1148
to
+1157
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @Garve, sorry to drag you into this mess. No need for you to read all the changes. TL;DR is that in this PR I added tests using scikit-learn parametrize_with_checks to check for (better) compatibility. As some of the tests were failing I took a closer look and adjusted the formulas to take into account There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's really just one solver I would just not sweat it. We should document it for sure but we can always patch that later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apparently, for ubuntu and windows 'TNC' as well is failing, but only for |
||
) + self._regularized_grad_loss(params) | ||
|
||
return quantile_loss, grad_quantile_loss | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the good news! They are both working with numpy 2.0rc