Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime env] Async pip runtime env #22381

Merged
merged 18 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 37 additions & 38 deletions python/ray/_private/runtime_env/pip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import contextlib
import asyncio
import os
import sys
import json
Expand All @@ -9,11 +9,14 @@
from filelock import FileLock
from typing import Optional, List, Dict, Tuple

from ray._private.runtime_env.conda_utils import exec_cmd_stream_to_logger
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.packaging import Protocol, parse_uri
from ray._private.runtime_env.utils import RuntimeEnv
from ray._private.utils import get_directory_size_bytes, try_to_create_directory
from ray._private.runtime_env.utils import RuntimeEnv, check_output_cmd
from ray._private.utils import (
get_directory_size_bytes,
try_to_create_directory,
asynccontextmanager,
)

default_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,8 +89,8 @@ def _is_in_virtualenv() -> bool:
)

@staticmethod
@contextlib.contextmanager
def _check_ray(python: str, cwd: str, logger: logging.Logger):
@asynccontextmanager
async def _check_ray(python: str, cwd: str, logger: logging.Logger):
"""A context manager to check ray is not overwritten.

Currently, we only check ray version and path. It works for virtualenv,
Expand All @@ -96,24 +99,22 @@ def _check_ray(python: str, cwd: str, logger: logging.Logger):
- ray is in virtualenv's site-packages.
"""

def _get_ray_version_and_path() -> Tuple[str, str]:
async def _get_ray_version_and_path() -> Tuple[str, str]:
check_ray_cmd = [
python,
"-c",
"import ray; print(ray.__version__, ray.__path__[0])",
]
exit_code, output = exec_cmd_stream_to_logger(
check_ray_cmd, logger, cwd=cwd, env={}
output = await check_output_cmd(
check_ray_cmd, logger=logger, cwd=cwd, env={}
)
if exit_code != 0:
raise RuntimeError("Get ray version and path failed.")
# print after import ray may have  endings, so we strip them by *_
ray_version, ray_path, *_ = [s.strip() for s in output.split()]
return ray_version, ray_path

version, path = _get_ray_version_and_path()
version, path = await _get_ray_version_and_path()
yield
actual_version, actual_path = _get_ray_version_and_path()
actual_version, actual_path = await _get_ray_version_and_path()
if actual_version != version or actual_path != path:
raise RuntimeError(
"Changing the ray version is not allowed: \n"
Expand All @@ -126,7 +127,9 @@ def _get_ray_version_and_path() -> Tuple[str, str]:
)

@classmethod
def _create_or_get_virtualenv(cls, path: str, cwd: str, logger: logging.Logger):
async def _create_or_get_virtualenv(
cls, path: str, cwd: str, logger: logging.Logger
):
"""Create or get a virtualenv from path."""

python = sys.executable
Expand Down Expand Up @@ -189,16 +192,10 @@ def _create_or_get_virtualenv(cls, path: str, cwd: str, logger: logging.Logger):
virtualenv_path,
current_python_dir,
)
exit_code, output = exec_cmd_stream_to_logger(
create_venv_cmd, logger, cwd=cwd, env={}
)
if exit_code != 0:
raise RuntimeError(
f"Failed to create virtualenv {virtualenv_path}:\n{output}"
)
await check_output_cmd(create_venv_cmd, logger=logger, cwd=cwd, env={})

@classmethod
def _install_pip_packages(
async def _install_pip_packages(
cls,
path: str,
pip_packages: List[str],
Expand All @@ -209,9 +206,16 @@ def _install_pip_packages(
python = _PathHelper.get_virtualenv_python(path)
# TODO(fyrestone): Support -i, --no-deps, --no-cache-dir, ...
pip_requirements_file = _PathHelper.get_requirements_file(path)
with open(pip_requirements_file, "w") as file:
for line in pip_packages:
file.write(line + "\n")

def _gen_requirements_txt():
with open(pip_requirements_file, "w") as file:
for line in pip_packages:
file.write(line + "\n")

# Avoid blocking the event loop.
loop = asyncio.get_running_loop()
fyrestone marked this conversation as resolved.
Show resolved Hide resolved
await loop.run_in_executor(None, _gen_requirements_txt)

pip_install_cmd = [
python,
"-m",
Expand All @@ -222,15 +226,9 @@ def _install_pip_packages(
pip_requirements_file,
]
logger.info("Installing python requirements to %s", virtualenv_path)
exit_code, output = exec_cmd_stream_to_logger(
pip_install_cmd, logger, cwd=cwd, env={}
)
if exit_code != 0:
raise RuntimeError(
f"Failed to install python requirements to {virtualenv_path}:\n{output}"
)
await check_output_cmd(pip_install_cmd, logger=logger, cwd=cwd, env={})

def run(self):
async def run(self):
path = self._target_dir
logger = self._logger
pip_packages = self._runtime_env.pip_packages()
Expand All @@ -240,10 +238,10 @@ def run(self):
exec_cwd = os.path.join(path, "exec_cwd")
os.makedirs(exec_cwd, exist_ok=True)
try:
self._create_or_get_virtualenv(path, exec_cwd, logger)
await self._create_or_get_virtualenv(path, exec_cwd, logger)
python = _PathHelper.get_virtualenv_python(path)
with self._check_ray(python, exec_cwd, logger):
self._install_pip_packages(path, pip_packages, exec_cwd, logger)
async with self._check_ray(python, exec_cwd, logger):
await self._install_pip_packages(path, pip_packages, exec_cwd, logger)
# TODO(fyrestone): pip check.
except Exception:
logger.info("Delete incomplete virtualenv: %s", path)
Expand Down Expand Up @@ -316,9 +314,10 @@ async def create(

with FileLock(self._installs_and_deletions_file_lock):
pip_processor = PipProcessor(target_dir, runtime_env, logger)
pip_processor.run()
await pip_processor.run()

return get_directory_size_bytes(target_dir)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_directory_size_bytes, target_dir)

def modify_context(
self,
Expand Down
117 changes: 115 additions & 2 deletions python/ray/_private/runtime_env/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Dict, List, Tuple, Any
import asyncio
import itertools
import json
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
import logging
import subprocess
import textwrap
import types
from typing import Dict, List, Tuple, Any

from google.protobuf import json_format
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv


def _build_proto_pip_runtime_env(runtime_env_dict: dict, runtime_env: ProtoRuntimeEnv):
Expand Down Expand Up @@ -280,3 +287,109 @@ def from_dict(
_build_proto_container_runtime_env(runtime_env_dict, proto_runtime_env)
_build_proto_plugin_runtime_env(runtime_env_dict, proto_runtime_env)
return cls(proto_runtime_env=proto_runtime_env)


class SubprocessCalledProcessError(subprocess.CalledProcessError):
"""The subprocess.CalledProcessError with stripped stdout."""

LAST_N_LINES = 10
architkulkarni marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args, cmd_index=None, **kwargs):
self.cmd_index = cmd_index
super().__init__(*args, **kwargs)

@staticmethod
def _get_last_n_line(str_data: str, last_n_lines: int) -> str:
if last_n_lines < 0:
return str_data
lines = str_data.strip().split("\n")
return "\n".join(lines[-last_n_lines:])

def __str__(self):
str_list = (
[]
if self.cmd_index is None
else [f"Run cmd[{self.cmd_index}] failed with the following details."]
)
str_list.append(super().__str__())
out = {
"stdout": self.stdout,
"stderr": self.stderr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very similar to this part, maybe you can modify the implementation in conda_utils.py to avoid splitting the implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part

Yes. But we can't assume that the stderr is always redirected to the stdout. The implementation of SubprocessCalledProcessError should cover the interface of the parent subprocess.CalledProcessError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very similar to this part, maybe you can modify the implementation in conda_utils.py to avoid splitting the implementation

The exec_cmd and exec_cmd_stream_to_logger in the conda_utils.py are used by many runtime envs, not only by the conda. I thinks it's better to put the exec utils into the utils.py instead of conda_utils.py.

But, we are going to make all the runtime env async, so exec_cmd and exec_cmd_stream_to_logger will be removed after they are not referenced.

}
for name, s in out.items():
if s:
subtitle = f"Last {self.LAST_N_LINES} lines of {name}:"
last_n_line_str = self._get_last_n_line(s, self.LAST_N_LINES).strip()
str_list.append(
f"{subtitle}\n{textwrap.indent(last_n_line_str, ' ' * 4)}"
)
return "\n".join(str_list)


async def check_output_cmd(
cmd: List[str],
*,
logger: logging.Logger,
cmd_index_gen: types.GeneratorType = itertools.count(1),
**kwargs,
) -> str:
"""Run command with arguments and return its output.

If the return code was non-zero it raises a CalledProcessError. The
CalledProcessError object will have the return code in the returncode
attribute and any output in the output attribute.

Args:
cmd: The cmdline should be a sequence of program arguments or else
a single string or path-like object. The program to execute is
the first item in cmd.
logger: The logger instance.
cmd_index_gen: The cmd index generator, default is itertools.count(1).
kwargs: All arguments are passed to the create_subprocess_exec.

Returns:
The stdout of cmd.

Raises:
CalledProcessError: If the return code of cmd is not 0.
fyrestone marked this conversation as resolved.
Show resolved Hide resolved
"""

cmd_index = next(cmd_index_gen)
logger.info("Run cmd[%s] %s", cmd_index, repr(cmd))

proc = None
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
**kwargs,
)
# Use communicate instead of polling stdout:
# * Avoid deadlocks due to streams pausing reading or writing and blocking the
# child process. Please refer to:
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.stderr
# * Avoid mixing multiple outputs of concurrent cmds.
stdout, _ = await proc.communicate()
except BaseException as e:
raise RuntimeError(f"Run cmd[{cmd_index}] got exception.") from e
else:
stdout = stdout.decode("utf-8")
if stdout:
logger.info("Output of cmd[%s]: %s", cmd_index, stdout)
else:
logger.info("No output for cmd[%s]", cmd_index)
if proc.returncode != 0:
raise SubprocessCalledProcessError(
proc.returncode, cmd, output=stdout, cmd_index=cmd_index
)
return stdout
finally:
if proc is not None:
# Kill process.
try:
proc.kill()
except ProcessLookupError:
pass
# Wait process exit.
await proc.wait()
Loading