diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 183f5d8c5f..bcd0599df0 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -450,12 +450,12 @@ def to_click_option( required = False if default_val and isinstance(default_val, ArtifactQuery) else required if literal_converter.is_bool(): - click_cli_paramater_name = f"--{input_name}/--no_{input_name}" + click_cli_parameter_names = [f"--{input_name}/--no_{input_name}", f"--{input_name}/--no-{input_name.replace('_', '-')}"] else: - click_cli_paramater_name = f"--{input_name}" + click_cli_parameter_names = [f"--{input_name}"] return click.Option( - param_decls=[click_cli_paramater_name], + param_decls=click_cli_parameter_names, type=literal_converter.click_type, default=default_val, show_default=True, diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index ca5d20f48f..17d18023ba 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -241,16 +241,19 @@ def test_union_type1(input): @pytest.mark.parametrize( "extra_cli_args, task_name, expected_output", [ - (("--a",), "test_boolean", True), - (("--no_a",), "test_boolean", False), + (("--a_b",), "test_boolean", True), + (("--no_a_b",), "test_boolean", False), + (("--no-a-b",), "test_boolean", False), (tuple(), "test_boolean_default_true", True), - (("--a",), "test_boolean_default_true", True), - (("--no_a",), "test_boolean_default_true", False), + (("--a_b",), "test_boolean_default_true", True), + (("--no_a_b",), "test_boolean_default_true", False), + (("--no-a-b",), "test_boolean_default_true", False), (tuple(), "test_boolean_default_false", False), - (("--a",), "test_boolean_default_false", True), - (("--no_a",), "test_boolean_default_false", False), + (("--a_b",), "test_boolean_default_false", True), + (("--no_a_b",), "test_boolean_default_false", False), + (("--no-a-b",), "test_boolean_default_false", False), ], ) def test_boolean_type(extra_cli_args, task_name, expected_output): diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 7aca5d3bf4..2d65041439 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -81,16 +81,16 @@ def test_union2(a: typing.Union[float, typing.List[int], MyDataclass]): print(a) @task -def test_boolean(a: bool): - print(a) +def test_boolean(a_b: bool): + print(a_b) @task -def test_boolean_default_true(a: bool = True): - print(a) +def test_boolean_default_true(a_b: bool = True): + print(a_b) @task -def test_boolean_default_false(a: bool = False): - print(a) +def test_boolean_default_false(a_b: bool = False): + print(a_b) @workflow