Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/reference/lifecycle-hooks/RichProgressBar.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
==============================
plugins.h_rich.RichProgressBar
==============================

Provides a progress bar for Hamilton execution. Must have `rich` installed to use it:

`pip install sf-hamilton[rich]` (use quotes if using zsh)


.. autoclass:: hamilton.plugins.h_rich.RichProgressBar
:special-members: __init__
:members:
:inherited-members:
9 changes: 9 additions & 0 deletions docs/reference/lifecycle-hooks/TaskExecutionHook.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
===================================
lifecycle.api.TaskExecutionHook
===================================


.. autoclass:: hamilton.lifecycle.api.TaskExecutionHook
:special-members: __init__
:members:
:inherited-members:
9 changes: 9 additions & 0 deletions docs/reference/lifecycle-hooks/TaskGroupingHook.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
===================================
lifecycle.api.TaskGroupingHook
===================================


.. autoclass:: hamilton.lifecycle.api.TaskGroupingHook
:special-members: __init__
:members:
:inherited-members:
4 changes: 3 additions & 1 deletion docs/reference/lifecycle-hooks/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ looking forward.
NodeExecutionMethod
StaticValidator
GraphConstructionHook

TaskExecutionHook
TaskGroupingHook

Available Adapters
-------------------
Expand All @@ -51,6 +52,7 @@ Recall to add lifecycle adapters, you just need to call the ``with_adapters`` me
PDBDebugger
PrintLn
ProgressBar
RichProgressBar
DDOGTracer
FunctionInputOutputTypeChecker
SlackNotifierHook
Expand Down
15 changes: 12 additions & 3 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ def execute(
results_cache = state.DictBasedResultCache(prehydrated_results)
# Create tasks from the grouped nodes, filtering/pruning as we go
tasks = grouping.create_task_plan(grouped_nodes, final_vars, overrides, self.adapter)
task_ids = [task.base_id for task in tasks]

if self.adapter.does_hook("post_task_group", is_async=False):
self.adapter.call_all_lifecycle_hooks_sync(
"post_task_group",
run_id=run_id,
task_ids=task_ids,
)

# Create a task graph and execution state
execution_state = state.ExecutionState(
tasks, results_cache, run_id
Expand Down Expand Up @@ -2069,7 +2078,7 @@ def with_execution_manager(self, execution_manager: executors.ExecutionManager)
self._require_field_unset("execution_manager", "Cannot set execution manager twice")
self._require_field_unset(
"remote_executor",
"Cannot set execution manager with remote " "executor set -- these are disjoint",
"Cannot set execution manager with remote executor set -- these are disjoint",
)

self.execution_manager = execution_manager
Expand All @@ -2086,7 +2095,7 @@ def with_remote_executor(self, remote_executor: executors.TaskExecutor) -> "Buil
self._require_field_unset("remote_executor", "Cannot set remote executor twice")
self._require_field_unset(
"execution_manager",
"Cannot set remote executor with execution " "manager set -- these are disjoint",
"Cannot set remote executor with execution manager set -- these are disjoint",
)
self.remote_executor = remote_executor
return self
Expand All @@ -2102,7 +2111,7 @@ def with_local_executor(self, local_executor: executors.TaskExecutor) -> "Builde
self._require_field_unset("local_executor", "Cannot set local executor twice")
self._require_field_unset(
"execution_manager",
"Cannot set local executor with execution " "manager set -- these are disjoint",
"Cannot set local executor with execution manager set -- these are disjoint",
)
self.local_executor = local_executor
return self
Expand Down
4 changes: 4 additions & 0 deletions hamilton/execution/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def base_execute_task(task: TaskImplementation) -> Dict[str, Any]:
nodes=task.nodes,
inputs=task.dynamic_inputs,
overrides=task.overrides,
spawning_task_id=task.spawning_task_id,
purpose=task.purpose,
)
error = None
success = True
Expand Down Expand Up @@ -139,6 +141,8 @@ def base_execute_task(task: TaskImplementation) -> Dict[str, Any]:
results=results,
success=success,
error=error,
spawning_task_id=task.spawning_task_id,
purpose=task.purpose,
)
# This selection is for GC
# We also need to get the override values
Expand Down
8 changes: 8 additions & 0 deletions hamilton/execution/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,14 @@ def update_task_state(
self.realize_parameterized_group(
completed_task.task_id, parameterization_values, input_to_parameterize
)
if completed_task.adapter.does_hook("post_task_expand", is_async=False):
# TODO -- parameterization_values could be materialized here for generators
completed_task.adapter.call_all_lifecycle_hooks_sync(
"post_task_expand",
run_id=completed_task.run_id,
task_id=completed_task.task_id,
parameters=parameterization_values,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, there's a slight design issue here, curious as to your thoughts. The fact that we create a dict with all the parameterization values is not part of the contract -- the idea is we could go to having a generator where they're not all decided for now. Not married to this (it's a bit baked in that it's a list now), but curious what you're planning on using this value for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversely, it could change to storing the whole set of values as a list regardless of whether it's a generator or not, only if this hook exists -- that could be an implementation detail we could handle later (e.g. add something that materializes it) -- this would allow us to release this now and not change the contract.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's funny, while I was writing this I was concerned about exhausting the parameterization values if it was/became a generator. My immediate use case was to determine the number of expanded tasks, however I thought it might be useful to someone down that road to inspect the expansion results. With that said, I absolutely love the idea of materializing the values only when the hook is defined - I am going to add a comment to the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think a comment is good here -- it's kind of a fun one. You could also have the hook take in an optional value...

)
else:
for candidate_task in self.base_reverse_dependencies[completed_task.base_id]:
# This means its not spawned by another task, or a node spawning group itself
Expand Down
2 changes: 2 additions & 0 deletions hamilton/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ResultBuilder,
StaticValidator,
TaskExecutionHook,
TaskGroupingHook,
)
from .base import LifecycleAdapter # noqa: F401
from .default import ( # noqa: F401
Expand Down Expand Up @@ -39,6 +40,7 @@
"NodeExecutionMethod",
"StaticValidator",
"TaskExecutionHook",
"TaskGroupingHook",
"FunctionInputOutputTypeChecker",
"NoEdgeAndInputTypeChecking",
]
57 changes: 57 additions & 0 deletions hamilton/lifecycle/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
# To really fix this we should move everything user-facing out of base, which is a pretty sloppy name for a package anyway
# And put it where it belongs. For now we're OK with the TYPE_CHECKING hack
if TYPE_CHECKING:
from hamilton.execution.grouping import NodeGroupPurpose
from hamilton.graph import FunctionGraph
else:
NodeGroupPurpose = None

from hamilton.graph_types import HamiltonGraph, HamiltonNode
from hamilton.lifecycle.base import (
Expand All @@ -27,6 +30,8 @@
BasePostGraphExecute,
BasePostNodeExecute,
BasePostTaskExecute,
BasePostTaskExpand,
BasePostTaskGroup,
BasePreGraphExecute,
BasePreNodeExecute,
BasePreTaskExecute,
Expand Down Expand Up @@ -379,13 +384,17 @@ def pre_task_execute(
nodes: List["node.Node"],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
self.run_before_task_execution(
run_id=run_id,
task_id=task_id,
nodes=[HamiltonNode.from_node(n) for n in nodes],
inputs=inputs,
overrides=overrides,
spawning_task_id=spawning_task_id,
purpose=purpose,
)

def post_task_execute(
Expand All @@ -397,6 +406,8 @@ def post_task_execute(
results: Optional[Dict[str, Any]],
success: bool,
error: Exception,
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
):
self.run_after_task_execution(
run_id=run_id,
Expand All @@ -405,6 +416,8 @@ def post_task_execute(
results=results,
success=success,
error=error,
spawning_task_id=spawning_task_id,
purpose=purpose,
)

@abc.abstractmethod
Expand All @@ -416,6 +429,8 @@ def run_before_task_execution(
nodes: List[HamiltonNode],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
**future_kwargs,
):
"""Implement this to run something after task execution. Tasks are tols used to group nodes.
Expand All @@ -428,6 +443,8 @@ def run_before_task_execution(
:param inputs: Inputs to the task
:param overrides: Overrides passed to the task
:param future_kwargs: Reserved for backwards compatibility.
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass

Expand All @@ -441,6 +458,8 @@ def run_after_task_execution(
results: Optional[Dict[str, Any]],
success: bool,
error: Exception,
spawning_task_id: Optional[str],
purpose: NodeGroupPurpose,
**future_kwargs,
):
"""Implement this to run something after task execution. See note in run_before_task_execution.
Expand All @@ -452,6 +471,8 @@ def run_after_task_execution(
:param success: Whether the task was successful
:param error: The error the task threw, if any
:param future_kwargs: Reserved for backwards compatibility.
:param spawning_task_id: ID of the task that spawned this task
:param purpose: Purpose of the current task group
"""
pass

Expand Down Expand Up @@ -614,6 +635,42 @@ def validate_graph(
return self.run_to_validate_graph(graph=HamiltonGraph.from_graph(graph))


class TaskGroupingHook(BasePostTaskGroup, BasePostTaskExpand):
"""Implement this to run something after task grouping or task expansion. This will allow you to
capture information about the tasks during `Parallelize`/`Collect` blocks in dynamic DAG execution."""

@override
@final
def post_task_group(self, *, run_id: str, task_ids: List[str]):
return self.run_after_task_grouping(run_id=run_id, task_ids=task_ids)

@override
@final
def post_task_expand(self, *, run_id: str, task_id: str, parameters: Dict[str, Any]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idea -- maybe we allow parameters to be None? I think I'm overthikning this but want to guard against a future case where using force evaluates everything... If it's Optional[...] then we can at least guard against it. That said, it could also be an empty dict if we don't want to, and that could be specified by the hook, so let's not worry about it.

return self.run_after_task_expansion(run_id=run_id, task_id=task_id, parameters=parameters)

@abc.abstractmethod
def run_after_task_grouping(self, *, run_id: str, task_ids: List[str], **future_kwargs):
"""Hook that is called after task grouping.
:param run_id: ID of the run, unique in scope of the driver.
:param task_ids: List of tasks that were grouped together.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility.
"""
pass

@abc.abstractmethod
def run_after_task_expansion(
self, *, run_id: str, task_id: str, parameters: Dict[str, Any], **future_kwargs
):
"""Hook that is called after task expansion.
:param run_id: ID of the run, unique in scope of the driver.
:param task_id: ID of the task that was expanded.
:param parameters: Parameters that were passed to the task.
:param future_kwargs: Additional keyword arguments -- this is kept for backwards compatibility.
"""
pass


class GraphConstructionHook(BasePostGraphConstruct, abc.ABC):
"""Hook that is run after graph construction. This allows you to register/capture info on the graph.
Note that, in the case of materialization, this may be called multiple times (once when we create the graph,
Expand Down
Loading