Skip to content

Commit

Permalink
Enable Ray Fast Register (#2606)
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Fiedler <jan@union.ai>
  • Loading branch information
fiedlerNr9 authored Jul 29, 2024
1 parent df3ab4c commit 5bc5d5c
Showing 1 changed file with 41 additions and 7 deletions.
48 changes: 41 additions & 7 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import base64
import json
import os
import typing
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional

import yaml
from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec
from flytekitplugins.ray.models import (
HeadGroupSpec,
RayCluster,
RayJob,
WorkerGroupSpec,
)
from google.protobuf.json_format import MessageToDict

from flytekit import lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.extend import TaskPlugins

Expand Down Expand Up @@ -40,6 +46,7 @@ class RayJobConfig:
address: typing.Optional[str] = None
shutdown_after_job_finishes: bool = False
ttl_seconds_after_finished: typing.Optional[int] = None
excludes_working_dir: typing.Optional[typing.List[str]] = None


class RayFunctionTask(PythonFunctionTask):
Expand All @@ -50,11 +57,30 @@ class RayFunctionTask(PythonFunctionTask):
_RAY_TASK_TYPE = "ray"

def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs):
super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs)
super().__init__(
task_config=task_config,
task_type=self._RAY_TASK_TYPE,
task_function=task_function,
**kwargs,
)
self._task_config = task_config

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
ray.init(address=self._task_config.address)
init_params = {"address": self._task_config.address}

ctx = FlyteContextManager.current_context()
if not ctx.execution_state.is_local_execution():
working_dir = os.getcwd()
init_params["runtime_env"] = {
"working_dir": working_dir,
"excludes": ["script_mode.tar.gz", "fast*.tar.gz"],
}

cfg = self._task_config
if cfg.excludes_working_dir:
init_params["runtime_env"]["excludes"].extend(cfg.excludes_working_dir)

ray.init(**init_params)
return user_params

def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
Expand All @@ -67,12 +93,20 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]

ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None,
head_group_spec=(
HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None
),
worker_group_spec=[
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params)
WorkerGroupSpec(
c.group_name,
c.replicas,
c.min_replicas,
c.max_replicas,
c.ray_start_params,
)
for c in cfg.worker_node_config
],
enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False,
enable_autoscaling=(cfg.enable_autoscaling if cfg.enable_autoscaling else False),
),
runtime_env=runtime_env,
runtime_env_yaml=runtime_env_yaml,
Expand Down

0 comments on commit 5bc5d5c

Please sign in to comment.