-
Notifications
You must be signed in to change notification settings - Fork 0
Dynamically wrap fit in the CallbackSupportMixin to initialize the callback context #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: base_callbacks_2
Are you sure you want to change the base?
Dynamically wrap fit in the CallbackSupportMixin to initialize the callback context #20
Conversation
jeremiedbb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try another alternative in a new PR, that is an extension of this one, where fit accepts a callback_context arg. We assume that fit has the arg and the decorator initializes the context and passes it to fit.
sklearn/callback/_mixin.py
Outdated
|
|
||
| def __init_subclass__(cls, **kwargs): | ||
| super().__init_subclass__(**kwargs) | ||
| if hasattr(cls, "fit"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this defensive condition is useful. If one adds the mixin to a non-estimator class, it's his problem :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was anticipating cases where we might have callbacks on transform, so we would just add another if statement with a _transform_callback decorator, to tackle classes with transform but no fit or vice versa.
sklearn/callback/_mixin.py
Outdated
| def __init_subclass__(cls, **kwargs): | ||
| super().__init_subclass__(**kwargs) | ||
| if hasattr(cls, "fit"): | ||
| cls.fit = _fit_callback(cls.fit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll probably need to wrap all methods that fit, i.e. fit, fit_transform, fit_predict, partial_fit.
Although I'm not sure what will be the behavior then in a case like fit calling fit_transform internally.
sklearn/callback/_mixin.py
Outdated
| means the maximum number of subtasks is not known in advance. | ||
| @functools.wraps(fit_method) | ||
| def wrapper(estimator, *args, **kwargs): | ||
| estimator._callback_fit_ctx = CallbackContext._from_estimator(estimator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be safe w.r.t. my previous comment we can check first if the attribute doesn't already exist
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
…finally for eval_on_fit_end in _fit_context
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
… the same wrapper
…it-learn into mixin_wrapper_for_callbacks
…s merged with its parent
The callback context relative to the
fitis intialized in a decorator aroundfit, and thefitfunction is dynamically decorated through__init_subclass__in theCallbackSupportMixinclass.