Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,11 +1335,7 @@ def wrapper(estimator, *args, **kwargs):
prefer_skip_nested_validation or global_skip_validation
)
):
try:
return fit_method(estimator, *args, **kwargs)
finally:
if hasattr(estimator, "_callback_fit_ctx"):
estimator._callback_fit_ctx.eval_on_fit_end(estimator=estimator)
return fit_method(estimator, *args, **kwargs)

return wrapper

Expand Down
13 changes: 3 additions & 10 deletions sklearn/callback/_callback_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class CallbackContext:
"""

@classmethod
def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
def _from_estimator(cls, estimator, task_name):
"""Private constructor to create a root context.

Parameters
Expand All @@ -133,13 +133,6 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):

task_name : str
The name of the task this context is responsible for.

task_id : int
The id of the task this context is responsible for.

max_subtasks : int or None, default=None
The maximum number of subtasks of this task. 0 means it's a leaf.
None means the maximum number of subtasks is not known in advance.
"""
new_ctx = cls.__new__(cls)

Expand All @@ -148,10 +141,10 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
new_ctx._callbacks = getattr(estimator, "_skl_callbacks", [])
new_ctx.estimator_name = estimator.__class__.__name__
new_ctx.task_name = task_name
new_ctx.task_id = task_id
new_ctx.task_id = 0
new_ctx.parent = None
new_ctx._children_map = {}
new_ctx.max_subtasks = max_subtasks
new_ctx.max_subtasks = None
new_ctx.prev_estimator_name = None
new_ctx.prev_task_name = None

Expand Down
44 changes: 26 additions & 18 deletions sklearn/callback/_mixin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import functools

from sklearn.callback._base import Callback
from sklearn.callback._callback_context import CallbackContext


class CallbackSupportMixin:
"""Mixin class to add callback support to an estimator."""

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
for fit_method in ("fit", "fit_transform", "fit_predict", "partial_fit"):
if hasattr(cls, fit_method):
setattr(cls, fit_method, _fit_callback(getattr(cls, fit_method)))

def set_callbacks(self, callbacks):
"""Set callbacks for the estimator.

Expand All @@ -31,25 +39,25 @@ def set_callbacks(self, callbacks):

return self

def __skl_init_callback_context__(self, task_name="fit", max_subtasks=None):
"""Initialize the callback context for the estimator.

Parameters
----------
task_name : str, default='fit'
The name of the root task.
def _fit_callback(fit_method):
"""Decorator to initialize the callback context for the fit methods."""

max_subtasks : int or None, default=None
The maximum number of subtasks that can be children of the root task. None
means the maximum number of subtasks is not known in advance.
@functools.wraps(fit_method)
def callback_wrapper(estimator, *args, **kwargs):
ctx_already_existing = hasattr(estimator, "__sklearn_callback_fit_ctx__")
if not ctx_already_existing:
estimator.__sklearn_callback_fit_ctx__ = CallbackContext._from_estimator(
estimator, task_name=fit_method.__name__
)

Returns
-------
callback_fit_ctx : CallbackContext
The callback context for the estimator.
"""
self._callback_fit_ctx = CallbackContext._from_estimator(
estimator=self, task_name=task_name, task_id=0, max_subtasks=max_subtasks
)
try:
return fit_method(estimator, *args, **kwargs)
finally:
if not ctx_already_existing:
estimator.__sklearn_callback_fit_ctx__.eval_on_fit_end(estimator)
del estimator.__sklearn_callback_fit_ctx__
if hasattr(estimator, "_parent_callback_ctx"):
del estimator._parent_callback_ctx

return self._callback_fit_ctx
return callback_wrapper
20 changes: 9 additions & 11 deletions sklearn/callback/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ def __init__(self, max_iter=20, computation_intensity=0.001):
self.computation_intensity = computation_intensity

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__(
max_subtasks=self.max_iter
).eval_on_fit_begin(estimator=self)

def fit(self, X=None, y=None, X_val=None, y_val=None):
callback_ctx = self.__sklearn_callback_fit_ctx__
callback_ctx.max_subtasks = self.max_iter
callback_ctx.eval_on_fit_begin(estimator=self)
for i in range(self.max_iter):
subcontext = callback_ctx.subcontext(task_id=i)

Expand Down Expand Up @@ -83,11 +82,10 @@ def __init__(self, computation_intensity=0.001):
self.computation_intensity = computation_intensity

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__().eval_on_fit_begin(
def fit(self, X=None, y=None, X_val=None, y_val=None):
callback_ctx = self.__sklearn_callback_fit_ctx__.eval_on_fit_begin(
estimator=self
)

i = 0
while True:
subcontext = callback_ctx.subcontext(task_id=i)
Expand Down Expand Up @@ -128,9 +126,9 @@ def __init__(

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__(
max_subtasks=self.n_outer
).eval_on_fit_begin(estimator=self)
callback_ctx = self.__sklearn_callback_fit_ctx__
callback_ctx.max_subtasks = self.n_outer
callback_ctx.eval_on_fit_begin(estimator=self)

Parallel(n_jobs=self.n_jobs, prefer=self.prefer)(
delayed(_func)(
Expand Down
69 changes: 26 additions & 43 deletions sklearn/callback/tests/test_callback_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_propagate_callbacks():
metaestimator = MetaEstimator(estimator)
metaestimator.set_callbacks([not_propagated_callback, propagated_callback])

callback_ctx = metaestimator.__skl_init_callback_context__()
callback_ctx = CallbackContext._from_estimator(metaestimator, task_name="fit")
callback_ctx.propagate_callbacks(estimator)

assert hasattr(estimator, "_parent_callback_ctx")
Expand All @@ -35,7 +35,7 @@ def test_propagate_callback_no_callback():
estimator = Estimator()
metaestimator = MetaEstimator(estimator)

callback_ctx = metaestimator.__skl_init_callback_context__()
callback_ctx = CallbackContext._from_estimator(metaestimator, task_name="fit")
assert len(callback_ctx._callbacks) == 0

callback_ctx.propagate_callbacks(estimator)
Expand All @@ -62,28 +62,20 @@ def test_auto_propagated_callbacks():
def _make_task_tree(n_children, n_grandchildren):
"""Helper function to create a tree of tasks with a context for each task."""
estimator = Estimator()
root = CallbackContext._from_estimator(
estimator,
task_name="root task",
task_id=0,
max_subtasks=n_children,
)
root = CallbackContext._from_estimator(estimator, task_name="root task")
root.max_subtasks = n_children

for i in range(n_children):
child = CallbackContext._from_estimator(
estimator,
task_name="child task",
task_id=i,
max_subtasks=n_grandchildren,
)
child = CallbackContext._from_estimator(estimator, task_name="child task")
child.task_id = (i,)
child.max_subtasks = n_grandchildren
root._add_child(child)

for j in range(n_grandchildren):
grandchild = CallbackContext._from_estimator(
estimator,
task_name="grandchild task",
task_id=j,
estimator, task_name="grandchild task"
)
grandchild.task_id = j
child._add_child(grandchild)

return root
Expand Down Expand Up @@ -122,57 +114,48 @@ def test_task_tree():
def test_add_child():
"""Sanity check for the `_add_child` method."""
estimator = Estimator()
root = CallbackContext._from_estimator(
estimator, task_name="root task", task_id=0, max_subtasks=2
)
root = CallbackContext._from_estimator(estimator, task_name="root task")
root.max_subtasks = 2

root._add_child(
CallbackContext._from_estimator(estimator, task_name="child task", task_id=0)
)
first_child = CallbackContext._from_estimator(estimator, task_name="child task")

root._add_child(first_child)
assert root.max_subtasks == 2
assert len(root._children_map) == 1

second_child = CallbackContext._from_estimator(estimator, task_name="child task")
# root already has a child with id 0
with pytest.raises(
ValueError, match=r"Callback context .* already has a child with task_id=0"
):
root._add_child(
CallbackContext._from_estimator(
estimator, task_name="child task", task_id=0
)
)
root._add_child(second_child)

root._add_child(
CallbackContext._from_estimator(estimator, task_name="child task", task_id=1)
)
second_child.task_id = 1
root._add_child(second_child)
assert len(root._children_map) == 2

third_child = CallbackContext._from_estimator(estimator, task_name="child task")
third_child.task_id = 2
# root can have at most 2 children
with pytest.raises(ValueError, match=r"Cannot add child to callback context"):
root._add_child(
CallbackContext._from_estimator(
estimator, task_name="child task", task_id=2
)
)
root._add_child(third_child)


def test_merge_with():
"""Sanity check for the `_merge_with` method."""
estimator = Estimator()
meta_estimator = MetaEstimator(estimator)
outer_root = CallbackContext._from_estimator(
meta_estimator, task_name="root", task_id=0, max_subtasks=2
)
outer_root = CallbackContext._from_estimator(meta_estimator, task_name="root")
outer_root.max_subtasks = 2

# Add a child task within the same estimator
outer_child = CallbackContext._from_estimator(
meta_estimator, task_name="child", task_id="id", max_subtasks=1
)
outer_child = CallbackContext._from_estimator(meta_estimator, task_name="child")
outer_child.max_subtasks = 1
outer_root._add_child(outer_child)

# The root task of the inner estimator is merged with (and effectively replaces)
# a leaf of the outer estimator because they correspond to the same formal task.
inner_root = CallbackContext._from_estimator(estimator, task_name="root", task_id=0)
inner_root = CallbackContext._from_estimator(estimator, task_name="root")
inner_root._merge_with(outer_child)

assert inner_root.parent is outer_root
Expand Down
10 changes: 4 additions & 6 deletions sklearn/callback/tests/test_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ def test_set_callbacks_error(callbacks):
estimator.set_callbacks(callbacks)


def test_init_callback_context():
"""Sanity check for the `__skl_init_callback_context__` method."""
def test_callback_removed_after_fit():
"""Test that the __sklearn_callback_fit_ctx__ attribute gets removed after fit."""
estimator = Estimator()
callback_ctx = estimator.__skl_init_callback_context__()

assert hasattr(estimator, "_callback_fit_ctx")
assert hasattr(callback_ctx, "_callbacks")
estimator.fit()
assert not hasattr(estimator, "__sklearn_callback_fit_ctx__")
Loading