Skip to content

Commit

Permalink
Write tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom-Newton committed Sep 23, 2024
1 parent fb95942 commit 9658cbe
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 19 deletions.
85 changes: 66 additions & 19 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,29 @@
import enum
import io
import json
import os
import pathlib
import shutil
import sys
import io
from typing import Iterator, List

import mock
import pytest
import yaml
from click.testing import CliRunner
from flytekit.loggers import logging, logger

from flytekit import workflow
from flytekit.clis.sdk_in_container import pyflyte
from flytekit.clis.sdk_in_container.run import (
RunLevelParams,
get_entities_in_file,
run_command,
)
from flytekit.clis.sdk_in_container.run import (RunLevelParams,
get_entities_in_file,
run_command)
from flytekit.configuration import Config, Image, ImageConfig
from flytekit.core.task import task
from flytekit.image_spec.image_spec import (
ImageBuildEngine,
ImageSpec,
)
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.interaction.click_types import DirParamType, FileParamType
from flytekit.loggers import logger, logging
from flytekit.remote import FlyteRemote
from typing import Iterator, List
from flytekit.types.iterator import JSON
from flytekit import workflow


pytest.importorskip("pandas")

Expand Down Expand Up @@ -205,7 +199,7 @@ def test_pyflyte_run_cli(workflow_file):
"--s",
json.dumps({"x": {"i": 1, "a": ["h", "e"]}}),
"--t",
json.dumps({"i": [{"i":1,"a":["h","e"]}]}),
json.dumps({"i": [{"i": 1, "a": ["h", "e"]}]}),
],
catch_exceptions=False,
)
Expand Down Expand Up @@ -238,6 +232,37 @@ def test_union_type1(input):
assert result.exit_code == 0


@pytest.mark.parametrize(
"extra_cli_args, task_name, expected_output",
[
(("--a",), "test_task_boolean", True),
(("--no_a",), "test_task_boolean", False),
(tuple(), "test_task_boolean_default_true", True),
(("--a",), "test_task_boolean_default_true", True),
(("--no_a",), "test_task_boolean_default_true", False),
(tuple(), "test_task_boolean_default_false", False),
(("--a",), "test_task_boolean_default_false", True),
(("--no_a",), "test_task_boolean_default_false", False),
],
)
def test_boolean_type(extra_cli_args, task_name, expected_output):
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
[
"run",
os.path.join(DIR_NAME, "workflow.py"),
task_name,
*extra_cli_args,
],
catch_exceptions=False,
)
assert result.exit_code == 0
assert str(expected_output) in result.stdout


def test_all_types_with_json_input():
runner = CliRunner()
result = runner.invoke(
Expand All @@ -247,7 +272,9 @@ def test_all_types_with_json_input():
os.path.join(DIR_NAME, "workflow.py"),
"my_wf",
"--inputs-file",
os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"),
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"
),
],
catch_exceptions=False,
)
Expand All @@ -259,15 +286,32 @@ def test_all_types_with_yaml_input():

result = runner.invoke(
pyflyte.main,
["run", os.path.join(DIR_NAME, "workflow.py"), "my_wf", "--inputs-file", os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.yaml")],
[
"run",
os.path.join(DIR_NAME, "workflow.py"),
"my_wf",
"--inputs-file",
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "my_wf_input.yaml"
),
],
catch_exceptions=False,
)
assert result.exit_code == 0, result.stdout


def test_all_types_with_pipe_input(monkeypatch):
runner = CliRunner()
input= str(json.load(open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"),"r")))
input = str(
json.load(
open(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"
),
"r",
)
)
)
monkeypatch.setattr("sys.stdin", io.StringIO(input))
result = runner.invoke(
pyflyte.main,
Expand Down Expand Up @@ -827,4 +871,7 @@ def test_entity_non_found_in_file():
catch_exceptions=False,
)
assert result.exit_code == 1
assert "FlyteEntityNotFoundException: Task/Workflow \'my_wffffff\' not found in module \n\'pyflyte.workflow\'" in result.stdout
assert (
"FlyteEntityNotFoundException: Task/Workflow 'my_wffffff' not found in module \n'pyflyte.workflow'"
in result.stdout
)
12 changes: 12 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ def test_union1(a: typing.Union[int, FlyteFile, typing.Dict[str, float], datetim
def test_union2(a: typing.Union[float, typing.List[int], MyDataclass]):
print(a)

@task
def test_task_boolean(a: bool):
print(a)

@task
def test_task_boolean_default_true(a: bool = True):
print(a)

@task
def test_task_boolean_default_false(a: bool = False):
print(a)


@workflow
def my_wf(
Expand Down

0 comments on commit 9658cbe

Please sign in to comment.