diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 2d19cb4dbed..8abc75eba19 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,35 +1,29 @@ import enum +import io import json import os import pathlib import shutil import sys -import io +from typing import Iterator, List import mock import pytest import yaml from click.testing import CliRunner -from flytekit.loggers import logging, logger +from flytekit import workflow from flytekit.clis.sdk_in_container import pyflyte -from flytekit.clis.sdk_in_container.run import ( - RunLevelParams, - get_entities_in_file, - run_command, -) +from flytekit.clis.sdk_in_container.run import (RunLevelParams, + get_entities_in_file, + run_command) from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.task import task -from flytekit.image_spec.image_spec import ( - ImageBuildEngine, - ImageSpec, -) +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec from flytekit.interaction.click_types import DirParamType, FileParamType +from flytekit.loggers import logger, logging from flytekit.remote import FlyteRemote -from typing import Iterator, List from flytekit.types.iterator import JSON -from flytekit import workflow - pytest.importorskip("pandas") @@ -205,7 +199,7 @@ def test_pyflyte_run_cli(workflow_file): "--s", json.dumps({"x": {"i": 1, "a": ["h", "e"]}}), "--t", - json.dumps({"i": [{"i":1,"a":["h","e"]}]}), + json.dumps({"i": [{"i": 1, "a": ["h", "e"]}]}), ], catch_exceptions=False, ) @@ -238,6 +232,37 @@ def test_union_type1(input): assert result.exit_code == 0 +@pytest.mark.parametrize( + "extra_cli_args, task_name, expected_output", + [ + (("--a",), "test_task_boolean", True), + (("--no_a",), "test_task_boolean", False), + + (tuple(), "test_task_boolean_default_true", True), + (("--a",), "test_task_boolean_default_true", True), + (("--no_a",), "test_task_boolean_default_true", False), + + (tuple(), "test_task_boolean_default_false", False), + (("--a",), "test_task_boolean_default_false", True), + (("--no_a",), "test_task_boolean_default_false", False), + ], +) +def test_boolean_type(extra_cli_args, task_name, expected_output): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + task_name, + *extra_cli_args, + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert str(expected_output) in result.stdout + + def test_all_types_with_json_input(): runner = CliRunner() result = runner.invoke( @@ -247,7 +272,9 @@ def test_all_types_with_json_input(): os.path.join(DIR_NAME, "workflow.py"), "my_wf", "--inputs-file", - os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"), + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json" + ), ], catch_exceptions=False, ) @@ -259,7 +286,15 @@ def test_all_types_with_yaml_input(): result = runner.invoke( pyflyte.main, - ["run", os.path.join(DIR_NAME, "workflow.py"), "my_wf", "--inputs-file", os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.yaml")], + [ + "run", + os.path.join(DIR_NAME, "workflow.py"), + "my_wf", + "--inputs-file", + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "my_wf_input.yaml" + ), + ], catch_exceptions=False, ) assert result.exit_code == 0, result.stdout @@ -267,7 +302,16 @@ def test_all_types_with_yaml_input(): def test_all_types_with_pipe_input(monkeypatch): runner = CliRunner() - input= str(json.load(open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json"),"r"))) + input = str( + json.load( + open( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "my_wf_input.json" + ), + "r", + ) + ) + ) monkeypatch.setattr("sys.stdin", io.StringIO(input)) result = runner.invoke( pyflyte.main, @@ -827,4 +871,7 @@ def test_entity_non_found_in_file(): catch_exceptions=False, ) assert result.exit_code == 1 - assert "FlyteEntityNotFoundException: Task/Workflow \'my_wffffff\' not found in module \n\'pyflyte.workflow\'" in result.stdout + assert ( + "FlyteEntityNotFoundException: Task/Workflow 'my_wffffff' not found in module \n'pyflyte.workflow'" + in result.stdout + ) diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 104538c3382..496f5c28b9f 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -80,6 +80,18 @@ def test_union1(a: typing.Union[int, FlyteFile, typing.Dict[str, float], datetim def test_union2(a: typing.Union[float, typing.List[int], MyDataclass]): print(a) +@task +def test_task_boolean(a: bool): + print(a) + +@task +def test_task_boolean_default_true(a: bool = True): + print(a) + +@task +def test_task_boolean_default_false(a: bool = False): + print(a) + @workflow def my_wf(