Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchtr committed Mar 19, 2024
1 parent 3b903d8 commit 3ab0b88
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
26 changes: 13 additions & 13 deletions src/fondant/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def compile_local(args):
if args.extra_volumes:
extra_volumes.extend(args.extra_volumes)

pipeline = pipeline_from_string(args.ref)
pipeline = dataset_from_string(args.ref)
compiler = DockerCompiler()
compiler.compile(
pipeline=pipeline,
Expand All @@ -430,23 +430,23 @@ def compile_local(args):
def compile_kfp(args):
from fondant.dataset.compiler import KubeFlowCompiler

pipeline = pipeline_from_string(args.ref)
pipeline = dataset_from_string(args.ref)
compiler = KubeFlowCompiler()
compiler.compile(pipeline=pipeline, output_path=args.output_path)


def compile_vertex(args):
from fondant.dataset.compiler import VertexCompiler

pipeline = pipeline_from_string(args.ref)
pipeline = dataset_from_string(args.ref)
compiler = VertexCompiler()
compiler.compile(pipeline=pipeline, output_path=args.output_path)


def compile_sagemaker(args):
from fondant.dataset.compiler import SagemakerCompiler

pipeline = pipeline_from_string(args.ref)
pipeline = dataset_from_string(args.ref)
compiler = SagemakerCompiler()
compiler.compile(
pipeline=pipeline,
Expand Down Expand Up @@ -612,7 +612,7 @@ def run_local(args):
extra_volumes.extend(args.extra_volumes)

try:
ref = pipeline_from_string(args.ref)
ref = dataset_from_string(args.ref)
except ModuleNotFoundError:
ref = args.ref

Expand All @@ -632,7 +632,7 @@ def run_kfp(args):
msg = "--host argument is required for running on Kubeflow"
raise ValueError(msg)
try:
ref = pipeline_from_string(args.ref)
ref = dataset_from_string(args.ref)
except ModuleNotFoundError:
ref = args.ref

Expand All @@ -644,7 +644,7 @@ def run_vertex(args):
from fondant.dataset.runner import VertexRunner

try:
ref = pipeline_from_string(args.ref)
ref = dataset_from_string(args.ref)
except ModuleNotFoundError:
ref = args.ref

Expand All @@ -661,7 +661,7 @@ def run_sagemaker(args):
from fondant.dataset.runner import SagemakerRunner

try:
ref = pipeline_from_string(args.ref)
ref = dataset_from_string(args.ref)
except ModuleNotFoundError:
ref = args.ref

Expand Down Expand Up @@ -761,8 +761,8 @@ def _called_with_wrong_args(f):
del tb


def pipeline_from_string(string_ref: str) -> Dataset: # noqa: PLR0912
"""Get the pipeline from the provided string reference.
def dataset_from_string(string_ref: str) -> Dataset: # noqa: PLR0912
"""Get the dataset from the provided string reference.
Inspired by Flask:
https://github.com/pallets/flask/blob/d611989/src/flask/cli.py#L112
Expand All @@ -776,7 +776,7 @@ def pipeline_from_string(string_ref: str) -> Dataset: # noqa: PLR0912
The pipeline obtained from the provided string
"""
if ":" not in string_ref:
return pipeline_from_module(string_ref)
return dataset_from_module(string_ref)

module_str, pipeline_str = string_ref.split(":")

Expand Down Expand Up @@ -856,8 +856,8 @@ def pipeline_from_string(string_ref: str) -> Dataset: # noqa: PLR0912
)


def pipeline_from_module(module_str: str) -> Dataset:
"""Try to import a pipeline from a string otherwise raise an ImportFromStringError."""
def dataset_from_module(module_str: str) -> Dataset:
"""Try to import a dataset from a string otherwise raise an ImportFromStringError."""
from fondant.dataset import Dataset

module = get_module(module_str)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
compile_sagemaker,
compile_vertex,
component_from_module,
dataset_from_string,
execute,
get_module,
pipeline_from_string,
run_kfp,
run_local,
run_vertex,
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_component_from_module_error(module_str):
)
def test_pipeline_from_module(module_str):
"""Test that pipeline_from_string works."""
pipeline = pipeline_from_string(module_str)
pipeline = dataset_from_string(module_str)
assert pipeline.name == "test_pipeline"


Expand Down Expand Up @@ -165,13 +165,13 @@ def test_pipeline_from_module(module_str):
def test_pipeline_from_module_error(module_str):
"""Test different error cases for pipeline_from_string."""
with pytest.raises(PipelineImportError):
pipeline_from_string(module_str)
dataset_from_string(module_str)


def test_factory_error_propagated():
"""Test that an error in the factory method is correctly propagated."""
with pytest.raises(NotImplementedError):
pipeline_from_string("examples.example_modules.pipeline:not_implemented")
dataset_from_string("examples.example_modules.pipeline:not_implemented")


def test_execute_logic(monkeypatch):
Expand Down

0 comments on commit 3ab0b88

Please sign in to comment.