-
Notifications
You must be signed in to change notification settings - Fork 0
Private fit function for callbacks #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: base_callbacks_2
Are you sure you want to change the base?
Private fit function for callbacks #18
Conversation
sklearn/callback/tests/_utils.py
Outdated
| def fit(self, X_train=None, y_train=None, X_val=None, y_val=None): | ||
| if isinstance(self, CallbackSupportMixin): | ||
| callback_ctx = self.init_callback_context() | ||
| callback_ctx.eval_on_fit_begin(estimator=self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on_fit_begin should be called in __skl_fit__ imo.
sklearn/callback/tests/_utils.py
Outdated
| BaseEstimator class. | ||
| """ | ||
|
|
||
| @_fit_context(prefer_skip_nested_validation=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's test what it looks like to extract the content of the decorator and put it inside fit.
sklearn/callback/tests/_utils.py
Outdated
| ): | ||
| for i in range(self.max_iter): | ||
| subcontext = callback_ctx.subcontext(task_id=i) | ||
| if callback_ctx is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a new method, like set_task_info, to set max_subtasks, task_name and task_id here, now that init_callback_context has moved.
callback_ctx.set_task_info(...)
sklearn/callback/tests/_utils.py
Outdated
| self, X_train=None, y_train=None, X_val=None, y_val=None, callback_ctx=None | ||
| ): | ||
| for i in range(self.max_iter): | ||
| subcontext = callback_ctx.subcontext(task_id=i) | ||
| if callback_ctx is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
callback context cannot be None since we intialize it in fit and pass it to __skl_fit__.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
…it-learn into private_fit_for_callbacks
…s merged with its parent
Implement the private
__skl_fit__method for the estimators used for the callbacks.