Skip to content

Commit

Permalink
Add static and runtime dag info, API to fetch ancestor tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Oct 31, 2024
1 parent a37555b commit 16cb6f8
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 3 deletions.
80 changes: 80 additions & 0 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,86 @@ def _iter_filter(self, x):
# exclude private data artifacts
return x.id[0] != "_"

def immediate_ancestors(self) -> Dict[str, Iterator["Task"]]:
"""
Returns a dictionary with iterators over the immediate ancestors of this task.
Returns
-------
Dict[str, Iterator[Task]]
Dictionary of immediate ancestors of this task. The keys are the
names of the ancestors steps and the values are iterators over the
tasks of the corresponding steps.
"""

def _prev_task(flow_id, run_id, previous_step):
# Find any previous task for current step

step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False)
task = next(iter(step.tasks()), None)
if task:
return task
raise MetaflowNotFound(f"No previous task found for step {previous_step}")

flow_id, run_id, step_name, task_id = self.path_components
previous_steps = self.metadata_dict.get("previous_steps", None)
print(
f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}"
)
print(f"previous_steps: {previous_steps}")

if not previous_steps or len(previous_steps) == 0:
return

cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", []))
ancestor_iters = {}
if len(previous_steps) > 1:
# This is a static join, so there is no change in foreach stack length
prev_foreach_stack_len = cur_foreach_stack_len
else:
prev_task = _prev_task(flow_id, run_id, previous_steps[0])
prev_foreach_stack_len = len(
prev_task.metadata_dict.get("foreach-stack", [])
)

print(
f"prev_foreach_stack_len: {prev_foreach_stack_len}, cur_foreach_stack_len: {cur_foreach_stack_len}"
)
if prev_foreach_stack_len == cur_foreach_stack_len:
field_name = "foreach-indices"
field_value = self.metadata_dict.get(field_name)
elif prev_foreach_stack_len > cur_foreach_stack_len:
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
else:
# We will compare the foreach-stack-truncated value of current task with the
# foreach-stack value of tasks in previous steps
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")

for prev_step in previous_steps:
# print(f"For task {self.pathspec}, findding parent tasks for step {prev_step} with {field_name} and "
# f"{field_value}")
ancestor_iters[prev_step] = (
self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step_name, prev_step, field_name, field_value
)
)

return ancestor_iters

# def closest_siblings(self) -> Iterator["Task"]:
# """
# Returns an iterator over the closest siblings of this task.
#
# Returns
# -------
# Iterator[Task]
# Iterator over the closest siblings of this task
# """
# flow_id, run_id, step_name, task_id = self.path_components
# print(f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}")

@property
def metadata(self) -> List[Metadata]:
"""
Expand Down
15 changes: 15 additions & 0 deletions metaflow/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,21 @@ def _register_system_metadata(self, run_id, step_name, task_id, attempt):
if metadata:
self.register_metadata(run_id, step_name, task_id, metadata)

@classmethod
def _filter_tasks_by_metadata(
cls, flow_id, run_id, step_name, prev_step, field_name, field_value
):
raise NotImplementedError()

@classmethod
def filter_tasks_by_metadata(
cls, flow_id, run_id, step_name, prev_step, field_name, field_value
):
task_ids = cls._filter_tasks_by_metadata(
flow_id, run_id, step_name, prev_step, field_name, field_value
)
return task_ids

@staticmethod
def _apply_filter(elts, filters):
if filters is None:
Expand Down
65 changes: 62 additions & 3 deletions metaflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import sys
import os
import time
import json
import hashlib
import traceback

from types import MethodType, FunctionType
Expand Down Expand Up @@ -37,6 +39,23 @@ class MetaflowTask(object):
MetaflowTask prepares a Flow instance for execution of a single step.
"""

@staticmethod
def _dynamic_runtime_metadata(foreach_stack):
foreach_indices = [foreach_frame.index for foreach_frame in foreach_stack]
foreach_indices_truncated = foreach_indices[:-1]
foreach_step_names = [foreach_frame.step for foreach_frame in foreach_stack]
return foreach_indices, foreach_indices_truncated, foreach_step_names

def _static_runtime_metadata(self, graph_info, step_name):
if step_name == "start":
return []

return [
node_name
for node_name, attributes in graph_info["steps"].items()
if step_name in attributes["next"]
]

def __init__(
self,
flow,
Expand Down Expand Up @@ -495,6 +514,33 @@ def run_step(
)
)

# Add runtime dag info
foreach_indices, foreach_indices_truncated, foreach_step_names = (
self._dynamic_runtime_metadata(foreach_stack)
)
metadata.extend(
[
MetaDatum(
field="foreach-indices",
value=foreach_indices,
type="foreach-indices",
tags=metadata_tags,
),
MetaDatum(
field="foreach-indices-truncated",
value=foreach_indices_truncated,
type="foreach-indices-truncated",
tags=metadata_tags,
),
MetaDatum(
field="foreach-step-names",
value=foreach_step_names,
type="foreach-step-names",
tags=metadata_tags,
),
]
)

self.metadata.register_metadata(
run_id,
step_name,
Expand Down Expand Up @@ -564,12 +610,17 @@ def run_step(
self.flow._success = False
self.flow._task_ok = None
self.flow._exception = None

# Note: All internal flow attributes (ie: non-user artifacts)
# should either be set prior to running the user code or listed in
# FlowSpec._EPHEMERAL to allow for proper merging/importing of
# user artifacts in the user's step code.

if join_type:
if join_type == "foreach":
# We only want to persist one of the input paths
self.flow._input_paths = str(input_paths[0])

# Join step:

# Ensure that we have the right number of inputs. The
Expand Down Expand Up @@ -619,7 +670,9 @@ def run_step(
)
}
)

previous_steps = self._static_runtime_metadata(
self.flow._graph_info, step_name
)
for deco in decorators:
deco.task_pre_step(
step_name,
Expand Down Expand Up @@ -730,8 +783,14 @@ def run_step(
field="attempt_ok",
value=attempt_ok,
type="internal_attempt_status",
tags=["attempt_id:{0}".format(retry_count)],
)
tags=metadata_tags,
),
MetaDatum(
field="previous_steps",
value=previous_steps,
type="previous_steps",
tags=metadata_tags,
),
],
)

Expand Down

0 comments on commit 16cb6f8

Please sign in to comment.