From 829f544310a51dcf2d0635a305dc08f4a6974619 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 31 Oct 2024 02:53:31 -0700 Subject: [PATCH] Add API to get immediate successors --- metaflow/client/core.py | 82 +++++++++++++++++++++++++++++++++-------- metaflow/task.py | 15 +++++--- 2 files changed, 76 insertions(+), 21 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 0f287ccb177..2f330e99219 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1123,21 +1123,21 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" - def immediate_ancestors(self) -> Dict[str, Iterator["Task"]]: + def immediate_ancestors(self) -> Dict[str, List[str]]: """ - Returns a dictionary with iterators over the immediate ancestors of this task. + Returns a dictionary of immediate ancestors task ids of this task for each + previous step. Returns ------- - Dict[str, Iterator[Task]] + Dict[str, List[str]] Dictionary of immediate ancestors of this task. The keys are the - names of the ancestors steps and the values are iterators over the - tasks of the corresponding steps. + names of the ancestors steps and the values are the corresponding + task ids of the ancestors. """ def _prev_task(flow_id, run_id, previous_step): # Find any previous task for current step - step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False) task = next(iter(step.tasks()), None) if task: @@ -1146,10 +1146,6 @@ def _prev_task(flow_id, run_id, previous_step): flow_id, run_id, step_name, task_id = self.path_components previous_steps = self.metadata_dict.get("previous_steps", None) - print( - f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}" - ) - print(f"previous_steps: {previous_steps}") if not previous_steps or len(previous_steps) == 0: return @@ -1165,9 +1161,6 @@ def _prev_task(flow_id, run_id, previous_step): prev_task.metadata_dict.get("foreach-stack", []) ) - print( - f"prev_foreach_stack_len: {prev_foreach_stack_len}, cur_foreach_stack_len: {cur_foreach_stack_len}" - ) if prev_foreach_stack_len == cur_foreach_stack_len: field_name = "foreach-indices" field_value = self.metadata_dict.get(field_name) @@ -1181,16 +1174,73 @@ def _prev_task(flow_id, run_id, previous_step): field_value = self.metadata_dict.get("foreach-indices-truncated") for prev_step in previous_steps: - # print(f"For task {self.pathspec}, findding parent tasks for step {prev_step} with {field_name} and " - # f"{field_value}") ancestor_iters[prev_step] = ( self._metaflow.metadata.filter_tasks_by_metadata( flow_id, run_id, step_name, prev_step, field_name, field_value ) ) - return ancestor_iters + def immediate_successors(self) -> Dict[str, List[str]]: + """ + Returns a dictionary of immediate successors task ids of this task for each + previous step. + + Returns + ------- + Dict[str, List[str]] + Dictionary of immediate successors of this task. The keys are the + names of the successors steps and the values are the corresponding + task ids of the successors. + """ + + def _successor_task(flow_id, run_id, successor_step): + # Find any previous task for current step + step = Step(f"{flow_id}/{run_id}/{successor_step}", _namespace_check=False) + task = next(iter(step.tasks()), None) + if task: + return task + raise MetaflowNotFound(f"No successor task found for step {successor_step}") + + flow_id, run_id, step_name, task_id = self.path_components + successor_steps = self.metadata_dict.get("successor_steps", None) + + if not successor_steps or len(successor_steps) == 0: + return + + cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) + successor_iters = {} + if len(successor_steps) > 1: + # This is a static split, so there is no change in foreach stack length + successor_foreach_stack_len = cur_foreach_stack_len + else: + successor_task = _successor_task(flow_id, run_id, successor_steps[0]) + successor_foreach_stack_len = len( + successor_task.metadata_dict.get("foreach-stack", []) + ) + + if successor_foreach_stack_len == cur_foreach_stack_len: + field_name = "foreach-indices" + field_value = self.metadata_dict.get(field_name) + elif successor_foreach_stack_len > cur_foreach_stack_len: + # We will compare the foreach-indices value of current task with the + # foreach-indices-truncated value of tasks in successor steps + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # We will compare the foreach-stack-truncated value of current task with the + # foreach-stack value of tasks in successor steps + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + + for successor_step in successor_steps: + successor_iters[successor_step] = ( + self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, successor_step, field_name, field_value + ) + ) + return successor_iters + # def closest_siblings(self) -> Iterator["Task"]: # """ # Returns an iterator over the closest siblings of this task. diff --git a/metaflow/task.py b/metaflow/task.py index d157cee195a..4ba66847c26 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -47,14 +47,13 @@ def _dynamic_runtime_metadata(foreach_stack): return foreach_indices, foreach_indices_truncated, foreach_step_names def _static_runtime_metadata(self, graph_info, step_name): - if step_name == "start": - return [] - - return [ + prev_steps = [ node_name for node_name, attributes in graph_info["steps"].items() if step_name in attributes["next"] ] + succesor_steps = graph_info["steps"][step_name]["next"] + return prev_steps, succesor_steps def __init__( self, @@ -670,7 +669,7 @@ def run_step( ) } ) - previous_steps = self._static_runtime_metadata( + previous_steps, successor_steps = self._static_runtime_metadata( self.flow._graph_info, step_name ) for deco in decorators: @@ -791,6 +790,12 @@ def run_step( type="previous_steps", tags=metadata_tags, ), + MetaDatum( + field="successor_steps", + value=successor_steps, + type="successor_steps", + tags=metadata_tags, + ), ], )