Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions sklearn/callback/_callback_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import inspect
import warnings

from sklearn.base import BaseEstimator
from sklearn.callback import AutoPropagatedCallback

# TODO(callbacks): move these explanations into a dedicated user guide.
Expand Down Expand Up @@ -141,6 +144,18 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
The maximum number of subtasks of this task. 0 means it's a leaf.
None means the maximum number of subtasks is not known in advance.
"""
if hasattr(estimator, "_skl_callbacks") and estimator._skl_callbacks:
meta_est_no_callback = called_from_no_callback_meta_estimator()
if meta_est_no_callback is not None:
warnings.warn(
f"The estimator {estimator.__class__.__name__} which supports"
f" callbacks is used within the fitting of a {meta_est_no_callback}"
" meta-estimator which does not support callbacks. The behaviour of"
f" callbacks that are attached to {estimator.__class__.__name__}"
" will be undefined.",
UserWarning,
)

new_ctx = cls.__new__(cls)

# We don't store the estimator in the context to avoid circular references
Expand Down Expand Up @@ -358,6 +373,8 @@ def propagate_callbacks(self, sub_estimator):
sub_estimator : estimator instance
The estimator to which the callbacks should be propagated.
"""
from sklearn.callback._mixin import CallbackSupportMixin

bad_callbacks = [
callback.__class__.__name__
for callback in getattr(sub_estimator, "_skl_callbacks", [])
Expand Down Expand Up @@ -385,6 +402,16 @@ def propagate_callbacks(self, sub_estimator):
if not callbacks_to_propagate:
return self

if not isinstance(sub_estimator, CallbackSupportMixin):
warnings.warn(
f"The estimator {sub_estimator.__class__.__name__} which does not"
" supports callbacks is being used in a meta-estimator which supports"
" callbacks. The callbacks will not be propagated through this"
" estimator.",
UserWarning,
)
return self

# We store the parent context in the sub-estimator to be able to merge the
# task trees of the sub-estimator and the meta-estimator.
sub_estimator._parent_callback_ctx = self
Expand Down Expand Up @@ -415,3 +442,30 @@ def get_context_path(context):
if context.parent is None
else get_context_path(context.parent) + [context]
)


def called_from_no_callback_meta_estimator():
"""Helper function to check if in the traceback there is a call of a fit function
from a meta-estimator that does not support callbacks.

Returns
-------
str or None
The name of the class of the meta-estimator if there is one in the traceback
which does not support callback, None otherwise.
"""
from sklearn.callback._mixin import CallbackSupportMixin

for frame_info in inspect.stack()[1:]:
if (
frame_info.function not in ("fit", "fit_transform", "partial_fit")
or "self" not in frame_info.frame.f_locals
):
continue

if isinstance(
frame_info.frame.f_locals["self"], BaseEstimator
) and not isinstance(frame_info.frame.f_locals["self"], CallbackSupportMixin):
return frame_info.frame.f_locals["self"].__class__.__name__

return None
56 changes: 55 additions & 1 deletion sklearn/callback/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TestingCallback:
def on_fit_begin(self, estimator):
pass

def on_fit_end(self):
def on_fit_end(self, estimator, data):
pass

def on_fit_task_end(self, estimator, context, **kwargs):
Expand Down Expand Up @@ -167,3 +167,57 @@ def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx):
estimator=meta_estimator,
data={"X_train": X, "y_train": y},
)


class MetaEstimatorNoCallback(BaseEstimator):
"""A class that mimics the behavior of a meta-estimator which does not support
callbacks.
"""

_parameter_constraints: dict = {}

def __init__(
self, estimator, n_outer=4, n_inner=3, n_jobs=None, prefer="processes"
):
self.estimator = estimator
self.n_outer = n_outer
self.n_inner = n_inner
self.n_jobs = n_jobs
self.prefer = prefer

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
Parallel(n_jobs=self.n_jobs, prefer=self.prefer)(
delayed(_func_no_callback)(self, self.estimator, X, y)
for i in range(self.n_outer)
)

return self


def _func_no_callback(meta_estimator, inner_estimator, X, y):
for i in range(meta_estimator.n_inner):
est = clone(inner_estimator)

est.fit(X, y)


class EstimatorNoCallback(BaseEstimator):
"""A class that mimics the behavior of an estimator which does not support
callbacks.
"""

_parameter_constraints: dict = {}

def __init__(self, max_iter=20, computation_intensity=0.001):
self.max_iter = max_iter
self.computation_intensity = computation_intensity

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
for i in range(self.max_iter):
time.sleep(self.computation_intensity) # Computation intensive task

self.n_iter_ = i + 1

return self
26 changes: 26 additions & 0 deletions sklearn/callback/tests/test_callback_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from sklearn.callback._callback_context import CallbackContext, get_context_path
from sklearn.callback.tests._utils import (
Estimator,
EstimatorNoCallback,
MetaEstimator,
MetaEstimatorNoCallback,
TestingAutoPropagatedCallback,
TestingCallback,
)
Expand Down Expand Up @@ -183,3 +185,27 @@ def test_merge_with():
# The name and estimator name of the tasks it was merged with are stored
assert inner_root.prev_task_name == outer_child.task_name
assert inner_root.prev_estimator_name == outer_child.estimator_name


@pytest.mark.parametrize("n_jobs", [1, 2])
@pytest.mark.parametrize("prefer", ["threads", "processes"])
def test_no_callback_meta_est_warning(n_jobs, prefer):
estimator = Estimator()
estimator.set_callbacks(TestingCallback())
meta_estimator = MetaEstimatorNoCallback(estimator, n_jobs=n_jobs, prefer=prefer)
with pytest.warns(
UserWarning,
match="meta-estimator which does not support callbacks.",
):
meta_estimator.fit()


def test_no_callback_est_in_meta_est():
estimator = EstimatorNoCallback()
meta_estimator = MetaEstimator(estimator)
meta_estimator.set_callbacks(TestingAutoPropagatedCallback())
with pytest.warns(
UserWarning,
match="which does not supports callbacks is being used in a meta-estimator",
):
meta_estimator.fit()
Loading