Skip to content
This repository was archived by the owner on Jun 13, 2023. It is now read-only.

Identify cycles in the pipeline after DAG generation #7

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cf_pipelines/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Union

import networkx as nx
from ploomber import DAG
from ploomber.executors import Serial
from ploomber.io import serializer_pickle, unserializer_pickle
Expand All @@ -15,6 +16,7 @@

from cf_pipelines.base.helper_classes import FunctionDetails, ProductLineage
from cf_pipelines.base.utils import get_return_keys_from_function, remove_extension, wrap_preserving_signature
from cf_pipelines.exceptions import CycledPipelineError


class Pipeline:
Expand Down Expand Up @@ -287,6 +289,16 @@ def make_dag(self) -> DAG:
for function_name, dependencies in solved_dependencies.items():
for dependency in dependencies:
callables[dependency] >> callables[function_name]

try:
# TODO: investigate how reliable is accessing `_G`
# TODO: ask for a way to find cycles eagerly without building (executing) the pipeline?
cycles = nx.find_cycle(dag._G)
raise CycledPipelineError(cycles)
except nx.NetworkXNoCycle:
# Ironically, this means everything is good! so we just ignore this exception
pass

return dag

def generate_run_id(self) -> str:
Expand Down
10 changes: 10 additions & 0 deletions cf_pipelines/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import List, Tuple


class CycledPipelineError(Exception):
def __init__(self, cycles: List[Tuple[str, str]]):
self.cycles = cycles

message = "\n".join([f'f "{f1}" depends on "{f2}".' for f1, f2 in cycles])

super().__init__("Your pipeline contains cycles: " + message)
42 changes: 42 additions & 0 deletions tests/base/test_pipelines_loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from cf_pipelines import Pipeline
from cf_pipelines.exceptions import CycledPipelineError


@pytest.fixture
def simple_looped_pipeline(parse_indented):
simple = Pipeline("Simple Pipeline")

@simple.step("step_1")
def one(*, one):
return {"one.txt": None}

return simple


@pytest.fixture
def looped_pipeline(parse_indented):
simple = Pipeline("Simple Pipeline")

@simple.step("step")
def f_one():
return {"two.txt": None}

@simple.step("step")
def f_two(*, two, four):
return {"three.txt": None}

@simple.step("step")
def f_three(*, three):
return {"four.txt": None}

return simple


@pytest.mark.parametrize("pipeline", ["simple_looped_pipeline", "looped_pipeline"])
def test_make_dag_fail_when_looped(pipeline, request):
actual_pipeline = request.getfixturevalue(pipeline)

with pytest.raises(CycledPipelineError):
actual_pipeline.make_dag()