Skip to content

Commit

Permalink
fix: duplicate code while merging
Browse files Browse the repository at this point in the history
  • Loading branch information
mao3267 committed Aug 1, 2024
1 parent dc741d4 commit ec52ab8
Showing 1 changed file with 23 additions and 106 deletions.
129 changes: 23 additions & 106 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@
ImageConfig,
SerializationSettings,
)
from flytekit.configuration import (
DefaultImages,
FastSerializationSettings,
ImageConfig,
SerializationSettings,
)
from flytekit.configuration.plugin import get_plugin
from flytekit.core import context_manager
from flytekit.core.artifact import ArtifactQuery
Expand All @@ -53,11 +47,6 @@
key_value_callback,
labels_callback,
)
from flytekit.interaction.click_types import (
FlyteLiteralConverter,
key_value_callback,
labels_callback,
)
from flytekit.interaction.string_literals import literal_string_repr
from flytekit.loggers import logger
from flytekit.models import security
Expand All @@ -71,13 +60,6 @@
FlyteWorkflow,
remote_fs,
)
from flytekit.remote import (
FlyteLaunchPlan,
FlyteRemote,
FlyteTask,
FlyteWorkflow,
remote_fs,
)
from flytekit.remote.executions import FlyteWorkflowExecution
from flytekit.tools import module_loader
from flytekit.tools.script_mode import _find_project_root, compress_scripts
Expand All @@ -93,9 +75,7 @@ class RunLevelComputedParams:

project_root: typing.Optional[str] = None
module: typing.Optional[str] = None
temp_file_name: typing.Optional[str] = (
None # Used to store the temporary location of the file downloaded
)
temp_file_name: typing.Optional[str] = None # Used to store the temporary location of the file downloaded


@dataclass
Expand Down Expand Up @@ -306,19 +286,15 @@ class RunLevelParams(PyFlyteParams):
)
)

computed_params: RunLevelComputedParams = field(
default_factory=RunLevelComputedParams
)
computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams)
_remote: typing.Optional[FlyteRemote] = None

def remote_instance(self) -> FlyteRemote:
if self._remote is None:
data_upload_location = None
if self.is_remote:
data_upload_location = remote_fs.REMOTE_PLACEHOLDER
self._remote = get_plugin().get_remote(
self.config_file, self.project, self.domain, data_upload_location
)
self._remote = get_plugin().get_remote(self.config_file, self.project, self.domain, data_upload_location)
return self._remote

@property
Expand All @@ -337,25 +313,19 @@ def options(cls) -> typing.List[click.Option]:
return [get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata]


def load_naive_entity(
module_name: str, entity_name: str, project_root: str
) -> typing.Union[WorkflowBase, PythonTask]:
def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]:
"""
Load the workflow of a script file.
N.B.: it assumes that the file is self-contained, in other words, there are no relative imports.
"""
flyte_ctx_builder = (
context_manager.FlyteContextManager.current_context().new_builder()
)
flyte_ctx_builder = context_manager.FlyteContextManager.current_context().new_builder()
with context_manager.FlyteContextManager.with_context(flyte_ctx_builder):
with module_loader.add_sys_path(project_root):
importlib.import_module(module_name)
return module_loader.load_object_from_module(f"{module_name}.{entity_name}")


def dump_flyte_remote_snippet(
execution: FlyteWorkflowExecution, project: str, domain: str
):
def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, domain: str):
click.secho(
f"""
In order to have programmatic access to the execution, use the following snippet:
Expand Down Expand Up @@ -394,9 +364,7 @@ def get_entities_in_file(filename: pathlib.Path, should_delete: bool) -> Entitie
additional_path = str(pathlib.Path.cwd())
else:
additional_path = _find_project_root(filename)
module_name = str(filename.relative_to(additional_path).with_suffix("")).replace(
os.path.sep, "."
)
module_name = str(filename.relative_to(additional_path).with_suffix("")).replace(os.path.sep, ".")
with context_manager.FlyteContextManager.with_context(flyte_ctx):
with module_loader.add_sys_path(additional_path):
importlib.import_module(module_name)
Expand Down Expand Up @@ -453,9 +421,7 @@ def to_click_option(
description_extra = f": {json.dumps(literal_var.type.metadata)}"

# If a query has been specified, the input is never strictly required at this layer
required = (
False if default_val and isinstance(default_val, ArtifactQuery) else required
)
required = False if default_val and isinstance(default_val, ArtifactQuery) else required

return click.Option(
param_decls=[f"--{input_name}"],
Expand Down Expand Up @@ -541,8 +507,6 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder:
file_access = FileAccessProvider(
local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"),
raw_output_prefix=output_prefix,
local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"),
raw_output_prefix=output_prefix,
)

# The task might run on a remote machine if raw_output_prefix is a remote path,
Expand All @@ -555,11 +519,6 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder:
str(archive_fname),
params.computed_params.module,
)
compress_scripts(
params.computed_params.project_root,
str(archive_fname),
params.computed_params.module,
)
remote_dir = file_access.get_random_remote_directory()
remote_archive_fname = f"{remote_dir}/script_mode.tar.gz"
file_access.put_data(str(archive_fname), remote_archive_fname)
Expand All @@ -582,14 +541,10 @@ def is_optional(_type):
"""
Checks if the given type is Optional Type
"""
return typing.get_origin(_type) is typing.Union and type(None) in typing.get_args(
_type
)
return typing.get_origin(_type) is typing.Union and type(None) in typing.get_args(_type)


def run_command(
ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]
):
def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]):
"""
Returns a function that is used to implement WorkflowCommand and execute a flyte workflow.
"""
Expand Down Expand Up @@ -645,9 +600,7 @@ def _run(*args, **kwargs):
inputs[input_name] = False

if not run_level_params.is_remote:
with FlyteContextManager.with_context(
_update_flyte_context(run_level_params)
):
with FlyteContextManager.with_context(_update_flyte_context(run_level_params)):
if run_level_params.envvars:
for env_var, value in run_level_params.envvars.items():
os.environ[env_var] = value
Expand All @@ -666,9 +619,7 @@ def _run(*args, **kwargs):
image_config = run_level_params.image_config
image_config = patch_image_config(config_file, image_config)

with context_manager.FlyteContextManager.with_context(
remote.context.new_builder()
):
with context_manager.FlyteContextManager.with_context(remote.context.new_builder()):
remote_entity = remote.register_script(
entity,
project=run_level_params.project,
Expand Down Expand Up @@ -711,21 +662,15 @@ def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs)
self._launcher = launcher
self._entity = None

def _fetch_entity(
self, ctx: click.Context
) -> typing.Union[FlyteLaunchPlan, FlyteTask]:
def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask]:
if self._entity:
return self._entity
run_level_params: RunLevelParams = ctx.obj
r = run_level_params.remote_instance()
if self._launcher == self.LP_LAUNCHER:
entity = r.fetch_launch_plan(
run_level_params.project, run_level_params.domain, self._entity_name
)
entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name)
else:
entity = r.fetch_task(
run_level_params.project, run_level_params.domain, self._entity_name
)
entity = r.fetch_task(run_level_params.project, run_level_params.domain, self._entity_name)
self._entity = entity
return entity

Expand Down Expand Up @@ -778,9 +723,7 @@ def get_params(self, ctx: click.Context) -> typing.List["click.Parameter"]:
entity.default_inputs.parameters,
)
else:
self.params = self._get_params(
ctx, entity.interface.inputs, types
)
self.params = self._get_params(ctx, entity.interface.inputs, types)

return super().get_params(ctx)

Expand Down Expand Up @@ -820,26 +763,18 @@ def __init__(self, command_name: str):
self._command_name = command_name
self._entities = []

def _get_entities(
self, r: FlyteRemote, project: str, domain: str, limit: int
) -> typing.List[str]:
def _get_entities(self, r: FlyteRemote, project: str, domain: str, limit: int) -> typing.List[str]:
"""
Retreieves the right entities from the remote flyte instance.
"""
if self._command_name == self.LAUNCHPLAN_COMMAND:
lps = r.client.list_launch_plan_ids_paginated(
project=project, domain=domain, limit=limit
)
lps = r.client.list_launch_plan_ids_paginated(project=project, domain=domain, limit=limit)
return [l.name for l in lps[0]]
elif self._command_name == self.WORKFLOW_COMMAND:
wfs = r.client.list_workflow_ids_paginated(
project=project, domain=domain, limit=limit
)
wfs = r.client.list_workflow_ids_paginated(project=project, domain=domain, limit=limit)
return [w.name for w in wfs[0]]
elif self._command_name == self.TASK_COMMAND:
tasks = r.client.list_task_ids_paginated(
project=project, domain=domain, limit=limit
)
tasks = r.client.list_task_ids_paginated(project=project, domain=domain, limit=limit)
return [t.name for t in tasks[0]]
return []

Expand All @@ -854,10 +789,6 @@ def list_commands(self, ctx):
f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...",
total=None,
)
task = progress.add_task(
f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...",
total=None,
)
with progress:
progress.start_task(task)
try:
Expand All @@ -866,10 +797,6 @@ def list_commands(self, ctx):
run_level_params.project,
run_level_params.domain,
run_level_params.limit,
r,
run_level_params.project,
run_level_params.domain,
run_level_params.limit,
)
return self._entities
except FlyteSystemException as e:
Expand Down Expand Up @@ -1006,10 +933,6 @@ def _create_command(

# Add options for each of the workflow inputs
params = []
for (
input_name,
input_type_val,
) in loaded_entity.python_interface.inputs_with_defaults.items():
for (
input_name,
input_type_val,
Expand Down Expand Up @@ -1077,9 +1000,7 @@ def get_command(self, ctx, exe_entity):

entity = load_naive_entity(module, exe_entity, project_root)

return self._create_command(
ctx, exe_entity, run_level_params, entity, is_workflow
)
return self._create_command(ctx, exe_entity, run_level_params, entity, is_workflow)


class RunCommand(click.RichGroup):
Expand All @@ -1099,9 +1020,7 @@ def __init__(self, *args, **kwargs):
def list_commands(self, ctx, add_remote: bool = True):
if self._files:
return self._files
self._files = [
str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"
]
self._files = [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"]
self._files = sorted(self._files)
if add_remote:
self._files = self._files + [
Expand All @@ -1125,9 +1044,7 @@ def get_command(self, ctx, filename):
return RemoteEntityGroup(RemoteEntityGroup.WORKFLOW_COMMAND)
elif filename == RemoteEntityGroup.TASK_COMMAND:
return RemoteEntityGroup(RemoteEntityGroup.TASK_COMMAND)
return WorkflowCommand(
filename, name=filename, help=f"Run a [workflow|task] from {filename}"
)
return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}")


_run_help = """
Expand Down

0 comments on commit ec52ab8

Please sign in to comment.