Skip to content

Expanding Lifecycle Adapters for Dynamic DAGs / Parallel Execution #1196

@cswartzvi

Description

@cswartzvi

Is your feature request related to a problem? Please describe.

I hit a bit of a snag while creating some custom multi-level progress bar lifecycle adapters for task-based parallel DAGs (with rich for the curious). Currently, for task-based DAGs, TaskExecutionHook will only fire before and after a task is executed. The hooks have no knowledge of the overall task landscape, including:

  1. Number (and index) of tasks in the current group
  2. Overall groups in the graph
  3. Details about the expander task parameterization
  4. Type of current task (expander, collector, etc.)
  5. Spawning task ID (if available)

Note: Item 1 was originally discussed on Slack: https://hamilton-opensource.slack.com/archives/C03MANME6G5/p1728403433108319

Describe the solution you'd like

After speaking with @elijahbenizzy, an initial implementation for item 1 was suggested that modifies the TaskImplementation object to store the current task index and the total number of tasks. This information would then be wired through various methods in the ExecutionState class and be eventually passed to the lifecycle hooks run_after_task_execution and run_before_task_execution on TaskExecutionHook.

While implementing the above in a test branch (https://github.com/cswartzvi/hamilton/tree/update_task_execution_hook) I found that it was still difficult to create a multi-level progress bar without some of the information in item 2-5. To that end I also added:

  • spawning_task_id and purpose to the methods and hooks associated with TaskExecutionHook
  • Created a new hook post_task_group that runs after the tasks are grouped 
  • Created a new hook post_task_expand that runs after the expander task is parameterized

With these additional changes (also in the branch above) I was able to create my coveted multi-level progress bar:

class TaskProgressHook(TaskExecutionHook, TaskGroupingHook, GraphExecutionHook):

    def __init__(self) -> None:
        self._console = rich.console.Console()
        self._progress = rich.progress.Progress(console=self._console)

    def run_before_graph_execution(self, **kwargs: Any):
        pass

    def run_after_graph_execution(self, **kwargs: Any):
        self._progress.stop()  # in case progress thread is lagging

    def run_after_task_grouping(self, *, tasks: List[TaskSpec], **kwargs):
        self._progress.add_task("Running Task Groups:", total=len(tasks))
        self._progress.start()

    def run_after_task_expansion(self, *, parameters: dict[str, Any], **kwargs):
        self._progress.add_task("Running Parallelizable:", total=len(parameters))

    def run_before_task_execution(self, *, purpose: NodeGroupPurpose, **kwargs):
        if purpose == NodeGroupPurpose.GATHER:
            self._progress.advance(self._progress.task_ids[0])
            self._progress.stop_task(self._progress.task_ids[-1])

    def run_after_task_execution(self, *, purpose: NodeGroupPurpose, **kwargs):
        if purpose == NodeGroupPurpose.EXECUTE_BLOCK:
            self._progress.advance(self._progress.task_ids[-1])
        else:
            self._progress.advance(self._progress.task_ids[0])

Multi-Level-Progress

Maybe I reached a little too far with this for my own selfish goals 😄, either way please let me know if you would be interested in a PR for any, or all, of the changes to the task lifecycle adapters (heck, I would also be willing to add rich plugins if you like that as well). Thanks!

Additional context

Currently, the build-in lifecycle adapter ProgressBar has an indeterminate length for task-based DAGs.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions