Skip to content

Conversation

@FrancoisPgm
Copy link

The callback context relative to the fit is intialized in a decorator around fit, and the fit function is dynamically decorated through __init_subclass__ in the CallbackSupportMixin class.

@github-actions
Copy link

github-actions bot commented Oct 7, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 36a5aaa. Link to the linter CI: here

Copy link
Owner

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try another alternative in a new PR, that is an extension of this one, where fit accepts a callback_context arg. We assume that fit has the arg and the decorator initializes the context and passes it to fit.


def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, "fit"):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this defensive condition is useful. If one adds the mixin to a non-estimator class, it's his problem :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was anticipating cases where we might have callbacks on transform, so we would just add another if statement with a _transform_callback decorator, to tackle classes with transform but no fit or vice versa.

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, "fit"):
cls.fit = _fit_callback(cls.fit)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll probably need to wrap all methods that fit, i.e. fit, fit_transform, fit_predict, partial_fit.

Although I'm not sure what will be the behavior then in a case like fit calling fit_transform internally.

means the maximum number of subtasks is not known in advance.
@functools.wraps(fit_method)
def wrapper(estimator, *args, **kwargs):
estimator._callback_fit_ctx = CallbackContext._from_estimator(estimator)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be safe w.r.t. my previous comment we can check first if the attribute doesn't already exist

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants