Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/reference/lifecycle-hooks/TaskReturnHook.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
================================
lifecycle.api.TaskReturnHook
================================


.. autoclass:: hamilton.lifecycle.api.TaskReturnHook
:special-members: __init__
:members:
:inherited-members:
9 changes: 9 additions & 0 deletions docs/reference/lifecycle-hooks/TaskSubmissionHook.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
================================
lifecycle.api.TaskSubmissionHook
================================


.. autoclass:: hamilton.lifecycle.api.TaskSubmissionHook
:special-members: __init__
:members:
:inherited-members:
2 changes: 2 additions & 0 deletions docs/reference/lifecycle-hooks/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ looking forward.
NodeExecutionMethod
StaticValidator
GraphConstructionHook
TaskSubmissionHook
TaskReturnHook
TaskExecutionHook
TaskGroupingHook

Expand Down
109 changes: 89 additions & 20 deletions hamilton/execution/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from concurrent.futures.process import ProcessPoolExecutor

from concurrent.futures import Executor, Future, ThreadPoolExecutor
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Protocol

from hamilton import node
from hamilton.execution.graph_functions import execute_subdag
Expand All @@ -20,13 +20,17 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class TaskFuture:
class TaskFuture(Protocol):
"""Simple representation of a future. TODO -- add cancel().
This a clean wrapper over a python future, and we may end up just using that at some point."""

get_state: Callable[[], TaskState]
get_result: Callable[[], Any]
def get_state(self) -> TaskState:
"""Returns the state of the task."""
...

def get_result(self) -> Any:
"""Returns the result of the task."""
...


class TaskExecutor(abc.ABC):
Expand Down Expand Up @@ -162,6 +166,39 @@ def base_execute_task(task: TaskImplementation) -> Dict[str, Any]:
return final_retval


@dataclasses.dataclass
class TaskFutureWrappingFunction(TaskFuture):
"""Wraps a python function call in a TaskFuture."""

def __init__(self, function: Callable[[], Any]):
self.function = function
self._results = None
self._done = False
self._exception = None

def get_state(self):
if self._exception is not None:
return TaskState.FAILED
if not self._done:
try:
self._results = self.function()
except Exception as e:
logger.exception("Task failed")
self._exception = e
return TaskState.FAILED
finally:
self._done = True
return TaskState.SUCCESSFUL

def get_result(self):
if self._exception is not None:
raise self._exception
if not self._done:
self._results = self.function()
self._done = True
return self._results


class SynchronousLocalTaskExecutor(TaskExecutor):
"""Basic synchronous/local task executor that runs tasks
in the same process, at submit time."""
Expand All @@ -172,9 +209,7 @@ def submit_task(self, task: TaskImplementation) -> TaskFuture:
:param task: Task to submit
:return: Future associated with this task
"""
# No error management for now
result = base_execute_task(task)
return TaskFuture(get_state=lambda: TaskState.SUCCESSFUL, get_result=lambda: result)
return TaskFutureWrappingFunction(functools.partial(base_execute_task, task))

def can_submit_task(self) -> bool:
"""We can always submit a task as the task submission is blocking!
Expand All @@ -190,6 +225,7 @@ def finalize(self):
pass


@dataclasses.dataclass
class TaskFutureWrappingPythonFuture(TaskFuture):
"""Wraps a python future in a TaskFuture"""

Expand Down Expand Up @@ -383,11 +419,22 @@ def run_graph_to_completion(
execution_manager.init()
try:
while not GraphState.is_terminal(execution_state.get_graph_state()):
# get the next task from the queue
# Get the next task from the queue
next_task = execution_state.release_next_task()
if next_task is not None:
task_executor = execution_manager.get_executor_for_task(next_task)
if task_executor.can_submit_task():
if next_task.adapter.does_hook("pre_task_submission", is_async=False):
next_task.adapter.call_all_lifecycle_hooks_sync(
"pre_task_submission",
run_id=next_task.run_id,
task_id=next_task.task_id,
nodes=next_task.nodes,
inputs=next_task.dynamic_inputs,
overrides=next_task.overrides,
spawning_task_id=next_task.spawning_task_id,
purpose=next_task.purpose,
)
try:
submitted = task_executor.submit_task(next_task)
except Exception as e:
Expand All @@ -396,20 +443,42 @@ def run_graph_to_completion(
f"{[item.name for item in next_task.nodes]}"
)
raise e
task_futures[next_task.task_id] = submitted
task_futures[next_task] = submitted
else:
# Whoops, back on the queue
# We should probably wait a bit here, but for now we're going to keep
# burning through
# Whoops, back on the queue. We should probably wait a bit here, but for
# now we're going to keep burning through
execution_state.reject_task(task_to_reject=next_task)
# update all the tasks in flight
# copy so we can modify
for task_name, task_future in task_futures.copy().items():

# Update all the tasks in flight (copy so we can modify)
for task, task_future in task_futures.copy().items():
result, error = None, None
state = task_future.get_state()
result = task_future.get_result()
execution_state.update_task_state(task_name, state, result)
if TaskState.is_terminal(state):
del task_futures[task_name]
try:
result = task_future.get_result()
except Exception as e:
logger.exception(
f"Exception resolving task {task.task_id}, with nodes: "
f"{[item.name for item in task.nodes]}"
)
error = e
finally:
execution_state.update_task_state(task.task_id, state, result)
if TaskState.is_terminal(state):
if task.adapter.does_hook("post_task_return", is_async=False):
task.adapter.call_all_lifecycle_hooks_sync(
"post_task_return",
run_id=task.run_id,
task_id=task.task_id,
nodes=task.nodes,
success=state == TaskState.SUCCESSFUL,
error=error,
result=result,
spawning_task_id=task.spawning_task_id,
purpose=task.purpose,
)
del task_futures[task]
if error:
raise error
logger.info(f"Graph is done, graph state is {execution_state.get_graph_state()}")
finally:
execution_manager.finalize()
8 changes: 8 additions & 0 deletions hamilton/execution/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ def __post_init__(self):
super(TaskImplementation, self).__post_init__()
self.task_id = self.determine_task_id(self.base_id, self.spawning_task_id, self.group_id)

def __hash__(self) -> int:
return hash(self.task_id)

def __eq__(self, other: object) -> bool:
if not isinstance(other, TaskImplementation):
return False
return self.task_id == other.task_id


class GroupingStrategy(abc.ABC):
"""Base class for grouping nodes"""
Expand Down
4 changes: 4 additions & 0 deletions hamilton/lifecycle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
StaticValidator,
TaskExecutionHook,
TaskGroupingHook,
TaskReturnHook,
TaskSubmissionHook,
)
from .base import LifecycleAdapter # noqa: F401
from .default import ( # noqa: F401
Expand Down Expand Up @@ -43,4 +45,6 @@
"TaskGroupingHook",
"FunctionInputOutputTypeChecker",
"NoEdgeAndInputTypeChecking",
"TaskReturnHook",
"TaskSubmissionHook",
]
Loading