Skip to content

Commit

Permalink
GH-5768: Better pyflyte boolean parsing (#2764)
Browse files Browse the repository at this point in the history
* Add --no_{input_name}

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Write tests

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Autoformat

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Rename test tasks and fix test_get_entities_in_file

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Support - and _

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Fix lint warning

Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>

---------

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>
Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
Co-authored-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
  • Loading branch information
2 people authored and kumare3 committed Nov 8, 2024
1 parent 6dbb113 commit cbdc012
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
11 changes: 9 additions & 2 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,17 @@ def to_click_option(
# If a query has been specified, the input is never strictly required at this layer
required = False if default_val and isinstance(default_val, ArtifactQuery) else required

if literal_converter.is_bool():
click_cli_parameter_names = [
f"--{input_name}/--no_{input_name}",
f"--{input_name}/--no-{input_name.replace('_', '-')}",
]
else:
click_cli_parameter_names = [f"--{input_name}"]

return click.Option(
param_decls=[f"--{input_name}"],
param_decls=click_cli_parameter_names,
type=literal_converter.click_type,
is_flag=literal_converter.is_bool(),
default=default_val,
show_default=True,
required=required,
Expand Down
40 changes: 40 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,40 @@ def test_union_type1(input):
assert result.exit_code == 0


@pytest.mark.parametrize(
"extra_cli_args, task_name, expected_output",
[
(("--a_b",), "test_boolean", True),
(("--no_a_b",), "test_boolean", False),
(("--no-a-b",), "test_boolean", False),
(tuple(), "test_boolean_default_true", True),
(("--a_b",), "test_boolean_default_true", True),
(("--no_a_b",), "test_boolean_default_true", False),
(("--no-a-b",), "test_boolean_default_true", False),
(tuple(), "test_boolean_default_false", False),
(("--a_b",), "test_boolean_default_false", True),
(("--no_a_b",), "test_boolean_default_false", False),
(("--no-a-b",), "test_boolean_default_false", False),
],
)
def test_boolean_type(extra_cli_args, task_name, expected_output):
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
[
"run",
os.path.join(DIR_NAME, "workflow.py"),
task_name,
*extra_cli_args,
],
catch_exceptions=False,
)
assert result.exit_code == 0
assert str(expected_output) in result.stdout


def test_all_types_with_json_input():
runner = CliRunner()
result = runner.invoke(
Expand Down Expand Up @@ -391,6 +425,9 @@ def test_get_entities_in_file(workflow_file):
"task_with_env_vars",
"task_with_list",
"task_with_optional",
"test_boolean",
"test_boolean_default_false",
"test_boolean_default_true",
"test_union1",
"test_union2",
]
Expand All @@ -405,6 +442,9 @@ def test_get_entities_in_file(workflow_file):
"task_with_env_vars",
"task_with_list",
"task_with_optional",
"test_boolean",
"test_boolean_default_false",
"test_boolean_default_true",
"test_union1",
"test_union2",
]
Expand Down
12 changes: 12 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ def test_union1(a: typing.Union[int, FlyteFile, typing.Dict[str, float], datetim
def test_union2(a: typing.Union[float, typing.List[int], MyDataclass]):
print(a)

@task
def test_boolean(a_b: bool):
print(a_b)

@task
def test_boolean_default_true(a_b: bool = True):
print(a_b)

@task
def test_boolean_default_false(a_b: bool = False):
print(a_b)


@workflow
def my_wf(
Expand Down

0 comments on commit cbdc012

Please sign in to comment.