diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 664b576c80..e375475a10 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -34,12 +34,6 @@ ImageConfig, SerializationSettings, ) -from flytekit.configuration import ( - DefaultImages, - FastSerializationSettings, - ImageConfig, - SerializationSettings, -) from flytekit.configuration.plugin import get_plugin from flytekit.core import context_manager from flytekit.core.artifact import ArtifactQuery @@ -53,11 +47,6 @@ key_value_callback, labels_callback, ) -from flytekit.interaction.click_types import ( - FlyteLiteralConverter, - key_value_callback, - labels_callback, -) from flytekit.interaction.string_literals import literal_string_repr from flytekit.loggers import logger from flytekit.models import security @@ -71,13 +60,6 @@ FlyteWorkflow, remote_fs, ) -from flytekit.remote import ( - FlyteLaunchPlan, - FlyteRemote, - FlyteTask, - FlyteWorkflow, - remote_fs, -) from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader from flytekit.tools.script_mode import _find_project_root, compress_scripts @@ -93,9 +75,7 @@ class RunLevelComputedParams: project_root: typing.Optional[str] = None module: typing.Optional[str] = None - temp_file_name: typing.Optional[str] = ( - None # Used to store the temporary location of the file downloaded - ) + temp_file_name: typing.Optional[str] = None # Used to store the temporary location of the file downloaded @dataclass @@ -306,9 +286,7 @@ class RunLevelParams(PyFlyteParams): ) ) - computed_params: RunLevelComputedParams = field( - default_factory=RunLevelComputedParams - ) + computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams) _remote: typing.Optional[FlyteRemote] = None def remote_instance(self) -> FlyteRemote: @@ -316,9 +294,7 @@ def remote_instance(self) -> FlyteRemote: data_upload_location = None if self.is_remote: data_upload_location = remote_fs.REMOTE_PLACEHOLDER - self._remote = get_plugin().get_remote( - self.config_file, self.project, self.domain, data_upload_location - ) + self._remote = get_plugin().get_remote(self.config_file, self.project, self.domain, data_upload_location) return self._remote @property @@ -337,25 +313,19 @@ def options(cls) -> typing.List[click.Option]: return [get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata] -def load_naive_entity( - module_name: str, entity_name: str, project_root: str -) -> typing.Union[WorkflowBase, PythonTask]: +def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]: """ Load the workflow of a script file. N.B.: it assumes that the file is self-contained, in other words, there are no relative imports. """ - flyte_ctx_builder = ( - context_manager.FlyteContextManager.current_context().new_builder() - ) + flyte_ctx_builder = context_manager.FlyteContextManager.current_context().new_builder() with context_manager.FlyteContextManager.with_context(flyte_ctx_builder): with module_loader.add_sys_path(project_root): importlib.import_module(module_name) return module_loader.load_object_from_module(f"{module_name}.{entity_name}") -def dump_flyte_remote_snippet( - execution: FlyteWorkflowExecution, project: str, domain: str -): +def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, domain: str): click.secho( f""" In order to have programmatic access to the execution, use the following snippet: @@ -394,9 +364,7 @@ def get_entities_in_file(filename: pathlib.Path, should_delete: bool) -> Entitie additional_path = str(pathlib.Path.cwd()) else: additional_path = _find_project_root(filename) - module_name = str(filename.relative_to(additional_path).with_suffix("")).replace( - os.path.sep, "." - ) + module_name = str(filename.relative_to(additional_path).with_suffix("")).replace(os.path.sep, ".") with context_manager.FlyteContextManager.with_context(flyte_ctx): with module_loader.add_sys_path(additional_path): importlib.import_module(module_name) @@ -453,9 +421,7 @@ def to_click_option( description_extra = f": {json.dumps(literal_var.type.metadata)}" # If a query has been specified, the input is never strictly required at this layer - required = ( - False if default_val and isinstance(default_val, ArtifactQuery) else required - ) + required = False if default_val and isinstance(default_val, ArtifactQuery) else required return click.Option( param_decls=[f"--{input_name}"], @@ -541,8 +507,6 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder: file_access = FileAccessProvider( local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=output_prefix, - local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), - raw_output_prefix=output_prefix, ) # The task might run on a remote machine if raw_output_prefix is a remote path, @@ -555,11 +519,6 @@ def _update_flyte_context(params: RunLevelParams) -> FlyteContext.Builder: str(archive_fname), params.computed_params.module, ) - compress_scripts( - params.computed_params.project_root, - str(archive_fname), - params.computed_params.module, - ) remote_dir = file_access.get_random_remote_directory() remote_archive_fname = f"{remote_dir}/script_mode.tar.gz" file_access.put_data(str(archive_fname), remote_archive_fname) @@ -582,14 +541,10 @@ def is_optional(_type): """ Checks if the given type is Optional Type """ - return typing.get_origin(_type) is typing.Union and type(None) in typing.get_args( - _type - ) + return typing.get_origin(_type) is typing.Union and type(None) in typing.get_args(_type) -def run_command( - ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask] -): +def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): """ Returns a function that is used to implement WorkflowCommand and execute a flyte workflow. """ @@ -645,9 +600,7 @@ def _run(*args, **kwargs): inputs[input_name] = False if not run_level_params.is_remote: - with FlyteContextManager.with_context( - _update_flyte_context(run_level_params) - ): + with FlyteContextManager.with_context(_update_flyte_context(run_level_params)): if run_level_params.envvars: for env_var, value in run_level_params.envvars.items(): os.environ[env_var] = value @@ -666,9 +619,7 @@ def _run(*args, **kwargs): image_config = run_level_params.image_config image_config = patch_image_config(config_file, image_config) - with context_manager.FlyteContextManager.with_context( - remote.context.new_builder() - ): + with context_manager.FlyteContextManager.with_context(remote.context.new_builder()): remote_entity = remote.register_script( entity, project=run_level_params.project, @@ -711,21 +662,15 @@ def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs) self._launcher = launcher self._entity = None - def _fetch_entity( - self, ctx: click.Context - ) -> typing.Union[FlyteLaunchPlan, FlyteTask]: + def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask]: if self._entity: return self._entity run_level_params: RunLevelParams = ctx.obj r = run_level_params.remote_instance() if self._launcher == self.LP_LAUNCHER: - entity = r.fetch_launch_plan( - run_level_params.project, run_level_params.domain, self._entity_name - ) + entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name) else: - entity = r.fetch_task( - run_level_params.project, run_level_params.domain, self._entity_name - ) + entity = r.fetch_task(run_level_params.project, run_level_params.domain, self._entity_name) self._entity = entity return entity @@ -778,9 +723,7 @@ def get_params(self, ctx: click.Context) -> typing.List["click.Parameter"]: entity.default_inputs.parameters, ) else: - self.params = self._get_params( - ctx, entity.interface.inputs, types - ) + self.params = self._get_params(ctx, entity.interface.inputs, types) return super().get_params(ctx) @@ -820,26 +763,18 @@ def __init__(self, command_name: str): self._command_name = command_name self._entities = [] - def _get_entities( - self, r: FlyteRemote, project: str, domain: str, limit: int - ) -> typing.List[str]: + def _get_entities(self, r: FlyteRemote, project: str, domain: str, limit: int) -> typing.List[str]: """ Retreieves the right entities from the remote flyte instance. """ if self._command_name == self.LAUNCHPLAN_COMMAND: - lps = r.client.list_launch_plan_ids_paginated( - project=project, domain=domain, limit=limit - ) + lps = r.client.list_launch_plan_ids_paginated(project=project, domain=domain, limit=limit) return [l.name for l in lps[0]] elif self._command_name == self.WORKFLOW_COMMAND: - wfs = r.client.list_workflow_ids_paginated( - project=project, domain=domain, limit=limit - ) + wfs = r.client.list_workflow_ids_paginated(project=project, domain=domain, limit=limit) return [w.name for w in wfs[0]] elif self._command_name == self.TASK_COMMAND: - tasks = r.client.list_task_ids_paginated( - project=project, domain=domain, limit=limit - ) + tasks = r.client.list_task_ids_paginated(project=project, domain=domain, limit=limit) return [t.name for t in tasks[0]] return [] @@ -854,10 +789,6 @@ def list_commands(self, ctx): f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...", total=None, ) - task = progress.add_task( - f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...", - total=None, - ) with progress: progress.start_task(task) try: @@ -866,10 +797,6 @@ def list_commands(self, ctx): run_level_params.project, run_level_params.domain, run_level_params.limit, - r, - run_level_params.project, - run_level_params.domain, - run_level_params.limit, ) return self._entities except FlyteSystemException as e: @@ -1006,10 +933,6 @@ def _create_command( # Add options for each of the workflow inputs params = [] - for ( - input_name, - input_type_val, - ) in loaded_entity.python_interface.inputs_with_defaults.items(): for ( input_name, input_type_val, @@ -1077,9 +1000,7 @@ def get_command(self, ctx, exe_entity): entity = load_naive_entity(module, exe_entity, project_root) - return self._create_command( - ctx, exe_entity, run_level_params, entity, is_workflow - ) + return self._create_command(ctx, exe_entity, run_level_params, entity, is_workflow) class RunCommand(click.RichGroup): @@ -1099,9 +1020,7 @@ def __init__(self, *args, **kwargs): def list_commands(self, ctx, add_remote: bool = True): if self._files: return self._files - self._files = [ - str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py" - ] + self._files = [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"] self._files = sorted(self._files) if add_remote: self._files = self._files + [ @@ -1125,9 +1044,7 @@ def get_command(self, ctx, filename): return RemoteEntityGroup(RemoteEntityGroup.WORKFLOW_COMMAND) elif filename == RemoteEntityGroup.TASK_COMMAND: return RemoteEntityGroup(RemoteEntityGroup.TASK_COMMAND) - return WorkflowCommand( - filename, name=filename, help=f"Run a [workflow|task] from {filename}" - ) + return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}") _run_help = """