Skip to content

Commit

Permalink
Add API to get immediate successors
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Oct 31, 2024
1 parent 16cb6f8 commit 829f544
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 21 deletions.
82 changes: 66 additions & 16 deletions metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,21 +1123,21 @@ def _iter_filter(self, x):
# exclude private data artifacts
return x.id[0] != "_"

def immediate_ancestors(self) -> Dict[str, Iterator["Task"]]:
def immediate_ancestors(self) -> Dict[str, List[str]]:
"""
Returns a dictionary with iterators over the immediate ancestors of this task.
Returns a dictionary of immediate ancestors task ids of this task for each
previous step.
Returns
-------
Dict[str, Iterator[Task]]
Dict[str, List[str]]
Dictionary of immediate ancestors of this task. The keys are the
names of the ancestors steps and the values are iterators over the
tasks of the corresponding steps.
names of the ancestors steps and the values are the corresponding
task ids of the ancestors.
"""

def _prev_task(flow_id, run_id, previous_step):
# Find any previous task for current step

step = Step(f"{flow_id}/{run_id}/{previous_step}", _namespace_check=False)
task = next(iter(step.tasks()), None)
if task:
Expand All @@ -1146,10 +1146,6 @@ def _prev_task(flow_id, run_id, previous_step):

flow_id, run_id, step_name, task_id = self.path_components
previous_steps = self.metadata_dict.get("previous_steps", None)
print(
f"flow_id: {flow_id}, run_id: {run_id}, step_name: {step_name}, task_id: {task_id}"
)
print(f"previous_steps: {previous_steps}")

if not previous_steps or len(previous_steps) == 0:
return
Expand All @@ -1165,9 +1161,6 @@ def _prev_task(flow_id, run_id, previous_step):
prev_task.metadata_dict.get("foreach-stack", [])
)

print(
f"prev_foreach_stack_len: {prev_foreach_stack_len}, cur_foreach_stack_len: {cur_foreach_stack_len}"
)
if prev_foreach_stack_len == cur_foreach_stack_len:
field_name = "foreach-indices"
field_value = self.metadata_dict.get(field_name)
Expand All @@ -1181,16 +1174,73 @@ def _prev_task(flow_id, run_id, previous_step):
field_value = self.metadata_dict.get("foreach-indices-truncated")

for prev_step in previous_steps:
# print(f"For task {self.pathspec}, findding parent tasks for step {prev_step} with {field_name} and "
# f"{field_value}")
ancestor_iters[prev_step] = (
self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step_name, prev_step, field_name, field_value
)
)

return ancestor_iters

def immediate_successors(self) -> Dict[str, List[str]]:
"""
Returns a dictionary of immediate successors task ids of this task for each
previous step.
Returns
-------
Dict[str, List[str]]
Dictionary of immediate successors of this task. The keys are the
names of the successors steps and the values are the corresponding
task ids of the successors.
"""

def _successor_task(flow_id, run_id, successor_step):
# Find any previous task for current step
step = Step(f"{flow_id}/{run_id}/{successor_step}", _namespace_check=False)
task = next(iter(step.tasks()), None)
if task:
return task
raise MetaflowNotFound(f"No successor task found for step {successor_step}")

flow_id, run_id, step_name, task_id = self.path_components
successor_steps = self.metadata_dict.get("successor_steps", None)

if not successor_steps or len(successor_steps) == 0:
return

cur_foreach_stack_len = len(self.metadata_dict.get("foreach-stack", []))
successor_iters = {}
if len(successor_steps) > 1:
# This is a static split, so there is no change in foreach stack length
successor_foreach_stack_len = cur_foreach_stack_len
else:
successor_task = _successor_task(flow_id, run_id, successor_steps[0])
successor_foreach_stack_len = len(
successor_task.metadata_dict.get("foreach-stack", [])
)

if successor_foreach_stack_len == cur_foreach_stack_len:
field_name = "foreach-indices"
field_value = self.metadata_dict.get(field_name)
elif successor_foreach_stack_len > cur_foreach_stack_len:
# We will compare the foreach-indices value of current task with the
# foreach-indices-truncated value of tasks in successor steps
field_name = "foreach-indices-truncated"
field_value = self.metadata_dict.get("foreach-indices")
else:
# We will compare the foreach-stack-truncated value of current task with the
# foreach-stack value of tasks in successor steps
field_name = "foreach-indices"
field_value = self.metadata_dict.get("foreach-indices-truncated")

for successor_step in successor_steps:
successor_iters[successor_step] = (
self._metaflow.metadata.filter_tasks_by_metadata(
flow_id, run_id, step_name, successor_step, field_name, field_value
)
)
return successor_iters

# def closest_siblings(self) -> Iterator["Task"]:
# """
# Returns an iterator over the closest siblings of this task.
Expand Down
15 changes: 10 additions & 5 deletions metaflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ def _dynamic_runtime_metadata(foreach_stack):
return foreach_indices, foreach_indices_truncated, foreach_step_names

def _static_runtime_metadata(self, graph_info, step_name):
if step_name == "start":
return []

return [
prev_steps = [
node_name
for node_name, attributes in graph_info["steps"].items()
if step_name in attributes["next"]
]
succesor_steps = graph_info["steps"][step_name]["next"]
return prev_steps, succesor_steps

def __init__(
self,
Expand Down Expand Up @@ -670,7 +669,7 @@ def run_step(
)
}
)
previous_steps = self._static_runtime_metadata(
previous_steps, successor_steps = self._static_runtime_metadata(
self.flow._graph_info, step_name
)
for deco in decorators:
Expand Down Expand Up @@ -791,6 +790,12 @@ def run_step(
type="previous_steps",
tags=metadata_tags,
),
MetaDatum(
field="successor_steps",
value=successor_steps,
type="successor_steps",
tags=metadata_tags,
),
],
)

Expand Down

0 comments on commit 829f544

Please sign in to comment.