Skip to content

Commit

Permalink
[runtime env] Make plugin setup process that has not been refactor ru…
Browse files Browse the repository at this point in the history
…n in threads. (#22588)

I recently realized that during a runtime_env creation process, a plugin/manager that is very slow to setup may block the creation of other runtime_env, so I make plugin/manager setup run in threads.

[The refactor of `PipManager`](#22381) is about to be completed, so I ignore it in this PR.
  • Loading branch information
Catch-Bull authored Feb 28, 2022
1 parent 22bc451 commit aa1885a
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 38 deletions.
30 changes: 19 additions & 11 deletions dashboard/modules/runtime_env/runtime_env_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,25 @@ async def _setup_runtime_env(
for uri in runtime_env.plugin_uris():
self._uris_to_envs[uri].add(serialized_runtime_env)

# Run setup function from all the plugins
for plugin_class_path, config in runtime_env.plugins():
per_job_logger.debug(
f"Setting up runtime env plugin {plugin_class_path}"
)
plugin_class = import_attr(plugin_class_path)
# TODO(simon): implement uri support
plugin_class.create("uri not implemented", json.loads(config), context)
plugin_class.modify_context(
"uri not implemented", json.loads(config), context
)
def setup_plugins():
# Run setup function from all the plugins
for plugin_class_path, config in runtime_env.plugins():
per_job_logger.debug(
f"Setting up runtime env plugin {plugin_class_path}"
)
plugin_class = import_attr(plugin_class_path)
# TODO(simon): implement uri support
plugin_class.create(
"uri not implemented", json.loads(config), context
)
plugin_class.modify_context(
"uri not implemented", json.loads(config), context
)

loop = asyncio.get_event_loop()
# Plugins setup method is sync process, running in other threads
# is to avoid blocks asyncio loop
await loop.run_in_executor(None, setup_plugins)

return context

Expand Down
55 changes: 34 additions & 21 deletions python/ray/_private/runtime_env/conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import platform
import runpy
import shutil
import asyncio

from filelock import FileLock
from typing import Optional, List, Dict, Any
Expand Down Expand Up @@ -306,28 +307,40 @@ async def create(
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger,
) -> int:
logger.debug("Setting up conda for runtime_env: " f"{runtime_env.serialize()}")
protocol, hash = parse_uri(uri)
conda_env_name = self._get_path_from_hash(hash)

conda_dict = _get_conda_dict_with_ray_inserted(runtime_env, logger=logger)
# Currently create method is still a sync process, to avoid blocking
# the loop, need to run this function in another thread.
# TODO(Catch-Bull): Refactor method create into an async process, and
# make this method running in current loop.
def _create():
logger.debug(
"Setting up conda for runtime_env: " f"{runtime_env.serialize()}"
)
protocol, hash = parse_uri(uri)
conda_env_name = self._get_path_from_hash(hash)

logger.info(f"Setting up conda environment with {runtime_env}")
with FileLock(self._installs_and_deletions_file_lock):
try:
conda_yaml_file = os.path.join(self._resources_dir, "environment.yml")
with open(conda_yaml_file, "w") as file:
yaml.dump(conda_dict, file)
create_conda_env_if_needed(
conda_yaml_file, prefix=conda_env_name, logger=logger
)
finally:
os.remove(conda_yaml_file)

if runtime_env.get_extension("_inject_current_ray") == "True":
_inject_ray_to_conda_site(conda_path=conda_env_name, logger=logger)
logger.info(f"Finished creating conda environment at {conda_env_name}")
return get_directory_size_bytes(conda_env_name)
conda_dict = _get_conda_dict_with_ray_inserted(runtime_env, logger=logger)

logger.info(f"Setting up conda environment with {runtime_env}")
with FileLock(self._installs_and_deletions_file_lock):
try:
conda_yaml_file = os.path.join(
self._resources_dir, "environment.yml"
)
with open(conda_yaml_file, "w") as file:
yaml.dump(conda_dict, file)
create_conda_env_if_needed(
conda_yaml_file, prefix=conda_env_name, logger=logger
)
finally:
os.remove(conda_yaml_file)

if runtime_env.get_extension("_inject_current_ray") == "True":
_inject_ray_to_conda_site(conda_path=conda_env_name, logger=logger)
logger.info(f"Finished creating conda environment at {conda_env_name}")
return get_directory_size_bytes(conda_env_name)

loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _create)

def modify_context(
self,
Expand Down
17 changes: 13 additions & 4 deletions python/ray/_private/runtime_env/py_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from types import ModuleType
from typing import Any, Dict, List, Optional
from pathlib import Path
import asyncio

from ray.experimental.internal_kv import _internal_kv_initialized
from ray._private.runtime_env.context import RuntimeEnvContext
Expand Down Expand Up @@ -129,10 +130,18 @@ async def create(
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger,
) -> int:
module_dir = download_and_unpack_package(
uri, self._resources_dir, logger=logger
)
return get_directory_size_bytes(module_dir)
# Currently create method is still a sync process, to avoid blocking
# the loop, need to run this function in another thread.
# TODO(Catch-Bull): Refactor method create into an async process, and
# make this method running in current loop.
def _create():
module_dir = download_and_unpack_package(
uri, self._resources_dir, logger=logger
)
return get_directory_size_bytes(module_dir)

loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _create)

def modify_context(
self,
Expand Down
15 changes: 13 additions & 2 deletions python/ray/_private/runtime_env/working_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from typing import Any, Dict, Optional
from pathlib import Path
import asyncio

from ray.experimental.internal_kv import _internal_kv_initialized
from ray._private.runtime_env.context import RuntimeEnvContext
Expand Down Expand Up @@ -115,8 +116,18 @@ async def create(
context: RuntimeEnvContext,
logger: Optional[logging.Logger] = default_logger,
) -> int:
local_dir = download_and_unpack_package(uri, self._resources_dir, logger=logger)
return get_directory_size_bytes(local_dir)
# Currently create method is still a sync process, to avoid blocking
# the loop, need to run this function in another thread.
# TODO(Catch-Bull): Refactor method create into an async process, and
# make this method running in current loop.
def _create():
local_dir = download_and_unpack_package(
uri, self._resources_dir, logger=logger
)
return get_directory_size_bytes(local_dir)

loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _create)

def modify_context(
self, uri: Optional[str], runtime_env_dict: Dict, context: RuntimeEnvContext
Expand Down
64 changes: 64 additions & 0 deletions python/ray/tests/test_runtime_env_plugin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import tempfile
from time import sleep

import pytest
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
from ray._private.test_utils import wait_for_condition

import ray

Expand Down Expand Up @@ -72,6 +74,68 @@ def f():
assert output == {"env_value": "42", "tmp_content": "hello", "nice": 19}


MY_PLUGIN_FOR_HANG_CLASS_PATH = "ray.tests.test_runtime_env_plugin.MyPluginForHang"
my_plugin_setup_times = 0


# This plugin will hang when first setup, second setup will ok
class MyPluginForHang(RuntimeEnvPlugin):
env_key = "MY_PLUGIN_FOR_HANG_TEST_ENVIRONMENT_KEY"

@staticmethod
def validate(runtime_env_dict: dict) -> str:
return "True"

@staticmethod
def create(uri: str, runtime_env: dict, ctx: RuntimeEnvContext) -> float:
global my_plugin_setup_times
my_plugin_setup_times += 1

# first setup
if my_plugin_setup_times == 1:
# sleep forever
sleep(3600)

@staticmethod
def modify_context(
uri: str, plugin_config_dict: dict, ctx: RuntimeEnvContext
) -> None:
global my_plugin_setup_times
ctx.env_vars[MyPluginForHang.env_key] = str(my_plugin_setup_times)


def test_plugin_hang(ray_start_regular):
env_key = MyPluginForHang.env_key

@ray.remote(num_cpus=0.1)
def f():
return os.environ[env_key]

refs = [
f.options(
# Avoid hitting the cache of runtime_env
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f1"}}}
).remote(),
f.options(
runtime_env={"plugins": {MY_PLUGIN_FOR_HANG_CLASS_PATH: {"name": "f2"}}}
).remote(),
]

def condition():
for ref in refs:
try:
res = ray.get(ref, timeout=1)
print("result:", res)
assert int(res) == 2
return True
except Exception as error:
print(f"Got error: {error}")
pass
return False

wait_for_condition(condition, timeout=60)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit aa1885a

Please sign in to comment.