-
Notifications
You must be signed in to change notification settings - Fork 167
Description
Is your feature request related to a problem? Please describe.
I hit a bit of a snag while creating some custom multi-level progress bar lifecycle adapters for task-based parallel DAGs (with rich for the curious). Currently, for task-based DAGs, TaskExecutionHook will only fire before and after a task is executed. The hooks have no knowledge of the overall task landscape, including:
- Number (and index) of tasks in the current group
- Overall groups in the graph
- Details about the expander task parameterization
- Type of current task (expander, collector, etc.)
- Spawning task ID (if available)
Note: Item 1 was originally discussed on Slack: https://hamilton-opensource.slack.com/archives/C03MANME6G5/p1728403433108319
Describe the solution you'd like
After speaking with @elijahbenizzy, an initial implementation for item 1 was suggested that modifies the TaskImplementation object to store the current task index and the total number of tasks. This information would then be wired through various methods in the ExecutionState class and be eventually passed to the lifecycle hooks run_after_task_execution and run_before_task_execution on TaskExecutionHook.
While implementing the above in a test branch (https://github.com/cswartzvi/hamilton/tree/update_task_execution_hook) I found that it was still difficult to create a multi-level progress bar without some of the information in item 2-5. To that end I also added:
spawning_task_idandpurposeto the methods and hooks associated withTaskExecutionHook- Created a new hook
post_task_groupthat runs after the tasks are grouped - Created a new hook
post_task_expandthat runs after the expander task is parameterized
With these additional changes (also in the branch above) I was able to create my coveted multi-level progress bar:
class TaskProgressHook(TaskExecutionHook, TaskGroupingHook, GraphExecutionHook):
def __init__(self) -> None:
self._console = rich.console.Console()
self._progress = rich.progress.Progress(console=self._console)
def run_before_graph_execution(self, **kwargs: Any):
pass
def run_after_graph_execution(self, **kwargs: Any):
self._progress.stop() # in case progress thread is lagging
def run_after_task_grouping(self, *, tasks: List[TaskSpec], **kwargs):
self._progress.add_task("Running Task Groups:", total=len(tasks))
self._progress.start()
def run_after_task_expansion(self, *, parameters: dict[str, Any], **kwargs):
self._progress.add_task("Running Parallelizable:", total=len(parameters))
def run_before_task_execution(self, *, purpose: NodeGroupPurpose, **kwargs):
if purpose == NodeGroupPurpose.GATHER:
self._progress.advance(self._progress.task_ids[0])
self._progress.stop_task(self._progress.task_ids[-1])
def run_after_task_execution(self, *, purpose: NodeGroupPurpose, **kwargs):
if purpose == NodeGroupPurpose.EXECUTE_BLOCK:
self._progress.advance(self._progress.task_ids[-1])
else:
self._progress.advance(self._progress.task_ids[0])Maybe I reached a little too far with this for my own selfish goals 😄, either way please let me know if you would be interested in a PR for any, or all, of the changes to the task lifecycle adapters (heck, I would also be willing to add rich plugins if you like that as well). Thanks!
Additional context
Currently, the build-in lifecycle adapter ProgressBar has an indeterminate length for task-based DAGs.
