From 16cb6f86ed26a52c91ef1e014b1e832283ae65d7 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 21 Oct 2024 15:35:17 -0700 Subject: [PATCH] Add static and runtime dag info, API to fetch ancestor tasks --- metaflow/client/core.py | 80 +++++++++++++++++++++++++++++++++++ metaflow/metadata/metadata.py | 15 +++++++ metaflow/task.py | 65 ++++++++++++++++++++++++++-- 3 files changed, 157 insertions(+), 3 deletions(-) diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 9534ffcca2c..0f287ccb177 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1123,6 +1123,86 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" + def immediate_ancestors(self) -> Dict[str, Iterator["Task"]]: + """ + Returns a dictionary with iterators over the immediate ancestors of this task. + + Returns + ------- + Dict[str, Iterator[Task]] + Dictionary of immediate ancestors of this task. The keys are the + names of the ancestors steps and the values are iterators over the + tasks of the corresponding steps. + """ + + def _prev_task(flow_id, run_id, previous_step): + # Find any previous task for current step + + step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False) + task = next(iter(step.tasks()), None) + if task: + return task + raise MetaflowNotFound(f"No previous task found for step {previous_step}") + + flow_id, run_id, step_name, task_id = self.path_components + previous_steps = self.metadata_dict.get("previous_steps", None) + print( + f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}" + ) + print(f"previous_steps: {previous_steps}") + + if not previous_steps or len(previous_steps) == 0: + return + + cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", [])) + ancestor_iters = {} + if len(previous_steps) > 1: + # This is a static join, so there is no change in foreach stack length + prev_foreach_stack_len = cur_foreach_stack_len + else: + prev_task = _prev_task(flow_id, run_id, previous_steps[0]) + prev_foreach_stack_len = len( + prev_task.metadata_dict.get("foreach-stack", []) + ) + + print( + f"prev_foreach_stack_len: {prev_foreach_stack_len}, cur_foreach_stack_len: {cur_foreach_stack_len}" + ) + if prev_foreach_stack_len == cur_foreach_stack_len: + field_name = "foreach-indices" + field_value = self.metadata_dict.get(field_name) + elif prev_foreach_stack_len > cur_foreach_stack_len: + field_name = "foreach-indices-truncated" + field_value = self.metadata_dict.get("foreach-indices") + else: + # We will compare the foreach-stack-truncated value of current task with the + # foreach-stack value of tasks in previous steps + field_name = "foreach-indices" + field_value = self.metadata_dict.get("foreach-indices-truncated") + + for prev_step in previous_steps: + # print(f"For task {self.pathspec}, findding parent tasks for step {prev_step} with {field_name} and " + # f"{field_value}") + ancestor_iters[prev_step] = ( + self._metaflow.metadata.filter_tasks_by_metadata( + flow_id, run_id, step_name, prev_step, field_name, field_value + ) + ) + + return ancestor_iters + + # def closest_siblings(self) -> Iterator["Task"]: + # """ + # Returns an iterator over the closest siblings of this task. + # + # Returns + # ------- + # Iterator[Task] + # Iterator over the closest siblings of this task + # """ + # flow_id, run_id, step_name, task_id = self.path_components + # print(f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}") + @property def metadata(self) -> List[Metadata]: """ diff --git a/metaflow/metadata/metadata.py b/metaflow/metadata/metadata.py index 11c3873a85e..a2d607eff7a 100644 --- a/metaflow/metadata/metadata.py +++ b/metaflow/metadata/metadata.py @@ -672,6 +672,21 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt): if metadata: self.register_metadata(run_id, step_name, task_id, metadata) + @classmethod + def _filter_tasks_by_metadata( + cls, flow_id, run_id, step_name, prev_step, field_name, field_value + ): + raise NotImplementedError() + + @classmethod + def filter_tasks_by_metadata( + cls, flow_id, run_id, step_name, prev_step, field_name, field_value + ): + task_ids = cls._filter_tasks_by_metadata( + flow_id, run_id, step_name, prev_step, field_name, field_value + ) + return task_ids + @staticmethod def _apply_filter(elts, filters): if filters is None: diff --git a/metaflow/task.py b/metaflow/task.py index bccaf47c668..d157cee195a 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -4,6 +4,8 @@ import sys import os import time +import json +import hashlib import traceback from types import MethodType, FunctionType @@ -37,6 +39,23 @@ class MetaflowTask(object): MetaflowTask prepares a Flow instance for execution of a single step. """ + @staticmethod + def _dynamic_runtime_metadata(foreach_stack): + foreach_indices = [foreach_frame.index for foreach_frame in foreach_stack] + foreach_indices_truncated = foreach_indices[:-1] + foreach_step_names = [foreach_frame.step for foreach_frame in foreach_stack] + return foreach_indices, foreach_indices_truncated, foreach_step_names + + def _static_runtime_metadata(self, graph_info, step_name): + if step_name == "start": + return [] + + return [ + node_name + for node_name, attributes in graph_info["steps"].items() + if step_name in attributes["next"] + ] + def __init__( self, flow, @@ -495,6 +514,33 @@ def run_step( ) ) + # Add runtime dag info + foreach_indices, foreach_indices_truncated, foreach_step_names = ( + self._dynamic_runtime_metadata(foreach_stack) + ) + metadata.extend( + [ + MetaDatum( + field="foreach-indices", + value=foreach_indices, + type="foreach-indices", + tags=metadata_tags, + ), + MetaDatum( + field="foreach-indices-truncated", + value=foreach_indices_truncated, + type="foreach-indices-truncated", + tags=metadata_tags, + ), + MetaDatum( + field="foreach-step-names", + value=foreach_step_names, + type="foreach-step-names", + tags=metadata_tags, + ), + ] + ) + self.metadata.register_metadata( run_id, step_name, @@ -564,12 +610,17 @@ def run_step( self.flow._success = False self.flow._task_ok = None self.flow._exception = None + # Note: All internal flow attributes (ie: non-user artifacts) # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of # user artifacts in the user's step code. if join_type: + if join_type == "foreach": + # We only want to persist one of the input paths + self.flow._input_paths = str(input_paths[0]) + # Join step: # Ensure that we have the right number of inputs. The @@ -619,7 +670,9 @@ def run_step( ) } ) - + previous_steps = self._static_runtime_metadata( + self.flow._graph_info, step_name + ) for deco in decorators: deco.task_pre_step( step_name, @@ -730,8 +783,14 @@ def run_step( field="attempt_ok", value=attempt_ok, type="internal_attempt_status", - tags=["attempt_id:{0}".format(retry_count)], - ) + tags=metadata_tags, + ), + MetaDatum( + field="previous_steps", + value=previous_steps, + type="previous_steps", + tags=metadata_tags, + ), ], )