Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions sklearn/callback/_callback_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import uuid

from sklearn.callback import AutoPropagatedCallback

# TODO(callbacks): move these explanations into a dedicated user guide.
Expand Down Expand Up @@ -120,6 +122,10 @@ class CallbackContext:

- parent : CallbackContext or None
The parent context of this context. None if this context is the root.

- uuid : UUID
The UUID relative to the task tree, meaning the same UUID is shared by a context
and all its children.
"""

@classmethod
Expand Down Expand Up @@ -154,6 +160,7 @@ def _from_estimator(cls, estimator, *, task_name, task_id, max_subtasks=None):
new_ctx.max_subtasks = max_subtasks
new_ctx.prev_estimator_name = None
new_ctx.prev_task_name = None
new_ctx.uuid = uuid.uuid4()

if hasattr(estimator, "_parent_callback_ctx"):
# This context's task is the root task of the estimator which itself
Expand Down Expand Up @@ -201,6 +208,7 @@ def _from_parent(cls, parent_context, *, task_name, task_id, max_subtasks=None):
new_ctx.max_subtasks = max_subtasks
new_ctx.prev_estimator_name = None
new_ctx.prev_task_name = None
new_ctx.uuid = parent_context.uuid

# This task is a subtask of another task of a same estimator
parent_context._add_child(new_ctx)
Expand Down Expand Up @@ -238,6 +246,7 @@ def _merge_with(self, other_context):
# meta-estimator's leaf context
self.parent = other_context.parent
self.task_id = other_context.task_id
self.uuid = other_context.uuid
other_context.parent._children_map[self.task_id] = self

# Keep information about the context it was merged with
Expand Down
27 changes: 21 additions & 6 deletions sklearn/callback/tests/test_callback_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,19 @@ def _make_task_tree(n_children, n_grandchildren):
)

for i in range(n_children):
child = CallbackContext._from_estimator(
estimator,
child = CallbackContext._from_parent(
root,
task_name="child task",
task_id=i,
max_subtasks=n_grandchildren,
)
root._add_child(child)

for j in range(n_grandchildren):
grandchild = CallbackContext._from_estimator(
estimator,
grandchild = CallbackContext._from_parent(
child,
task_name="grandchild task",
task_id=j,
)
child._add_child(grandchild)

return root

Expand All @@ -102,12 +100,14 @@ def test_task_tree():
assert len(get_context_path(child)) == 2
assert len(child._children_map) == 5
assert root.max_subtasks == 3
assert child.uuid == root.uuid

for grandchild in child._children_map.values():
assert grandchild.parent is child
assert len(get_context_path(grandchild)) == 3
assert len(grandchild._children_map) == 0
assert child.max_subtasks == 5
assert grandchild.uuid == root.uuid

# 1 root + 1 * 3 children + 1 * 3 * 5 grandchildren
expected_n_nodes = np.sum(np.cumprod([1, 3, 5]))
Expand Down Expand Up @@ -177,9 +177,24 @@ def test_merge_with():

assert inner_root.parent is outer_root
assert inner_root.task_id == outer_child.task_id
assert inner_root.uuid == outer_child.uuid
assert outer_child not in outer_root._children_map.values()
assert inner_root in outer_root._children_map.values()

# The name and estimator name of the tasks it was merged with are stored
assert inner_root.prev_task_name == outer_child.task_name
assert inner_root.prev_estimator_name == outer_child.estimator_name


def test_subcontext():
"""Sanity check for the `subcontext` method."""
estimator = Estimator()
estimator.set_callbacks(TestingCallback())
context = CallbackContext._from_estimator(estimator, task_name="task", task_id=0)
subcontext = context.subcontext(task_name="subtask", task_id=1)

assert context.uuid == subcontext.uuid
assert context._callbacks == subcontext._callbacks
assert context.estimator_name == subcontext.estimator_name
assert context._estimator_depth == subcontext._estimator_depth
assert subcontext.task_id in context._children_map
Loading