Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eliminate redundant literal conversion for Iterator[JSON] type #2602

Merged
merged 4 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import tempfile
import typing
from dataclasses import dataclass, field, fields
from typing import get_args
from typing import Iterator, get_args

import rich_click as click
from mashumaro.codecs.json import JSONEncoder
from rich.progress import Progress
from typing_extensions import get_origin

from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal
from flytekit.clis.sdk_in_container.helpers import patch_image_config
Expand Down Expand Up @@ -538,10 +539,21 @@ def _run(*args, **kwargs):
for input_name, v in entity.python_interface.inputs_with_defaults.items():
processed_click_value = kwargs.get(input_name)
optional_v = False

skip_default_value_selection = False
if processed_click_value is None and isinstance(v, typing.Tuple):
optional_v = is_optional(v[0])
if len(v) == 2:
processed_click_value = v[1]
if entity_type == "workflow" and hasattr(v[0], "__args__"):
origin_base_type = get_origin(v[0])
if inspect.isclass(origin_base_type) and issubclass(origin_base_type, Iterator): # Iterator
args = getattr(v[0], "__args__")
if isinstance(args, tuple) and get_origin(args[0]) is typing.Union: # Iterator[JSON]
logger.debug(f"Detected Iterator[JSON] in {entity.name} input annotations...")
skip_default_value_selection = True

if not skip_default_value_selection:
optional_v = is_optional(v[0])
if len(v) == 2:
processed_click_value = v[1]
if isinstance(processed_click_value, ArtifactQuery):
if run_level_params.is_remote:
click.secho(
Expand Down
214 changes: 196 additions & 18 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,34 @@
import pytest
import yaml
from click.testing import CliRunner
from flytekit.loggers import logging, logger

from flytekit.clis.sdk_in_container import pyflyte
from flytekit.clis.sdk_in_container.run import RunLevelParams, get_entities_in_file, run_command
from flytekit.clis.sdk_in_container.run import (
RunLevelParams,
get_entities_in_file,
run_command,
)
from flytekit.configuration import Config, Image, ImageConfig
from flytekit.core.task import task
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec, calculate_hash_from_image_spec
from flytekit.image_spec.image_spec import (
ImageBuildEngine,
ImageSpec,
calculate_hash_from_image_spec,
)
from flytekit.interaction.click_types import DirParamType, FileParamType
from flytekit.remote import FlyteRemote
from typing import Iterator
from flytekit.types.iterator import JSON
from flytekit import workflow


pytest.importorskip("pandas")

REMOTE_WORKFLOW_FILE = "https://raw.githubusercontent.com/flyteorg/flytesnacks/8337b64b33df046b2f6e4cba03c74b7bdc0c4fb1/cookbook/core/flyte_basics/basic_workflow.py"
IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py")
IMPERATIVE_WORKFLOW_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py"
)
DIR_NAME = os.path.dirname(os.path.realpath(__file__))


Expand All @@ -46,7 +61,9 @@ def workflow_file(request, tmp_path_factory):
@pytest.fixture
def remote():
with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client:
flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote = FlyteRemote(
config=Config.auto(), default_project="p1", default_domain="d1"
)
flyte_remote._client = mock_client
return flyte_remote

Expand All @@ -70,7 +87,9 @@ def test_pyflyte_run_wf(remote, remote_flag, workflow_file):
with mock.patch("flytekit.configuration.plugin.FlyteRemote"):
runner = CliRunner()
result = runner.invoke(
pyflyte.main, ["run", remote_flag, workflow_file, "my_wf", "--help"], catch_exceptions=False
pyflyte.main,
["run", remote_flag, workflow_file, "my_wf", "--help"],
catch_exceptions=False,
)

assert result.exit_code == 0
Expand All @@ -81,7 +100,9 @@ def test_pyflyte_run_with_labels():
with mock.patch("flytekit.configuration.plugin.FlyteRemote"):
runner = CliRunner()
result = runner.invoke(
pyflyte.main, ["run", "--remote", str(workflow_file), "my_wf", "--help"], catch_exceptions=False
pyflyte.main,
["run", "--remote", str(workflow_file), "my_wf", "--help"],
catch_exceptions=False,
)
assert result.exit_code == 0

Expand All @@ -100,7 +121,16 @@ def test_copy_all_files():
runner = CliRunner()
result = runner.invoke(
pyflyte.main,
["run", "--copy-all", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"],
[
"run",
"--copy-all",
IMPERATIVE_WORKFLOW_FILE,
"wf",
"--in1",
"hello",
"--in2",
"world",
],
catch_exceptions=False,
)
assert result.exit_code == 0
Expand Down Expand Up @@ -176,7 +206,13 @@ def test_pyflyte_run_cli(workflow_file):

@pytest.mark.parametrize(
"input",
["1", os.path.join(DIR_NAME, "testdata/df.parquet"), '{"x":1.0, "y":2.0}', "2020-05-01", "RED"],
[
"1",
os.path.join(DIR_NAME, "testdata/df.parquet"),
'{"x":1.0, "y":2.0}',
"2020-05-01",
"RED",
],
)
def test_union_type1(input):
runner = CliRunner()
Expand Down Expand Up @@ -300,7 +336,10 @@ def test_nested_workflow(working_dir, wf_path, monkeypatch: pytest.MonkeyPatch):
],
catch_exceptions=False,
)
assert result.stdout.strip() == "Running Execution on local.\nRunning Execution on local."
assert (
result.stdout.strip()
== "Running Execution on local.\nRunning Execution on local."
)
assert result.exit_code == 0


Expand All @@ -325,12 +364,18 @@ def test_list_default_arguments(wf_path):

# default case, what comes from click if no image is specified, the click param is configured to use the default.
ic_result_1 = ImageConfig(
default_image=Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest"),
images=[Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest")],
default_image=Image(
name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest"
),
images=[
Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest")
],
)
# test that command line args are merged with the file
ic_result_2 = ImageConfig(
default_image=Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"),
default_image=Image(
name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"
),
images=[
Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"),
Image(name="asdf", fqn="ghcr.io/asdf/asdf", tag="latest"),
Expand All @@ -345,7 +390,9 @@ def test_list_default_arguments(wf_path):
)
# test that command line args override the file
ic_result_3 = ImageConfig(
default_image=Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"),
default_image=Image(
name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"
),
images=[
Image(name="default", fqn="cr.flyte.org/flyteorg/flytekit", tag="py3.9-latest"),
Image(name="xyz", fqn="ghcr.io/asdf/asdf", tag="latest"),
Expand Down Expand Up @@ -395,21 +442,29 @@ def test_list_default_arguments(wf_path):
reason="Github macos-latest image does not have docker installed as per https://github.com/orgs/community/discussions/25777",
)
def test_pyflyte_run_run(
mock_image, image_string, leaf_configuration_file_name, final_image_config, mock_image_spec_builder
mock_image,
image_string,
leaf_configuration_file_name,
final_image_config,
mock_image_spec_builder,
):
mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest"
ImageBuildEngine.register("test", mock_image_spec_builder)

@task
def tk():
...
def tk(): ...

mock_click_ctx = mock.MagicMock()
mock_remote = mock.MagicMock()
image_tuple = (image_string,)
image_config = ImageConfig.validate_image(None, "", image_tuple)

pp = pathlib.Path(__file__).parent.parent.parent / "configuration" / "configs" / leaf_configuration_file_name
pp = (
pathlib.Path(__file__).parent.parent.parent
/ "configuration"
/ "configs"
/ leaf_configuration_file_name
)

obj = RunLevelParams(
project="p",
Expand All @@ -429,6 +484,125 @@ def check_image(*args, **kwargs):
run_command(mock_click_ctx, tk)()


def jsons():
for x in [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"},
],
},
},
]:
yield x


@mock.patch("flytekit.configuration.default_images.DefaultImages.default_image")
def test_pyflyte_run_with_iterator_json_type(
mock_image, mock_image_spec_builder, caplog
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

caplog ❤️

A lot of the tests in this test suite should use caplog instead. We can tackle this in a separate PR.

):
mock_image.return_value = "cr.flyte.org/flyteorg/flytekit:py3.9-latest"
ImageBuildEngine.register(
"test",
mock_image_spec_builder,
)

@task
def t1(x: Iterator[JSON]) -> Iterator[JSON]:
return x

@workflow
def tk(x: Iterator[JSON] = jsons()) -> Iterator[JSON]:
return t1(x=x)

@task
def t2(x: list[int]) -> list[int]:
return x

@workflow
def tk_list(x: list[int] = [1, 2, 3]) -> list[int]:
return t2(x=x)

@task
def t3(x: Iterator[int]) -> Iterator[int]:
return x

@workflow
def tk_simple_iterator(x: Iterator[int] = iter([1, 2, 3])) -> Iterator[int]:
return t3(x=x)

mock_click_ctx = mock.MagicMock()
mock_remote = mock.MagicMock()
image_tuple = ("ghcr.io/flyteorg/mydefault:py3.9-latest",)
image_config = ImageConfig.validate_image(None, "", image_tuple)

pp = (
pathlib.Path(__file__).parent.parent.parent
/ "configuration"
/ "configs"
/ "no_images.yaml"
)

obj = RunLevelParams(
project="p",
domain="d",
image_config=image_config,
remote=True,
config_file=str(pp),
)
obj._remote = mock_remote
mock_click_ctx.obj = obj

def check_image(*args, **kwargs):
assert kwargs["image_config"] == ic_result_1

mock_remote.register_script.side_effect = check_image

logger.propagate = True
with caplog.at_level(logging.DEBUG, logger="flytekit"):
run_command(mock_click_ctx, tk)()
assert any(
"Detected Iterator[JSON] in pyflyte.test_run.tk input annotations..."
in message[2]
for message in caplog.record_tuples
)

caplog.clear()

with caplog.at_level(logging.DEBUG, logger="flytekit"):
run_command(mock_click_ctx, tk_list)()
assert not any(
"Detected Iterator[JSON] in pyflyte.test_run.tk_list input annotations..."
in message[2]
for message in caplog.record_tuples
)

caplog.clear()

with caplog.at_level(logging.DEBUG, logger="flytekit"):
run_command(mock_click_ctx, t1)()
assert not any(
"Detected Iterator[JSON] in pyflyte.test_run.t1 input annotations..."
in message[2]
for message in caplog.record_tuples
)

caplog.clear()

with caplog.at_level(logging.DEBUG, logger="flytekit"):
run_command(mock_click_ctx, tk_simple_iterator)()
assert not any(
"Detected Iterator[JSON] in pyflyte.test_run.tk_simple_iterator input annotations..."
in message[2]
for message in caplog.record_tuples
)


def test_file_param():
m = mock.MagicMock()
flyte_file = FileParamType().convert(__file__, m, m)
Expand Down Expand Up @@ -484,7 +658,11 @@ def test_pyflyte_run_with_none(a_val, workflow_file):
"envs, envs_argument, expected_output",
[
(["--env", "MY_ENV_VAR=hello"], '["MY_ENV_VAR"]', "hello"),
(["--env", "MY_ENV_VAR=hello", "--env", "ABC=42"], '["MY_ENV_VAR","ABC"]', "hello,42"),
(
["--env", "MY_ENV_VAR=hello", "--env", "ABC=42"],
'["MY_ENV_VAR","ABC"]',
"hello,42",
),
],
)
@pytest.mark.parametrize(
Expand Down
Loading