Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Ray Fast Register #2606

Merged
merged 7 commits into from
Jul 29, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import base64
import json
import os
import typing
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional

import yaml
from flytekitplugins.ray.models import HeadGroupSpec, RayCluster, RayJob, WorkerGroupSpec
from flytekitplugins.ray.models import (
HeadGroupSpec,
RayCluster,
RayJob,
WorkerGroupSpec,
)
from google.protobuf.json_format import MessageToDict

from flytekit import lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.extend import TaskPlugins

Expand Down Expand Up @@ -40,6 +46,7 @@ class RayJobConfig:
address: typing.Optional[str] = None
shutdown_after_job_finishes: bool = False
ttl_seconds_after_finished: typing.Optional[int] = None
excludes_working_dir: typing.Optional[typing.List[str]] = None


class RayFunctionTask(PythonFunctionTask):
Expand All @@ -50,11 +57,31 @@ class RayFunctionTask(PythonFunctionTask):
_RAY_TASK_TYPE = "ray"

def __init__(self, task_config: RayJobConfig, task_function: Callable, **kwargs):
super().__init__(task_config=task_config, task_type=self._RAY_TASK_TYPE, task_function=task_function, **kwargs)
super().__init__(
task_config=task_config,
task_type=self._RAY_TASK_TYPE,
task_function=task_function,
**kwargs,
)
self._task_config = task_config

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
ray.init(address=self._task_config.address)
init_params = {"address": self._task_config.address}

ctx = FlyteContextManager.current_context()
if not ctx.execution_state.is_local_execution():
working_dir = os.getcwd()
init_params["runtime_env"] = {"working_dir": working_dir}

cfg = self._task_config
if cfg.excludes_working_dir:
init_params["runtime_env"]["excludes"] = cfg.excludes_working_dir

# fast register data with timestamp mtime=0 will be zipped and uploaded to ray gcs
# zip does not support timestamps before 1980 -> hacky workaround of touching all the files
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this comment say explicitly how the data is being moved around? My understanding is:

  1. Flyte's fast register and sends it to the head node
  2. Ray then zips the working directory and sends it to the ray gcs
  3. Ray worker pulls the files from ray's gcs.

Is this correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its exactly this only with the addition that ray respects the given excludes variable (list[str]), which functions as a ignore file

Copy link
Member

@pingsutw pingsutw Jul 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: could we exclude the file during the fast registration? (.gitignore)
Do we have a case where we want to upload a file to the head node but don't want it on the worker nodes?"

Copy link
Contributor Author

@fiedlerNr9 fiedlerNr9 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a case where we want to upload a file to the head node but don't want it on the worker nodes?

No i dont think so.

What file are you trying to exclude? The fast register .tar.gz? Actually I am thinking, its only the .tar.gz file coming with the mtime=0 right? If we exclude that we dont need to modify any timestamp on files?

os.system(f"touch `find {working_dir} -type f`")
Copy link
Member

@thomasjpfan thomasjpfan Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends on the system having touch + find, which is usually the case.

To be safer, can this use os.utime + os.walk (or Path.rglob) to update the mtimes?

Copy link
Contributor Author

@fiedlerNr9 fiedlerNr9 Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I can try this out & adjust. Also this should only stay short term. @pingsutw mentioned there is work in progress that allows us to not set mtime=0 for the fast register tar.gz

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i do not think we will do this right now @fiedlerNr9, because the mtime=0 makes it possible to get consistent hashes for the similar tar files, otherwise we get multiple uploads to admin.

I think you should do what @thomasjpfan recommends use os.walk. the os.system will break and will cause random bugs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to manually adjust the timestamp for this now. I just catch the fast register data and dont zip it at all


ray.init(**init_params)
return user_params

def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
Expand All @@ -67,12 +94,20 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]

ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None,
head_group_spec=(
HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None
),
worker_group_spec=[
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params)
WorkerGroupSpec(
c.group_name,
c.replicas,
c.min_replicas,
c.max_replicas,
c.ray_start_params,
)
for c in cfg.worker_node_config
],
enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False,
enable_autoscaling=(cfg.enable_autoscaling if cfg.enable_autoscaling else False),
),
runtime_env=runtime_env,
runtime_env_yaml=runtime_env_yaml,
Expand Down
Loading