Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,11 +1335,7 @@ def wrapper(estimator, *args, **kwargs):
prefer_skip_nested_validation or global_skip_validation
)
):
try:
return fit_method(estimator, *args, **kwargs)
finally:
if hasattr(estimator, "_callback_fit_ctx"):
estimator._callback_fit_ctx.eval_on_fit_end(estimator=estimator)
return fit_method(estimator, *args, **kwargs)

return wrapper

Expand Down
13 changes: 3 additions & 10 deletions sklearn/callback/_callback_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class CallbackContext:
"""

@classmethod
def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
def _from_estimator(cls, estimator, task_name):
"""Private constructor to create a root context.

Parameters
Expand All @@ -133,13 +133,6 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):

task_name : str
The name of the task this context is responsible for.

task_id : int
The id of the task this context is responsible for.

max_subtasks : int or None, default=None
The maximum number of subtasks of this task. 0 means it's a leaf.
None means the maximum number of subtasks is not known in advance.
"""
new_ctx = cls.__new__(cls)

Expand All @@ -148,10 +141,10 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
new_ctx._callbacks = getattr(estimator, "_skl_callbacks", [])
new_ctx.estimator_name = estimator.__class__.__name__
new_ctx.task_name = task_name
new_ctx.task_id = task_id
new_ctx.task_id = 0
new_ctx.parent = None
new_ctx._children_map = {}
new_ctx.max_subtasks = max_subtasks
new_ctx.max_subtasks = None
new_ctx.prev_estimator_name = None
new_ctx.prev_task_name = None

Expand Down
39 changes: 22 additions & 17 deletions sklearn/callback/_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import functools

from sklearn.callback._base import Callback
from sklearn.callback._callback_context import CallbackContext

Expand Down Expand Up @@ -31,25 +33,28 @@ def set_callbacks(self, callbacks):

return self

def __skl_init_callback_context__(self, task_name="fit", max_subtasks=None):
"""Initialize the callback context for the estimator.

Parameters
----------
task_name : str, default='fit'
The name of the root task.
def fit_callback(fit_method):
"""Decorator to initialize the callback context for the fit methods."""

max_subtasks : int or None, default=None
The maximum number of subtasks that can be children of the root task. None
means the maximum number of subtasks is not known in advance.
@functools.wraps(fit_method)
def callback_wrapper(estimator, *args, **kwargs):
if not isinstance(estimator, CallbackSupportMixin):
raise ValueError(
f"Estimator {estimator.__class__.__name__} does not support callbacks,"
" as it does not inherit from CallbackSupportMixin."
)

Returns
-------
callback_fit_ctx : CallbackContext
The callback context for the estimator.
"""
self._callback_fit_ctx = CallbackContext._from_estimator(
estimator=self, task_name=task_name, task_id=0, max_subtasks=max_subtasks
estimator.__sklearn_callback_fit_ctx__ = CallbackContext._from_estimator(
estimator, task_name=fit_method.__name__
)

return self._callback_fit_ctx
try:
return fit_method(estimator, *args, **kwargs)
finally:
estimator.__sklearn_callback_fit_ctx__.eval_on_fit_end(estimator)
del estimator.__sklearn_callback_fit_ctx__
if hasattr(estimator, "_parent_callback_ctx"):
del estimator._parent_callback_ctx

return callback_wrapper
75 changes: 61 additions & 14 deletions sklearn/callback/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sklearn.base import BaseEstimator, _fit_context, clone
from sklearn.callback import CallbackSupportMixin
from sklearn.callback._mixin import fit_callback
from sklearn.utils.parallel import Parallel, delayed


Expand All @@ -14,7 +15,7 @@ class TestingCallback:
def on_fit_begin(self, estimator):
pass

def on_fit_end(self):
def on_fit_end(self, estimator, context):
pass

def on_fit_task_end(self, estimator, context, **kwargs):
Expand Down Expand Up @@ -49,12 +50,12 @@ def __init__(self, max_iter=20, computation_intensity=0.001):
self.max_iter = max_iter
self.computation_intensity = computation_intensity

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__(
max_subtasks=self.max_iter
).eval_on_fit_begin(estimator=self)

@fit_callback
@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X=None, y=None, X_val=None, y_val=None):
callback_ctx = self.__sklearn_callback_fit_ctx__
callback_ctx.max_subtasks = self.max_iter
callback_ctx.eval_on_fit_begin(estimator=self)
for i in range(self.max_iter):
subcontext = callback_ctx.subcontext(task_id=i)

Expand Down Expand Up @@ -82,12 +83,12 @@ class WhileEstimator(CallbackSupportMixin, BaseEstimator):
def __init__(self, computation_intensity=0.001):
self.computation_intensity = computation_intensity

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__().eval_on_fit_begin(
@fit_callback
@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X=None, y=None, X_val=None, y_val=None):
callback_ctx = self.__sklearn_callback_fit_ctx__.eval_on_fit_begin(
estimator=self
)

i = 0
while True:
subcontext = callback_ctx.subcontext(task_id=i)
Expand Down Expand Up @@ -126,11 +127,12 @@ def __init__(
self.n_jobs = n_jobs
self.prefer = prefer

@fit_callback
@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__(
max_subtasks=self.n_outer
).eval_on_fit_begin(estimator=self)
callback_ctx = self.__sklearn_callback_fit_ctx__
callback_ctx.max_subtasks = self.n_outer
callback_ctx.eval_on_fit_begin(estimator=self)

Parallel(n_jobs=self.n_jobs, prefer=self.prefer)(
delayed(_func)(
Expand Down Expand Up @@ -167,3 +169,48 @@ def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx):
estimator=meta_estimator,
data={"X_train": X, "y_train": y},
)


class EstimatorWithoutCallbackMixin(BaseEstimator):
@fit_callback
def fit(self, X=None, y=None):
pass


class SimpleMetaEstimator(CallbackSupportMixin, BaseEstimator):
"""A class that mimics the behavior of a meta-estimator that does not clone the
estimator and does not parallelize.

There is no iteration, the meta estimator simply calls the fit of the estimator once
in a subcontext.
"""

_parameter_constraints: dict = {}

def __init__(self, estimator):
self.estimator = estimator

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._prefer_skip_nested_validation = False
return tags

@fit_callback
@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__sklearn_callback_fit_ctx__
callback_ctx.max_subtasks = 1
callback_ctx.eval_on_fit_begin(estimator=self)
subcontext = callback_ctx.subcontext(task_name="subtask").propagate_callbacks(
sub_estimator=self.estimator
)
self.estimator.fit(X, y)
callback_ctx.eval_on_fit_task_end(
estimator=self,
data={
"X_train": X,
"y_train": y,
},
)

return self
79 changes: 36 additions & 43 deletions sklearn/callback/tests/test_callback_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.callback.tests._utils import (
Estimator,
MetaEstimator,
SimpleMetaEstimator,
TestingAutoPropagatedCallback,
TestingCallback,
)
Expand All @@ -22,7 +23,7 @@ def test_propagate_callbacks():
metaestimator = MetaEstimator(estimator)
metaestimator.set_callbacks([not_propagated_callback, propagated_callback])

callback_ctx = metaestimator.__skl_init_callback_context__()
callback_ctx = CallbackContext._from_estimator(metaestimator, task_name="fit")
callback_ctx.propagate_callbacks(estimator)

assert hasattr(estimator, "_parent_callback_ctx")
Expand All @@ -35,7 +36,7 @@ def test_propagate_callback_no_callback():
estimator = Estimator()
metaestimator = MetaEstimator(estimator)

callback_ctx = metaestimator.__skl_init_callback_context__()
callback_ctx = CallbackContext._from_estimator(metaestimator, task_name="fit")
assert len(callback_ctx._callbacks) == 0

callback_ctx.propagate_callbacks(estimator)
Expand All @@ -62,28 +63,20 @@ def test_auto_propagated_callbacks():
def _make_task_tree(n_children, n_grandchildren):
"""Helper function to create a tree of tasks with a context for each task."""
estimator = Estimator()
root = CallbackContext._from_estimator(
estimator,
task_name="root task",
task_id=0,
max_subtasks=n_children,
)
root = CallbackContext._from_estimator(estimator, task_name="root task")
root.max_subtasks = n_children

for i in range(n_children):
child = CallbackContext._from_estimator(
estimator,
task_name="child task",
task_id=i,
max_subtasks=n_grandchildren,
)
child = CallbackContext._from_estimator(estimator, task_name="child task")
child.task_id = (i,)
child.max_subtasks = n_grandchildren
root._add_child(child)

for j in range(n_grandchildren):
grandchild = CallbackContext._from_estimator(
estimator,
task_name="grandchild task",
task_id=j,
estimator, task_name="grandchild task"
)
grandchild.task_id = j
child._add_child(grandchild)

return root
Expand Down Expand Up @@ -122,57 +115,48 @@ def test_task_tree():
def test_add_child():
"""Sanity check for the `_add_child` method."""
estimator = Estimator()
root = CallbackContext._from_estimator(
estimator, task_name="root task", task_id=0, max_subtasks=2
)
root = CallbackContext._from_estimator(estimator, task_name="root task")
root.max_subtasks = 2

root._add_child(
CallbackContext._from_estimator(estimator, task_name="child task", task_id=0)
)
first_child = CallbackContext._from_estimator(estimator, task_name="child task")

root._add_child(first_child)
assert root.max_subtasks == 2
assert len(root._children_map) == 1

second_child = CallbackContext._from_estimator(estimator, task_name="child task")
# root already has a child with id 0
with pytest.raises(
ValueError, match=r"Callback context .* already has a child with task_id=0"
):
root._add_child(
CallbackContext._from_estimator(
estimator, task_name="child task", task_id=0
)
)
root._add_child(second_child)

root._add_child(
CallbackContext._from_estimator(estimator, task_name="child task", task_id=1)
)
second_child.task_id = 1
root._add_child(second_child)
assert len(root._children_map) == 2

third_child = CallbackContext._from_estimator(estimator, task_name="child task")
third_child.task_id = 2
# root can have at most 2 children
with pytest.raises(ValueError, match=r"Cannot add child to callback context"):
root._add_child(
CallbackContext._from_estimator(
estimator, task_name="child task", task_id=2
)
)
root._add_child(third_child)


def test_merge_with():
"""Sanity check for the `_merge_with` method."""
estimator = Estimator()
meta_estimator = MetaEstimator(estimator)
outer_root = CallbackContext._from_estimator(
meta_estimator, task_name="root", task_id=0, max_subtasks=2
)
outer_root = CallbackContext._from_estimator(meta_estimator, task_name="root")
outer_root.max_subtasks = 2

# Add a child task within the same estimator
outer_child = CallbackContext._from_estimator(
meta_estimator, task_name="child", task_id="id", max_subtasks=1
)
outer_child = CallbackContext._from_estimator(meta_estimator, task_name="child")
outer_child.max_subtasks = 1
outer_root._add_child(outer_child)

# The root task of the inner estimator is merged with (and effectively replaces)
# a leaf of the outer estimator because they correspond to the same formal task.
inner_root = CallbackContext._from_estimator(estimator, task_name="root", task_id=0)
inner_root = CallbackContext._from_estimator(estimator, task_name="root")
inner_root._merge_with(outer_child)

assert inner_root.parent is outer_root
Expand All @@ -183,3 +167,12 @@ def test_merge_with():
# The name and estimator name of the tasks it was merged with are stored
assert inner_root.prev_task_name == outer_child.task_name
assert inner_root.prev_estimator_name == outer_child.estimator_name


def test_no_parent_callback_after_fit():
"""Check that the `_parent_callback_ctx` attribute does not survive after fit."""
estimator = Estimator()
meta_estimator = SimpleMetaEstimator(estimator)
meta_estimator.set_callbacks(TestingAutoPropagatedCallback())
meta_estimator.fit()
assert not hasattr(estimator, "_parent_callback_ctx")
Loading
Loading