Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Ray Fast Register #2606

Merged
merged 7 commits into from
Jul 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import base64
import json
import os
import typing
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional

import yaml
from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec
from flytekitplugins.ray.models import (
HeadGroupSpec,
RayCluster,
RayJob,
WorkerGroupSpec,
)
from google.protobuf.json_format import MessageToDict

from flytekit import lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.extend import TaskPlugins

Expand Down Expand Up @@ -40,6 +46,7 @@ class RayJobConfig:
address: typing.Optional[str] = None
shutdown_after_job_finishes: bool = False
ttl_seconds_after_finished: typing.Optional[int] = None
excludes_working_dir: typing.Optional[typing.List[str]] = None


class RayFunctionTask(PythonFunctionTask):
Expand All @@ -50,11 +57,30 @@ class RayFunctionTask(PythonFunctionTask):
_RAY_TASK_TYPE = "ray"

def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs):
super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs)
super().__init__(
task_config=task_config,
task_type=self._RAY_TASK_TYPE,
task_function=task_function,
**kwargs,
)
self._task_config = task_config

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
ray.init(address=self._task_config.address)
init_params = {"address": self._task_config.address}

ctx = FlyteContextManager.current_context()
if not ctx.execution_state.is_local_execution():
working_dir = os.getcwd()
init_params["runtime_env"] = {
"working_dir": working_dir,
"excludes": ["script_mode.tar.gz", "fast*.tar.gz"],
}

cfg = self._task_config
if cfg.excludes_working_dir:
init_params["runtime_env"]["excludes"].extend(cfg.excludes_working_dir)

ray.init(**init_params)
return user_params

def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
Expand All @@ -67,12 +93,20 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]

ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None,
head_group_spec=(
HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None
),
worker_group_spec=[
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params)
WorkerGroupSpec(
c.group_name,
c.replicas,
c.min_replicas,
c.max_replicas,
c.ray_start_params,
)
for c in cfg.worker_node_config
],
enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False,
enable_autoscaling=(cfg.enable_autoscaling if cfg.enable_autoscaling else False),
),
runtime_env=runtime_env,
runtime_env_yaml=runtime_env_yaml,
Expand Down
Loading