Skip to content
6 changes: 1 addition & 5 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,11 +1335,7 @@ def wrapper(estimator, *args, **kwargs):
prefer_skip_nested_validation or global_skip_validation
)
):
try:
return fit_method(estimator, *args, **kwargs)
finally:
if hasattr(estimator, "_callback_fit_ctx"):
estimator._callback_fit_ctx.eval_on_fit_end(estimator=estimator)
return fit_method(estimator, *args, **kwargs)

return wrapper

Expand Down
20 changes: 7 additions & 13 deletions sklearn/callback/_callback_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@
#
# @_fit_context()
# def fit(self, X, y):
# callback_ctx = self.__skl__init_callback_context__(max_subtasks=self.max_iter)
# callback_ctx = CallbackContext._from_estimator(self)
# callback_ctx.set_task_info(max_subtasks=self.max_iter)
# callback_ctx.eval_on_fit_begin(estimator=self)
#
# for i in range(self.max_iter):
Expand All @@ -89,8 +90,8 @@ class CallbackContext:
This class is responsible for managing the callbacks and holding the tree structure
of an estimator's tasks. Each instance corresponds to a task of the estimator.

Instances of this class should be created using the `__skl_init_callback_context__`
method of its estimator or the `subcontext` method of this class.
Instances of this class should be created using the `_from_estimator`
method providing its estimator or the `subcontext` method of this class.

These contexts are passed to the callback hooks to be able to keep track of the
position of a task in the task tree within the callbacks.
Expand Down Expand Up @@ -123,7 +124,7 @@ class CallbackContext:
"""

@classmethod
def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
def _from_estimator(cls, estimator, task_name):
"""Private constructor to create a root context.

Parameters
Expand All @@ -133,13 +134,6 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):

task_name : str
The name of the task this context is responsible for.

task_id : int
The id of the task this context is responsible for.

max_subtasks : int or None, default=None
The maximum number of subtasks of this task. 0 means it's a leaf.
None means the maximum number of subtasks is not known in advance.
"""
new_ctx = cls.__new__(cls)

Expand All @@ -148,10 +142,10 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
new_ctx._callbacks = getattr(estimator, "_skl_callbacks", [])
new_ctx.estimator_name = estimator.__class__.__name__
new_ctx.task_name = task_name
new_ctx.task_id = task_id
new_ctx.task_id = 0
new_ctx.parent = None
new_ctx._children_map = {}
new_ctx.max_subtasks = max_subtasks
new_ctx.max_subtasks = None
new_ctx.prev_estimator_name = None
new_ctx.prev_task_name = None

Expand Down
24 changes: 0 additions & 24 deletions sklearn/callback/_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

from sklearn.callback._base import Callback
from sklearn.callback._callback_context import CallbackContext


class CallbackSupportMixin:
Expand Down Expand Up @@ -30,26 +29,3 @@ def set_callbacks(self, callbacks):
self._skl_callbacks = callbacks

return self

def __skl_init_callback_context__(self, task_name="fit", max_subtasks=None):
"""Initialize the callback context for the estimator.

Parameters
----------
task_name : str, default='fit'
The name of the root task.

max_subtasks : int or None, default=None
The maximum number of subtasks that can be children of the root task. None
means the maximum number of subtasks is not known in advance.

Returns
-------
callback_fit_ctx : CallbackContext
The callback context for the estimator.
"""
self._callback_fit_ctx = CallbackContext._from_estimator(
estimator=self, task_name=task_name, task_id=0, max_subtasks=max_subtasks
)

return self._callback_fit_ctx
187 changes: 152 additions & 35 deletions sklearn/callback/tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@

import time

from sklearn.base import BaseEstimator, _fit_context, clone
from sklearn._config import config_context, get_config
from sklearn.base import BaseEstimator, clone
from sklearn.callback import CallbackSupportMixin
from sklearn.callback._callback_context import CallbackContext
from sklearn.utils._tags import get_tags
from sklearn.utils.parallel import Parallel, delayed


Expand All @@ -14,7 +17,7 @@ class TestingCallback:
def on_fit_begin(self, estimator):
pass

def on_fit_end(self):
def on_fit_end(self, estimator, context):
pass

def on_fit_task_end(self, estimator, context, **kwargs):
Expand All @@ -37,7 +40,45 @@ def on_fit_task_end(self, estimator, context, **kwargs):
pass # pragma: no cover


class Estimator(CallbackSupportMixin, BaseEstimator):
class BaseEstimatorPrivateFit(BaseEstimator):
"""A class that adds the implementation of a public and private fit method to the
BaseEstimator class.
"""

def fit(self, X=None, y=None, X_val=None, y_val=None):
global_skip_validation = get_config()["skip_parameter_validation"]
if not global_skip_validation:
self._validate_params()
with config_context(
skip_parameter_validation=global_skip_validation
or get_tags(self)._prefer_skip_nested_validation
):
if isinstance(self, CallbackSupportMixin):
callback_ctx = CallbackContext._from_estimator(
estimator=self, task_name="fit"
)
try:
return self.__sklearn_fit__(
X=X,
y=y,
X_val=X_val,
y_val=y_val,
callback_ctx=callback_ctx,
)
finally:
callback_ctx.eval_on_fit_end(estimator=self)
if hasattr(self, "_parent_callback_ctx"):
del self._parent_callback_ctx
else:
return self.__sklearn_fit__(
X=X,
y=y,
X_val=X_val,
y_val=y_val,
)


class Estimator(CallbackSupportMixin, BaseEstimatorPrivateFit):
"""A class that mimics the behavior of an estimator.

The iterative part uses a loop with a max number of iterations known in advance.
Expand All @@ -49,20 +90,29 @@ def __init__(self, max_iter=20, computation_intensity=0.001):
self.max_iter = max_iter
self.computation_intensity = computation_intensity

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__(
max_subtasks=self.max_iter
).eval_on_fit_begin(estimator=self)
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._prefer_skip_nested_validation = False
return tags

def __sklearn_fit__(
self, X=None, y=None, X_val=None, y_val=None, callback_ctx=None
):
callback_ctx.max_subtasks = self.max_iter
callback_ctx.eval_on_fit_begin(estimator=self)
for i in range(self.max_iter):
subcontext = callback_ctx.subcontext(task_id=i)

time.sleep(self.computation_intensity) # Computation intensive task

if subcontext.eval_on_fit_task_end(
estimator=self,
data={"X_train": X, "y_train": y},
data={
"X_train": X,
"y_train": y,
"X_val": X_val,
"y_val": y_val,
},
):
break

Expand All @@ -71,7 +121,7 @@ def fit(self, X=None, y=None):
return self


class WhileEstimator(CallbackSupportMixin, BaseEstimator):
class WhileEstimator(CallbackSupportMixin, BaseEstimatorPrivateFit):
"""A class that mimics the behavior of an estimator.

The iterative part uses a loop with a max number of iterations known in advance.
Expand All @@ -82,12 +132,15 @@ class WhileEstimator(CallbackSupportMixin, BaseEstimator):
def __init__(self, computation_intensity=0.001):
self.computation_intensity = computation_intensity

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__().eval_on_fit_begin(
estimator=self
)
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._prefer_skip_nested_validation = False
return tags

def __sklearn_fit__(
self, X=None, y=None, X_val=None, y_val=None, callback_ctx=None
):
callback_ctx.eval_on_fit_begin(estimator=self)
i = 0
while True:
subcontext = callback_ctx.subcontext(task_id=i)
Expand All @@ -96,7 +149,12 @@ def fit(self, X=None, y=None):

if subcontext.eval_on_fit_task_end(
estimator=self,
data={"X_train": X, "y_train": y},
data={
"X_train": X,
"y_train": y,
"X_val": X_val,
"y_val": y_val,
},
):
break

Expand All @@ -108,7 +166,7 @@ def fit(self, X=None, y=None):
return self


class MetaEstimator(CallbackSupportMixin, BaseEstimator):
class MetaEstimator(CallbackSupportMixin, BaseEstimatorPrivateFit):
"""A class that mimics the behavior of a meta-estimator.

It has two levels of iterations. The outer level uses parallelism and the inner
Expand All @@ -126,18 +184,26 @@ def __init__(
self.n_jobs = n_jobs
self.prefer = prefer

@_fit_context(prefer_skip_nested_validation=False)
def fit(self, X=None, y=None):
callback_ctx = self.__skl_init_callback_context__(
max_subtasks=self.n_outer
).eval_on_fit_begin(estimator=self)
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._prefer_skip_nested_validation = False
return tags

def __sklearn_fit__(
self, X=None, y=None, X_val=None, y_val=None, callback_ctx=None
):
callback_ctx.max_subtasks = self.n_outer
callback_ctx.eval_on_fit_begin(estimator=self)
Parallel(n_jobs=self.n_jobs, prefer=self.prefer)(
delayed(_func)(
self,
self.estimator,
X,
y,
data={
"X_train": X,
"y_train": y,
"X_val": X_val,
"y_val": y_val,
},
callback_ctx=callback_ctx.subcontext(
task_name="outer", task_id=i, max_subtasks=self.n_inner
),
Expand All @@ -148,22 +214,73 @@ def fit(self, X=None, y=None):
return self


def _func(meta_estimator, inner_estimator, X, y, *, callback_ctx):
def _func(meta_estimator, inner_estimator, data, *, callback_ctx):
for i in range(meta_estimator.n_inner):
est = clone(inner_estimator)

inner_ctx = callback_ctx.subcontext(
task_name="inner", task_id=i
).propagate_callbacks(sub_estimator=est)
inner_ctx = (
callback_ctx.subcontext(task_name="inner", task_id=i).propagate_callbacks(
sub_estimator=est
)
if callback_ctx is not None
else None
)

est.fit(
X=data["X_train"],
y=data["y_train"],
X_val=data["X_val"],
y_val=data["y_val"],
)

est.fit(X, y)
if callback_ctx is not None:
inner_ctx.eval_on_fit_task_end(
estimator=meta_estimator,
data=data,
)

inner_ctx.eval_on_fit_task_end(
if callback_ctx is not None:
callback_ctx.eval_on_fit_task_end(
estimator=meta_estimator,
data={"X_train": X, "y_train": y},
data=data,
)

callback_ctx.eval_on_fit_task_end(
estimator=meta_estimator,
data={"X_train": X, "y_train": y},
)

class SimpleMetaEstimator(CallbackSupportMixin, BaseEstimatorPrivateFit):
"""A class that mimics the behavior of a meta-estimator that does not clone the
estimator and does not parallelize.

There is no iteration, the meta estimator simply calls the fit of the estimator once
in a subcontext.
"""

_parameter_constraints: dict = {}

def __init__(self, estimator):
self.estimator = estimator

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._prefer_skip_nested_validation = False
return tags

def __sklearn_fit__(
self, X=None, y=None, X_val=None, y_val=None, callback_ctx=None
):
callback_ctx.max_subtasks = 1
callback_ctx.eval_on_fit_begin(estimator=self)
subcontext = callback_ctx.subcontext(task_name="subtask").propagate_callbacks(
sub_estimator=self.estimator
)
self.estimator.fit(X=X, y=y, X_val=X_val, y_val=y_val)
callback_ctx.eval_on_fit_task_end(
estimator=self,
data={
"X_train": X,
"y_train": y,
"X_val": X_val,
"y_val": y_val,
},
)

return self
Loading
Loading