Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime env] Async pip runtime env #22381

Merged
merged 18 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 37 additions & 38 deletions python/ray/_private/runtime_env/pip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import contextlib
import asyncio
import os
import sys
import json
Expand All @@ -9,11 +9,14 @@
from filelock import FileLock
from typing import Optional, List, Dict, Tuple

from ray._private.runtime_env.conda_utils import exec_cmd_stream_to_logger
from ray._private.runtime_env.context import RuntimeEnvContext
from ray._private.runtime_env.packaging import Protocol, parse_uri
from ray._private.runtime_env.utils import RuntimeEnv
from ray._private.utils import get_directory_size_bytes, try_to_create_directory
from ray._private.runtime_env.utils import RuntimeEnv, check_output_cmd
from ray._private.utils import (
get_directory_size_bytes,
try_to_create_directory,
asynccontextmanager,
)

default_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,8 +89,8 @@ def _is_in_virtualenv() -> bool:
)

@staticmethod
@contextlib.contextmanager
def _check_ray(python: str, cwd: str, logger: logging.Logger):
@asynccontextmanager
async def _check_ray(python: str, cwd: str, logger: logging.Logger):
"""A context manager to check ray is not overwritten.

Currently, we only check ray version and path. It works for virtualenv,
Expand All @@ -96,24 +99,22 @@ def _check_ray(python: str, cwd: str, logger: logging.Logger):
- ray is in virtualenv's site-packages.
"""

def _get_ray_version_and_path() -> Tuple[str, str]:
async def _get_ray_version_and_path() -> Tuple[str, str]:
check_ray_cmd = [
python,
"-c",
"import ray; print(ray.__version__, ray.__path__[0])",
]
exit_code, output = exec_cmd_stream_to_logger(
check_ray_cmd, logger, cwd=cwd, env={}
output = await check_output_cmd(
check_ray_cmd, logger=logger, cwd=cwd, env={}
)
if exit_code != 0:
raise RuntimeError("Get ray version and path failed.")
# print after import ray may have  endings, so we strip them by *_
ray_version, ray_path, *_ = [s.strip() for s in output.split()]
return ray_version, ray_path

version, path = _get_ray_version_and_path()
version, path = await _get_ray_version_and_path()
yield
actual_version, actual_path = _get_ray_version_and_path()
actual_version, actual_path = await _get_ray_version_and_path()
if actual_version != version or actual_path != path:
raise RuntimeError(
"Changing the ray version is not allowed: \n"
Expand All @@ -126,7 +127,9 @@ def _get_ray_version_and_path() -> Tuple[str, str]:
)

@classmethod
def _create_or_get_virtualenv(cls, path: str, cwd: str, logger: logging.Logger):
async def _create_or_get_virtualenv(
cls, path: str, cwd: str, logger: logging.Logger
):
"""Create or get a virtualenv from path."""
if cls._is_in_virtualenv():
# TODO(fyrestone): Handle create virtualenv from virtualenv.
Expand Down Expand Up @@ -174,16 +177,10 @@ def _create_or_get_virtualenv(cls, path: str, cwd: str, logger: logging.Logger):
virtualenv_path,
]
logger.info("Creating virtualenv at %s", virtualenv_path)
exit_code, output = exec_cmd_stream_to_logger(
create_venv_cmd, logger, cwd=cwd, env={}
)
if exit_code != 0:
raise RuntimeError(
f"Failed to create virtualenv {virtualenv_path}:\n{output}"
)
await check_output_cmd(create_venv_cmd, logger=logger, cwd=cwd, env={})

@classmethod
def _install_pip_packages(
async def _install_pip_packages(
cls,
path: str,
pip_packages: List[str],
Expand All @@ -194,9 +191,16 @@ def _install_pip_packages(
python = _PathHelper.get_virtualenv_python(path)
# TODO(fyrestone): Support -i, --no-deps, --no-cache-dir, ...
pip_requirements_file = _PathHelper.get_requirements_file(path)
with open(pip_requirements_file, "w") as file:
for line in pip_packages:
file.write(line + "\n")

def _gen_requirements_txt():
with open(pip_requirements_file, "w") as file:
for line in pip_packages:
file.write(line + "\n")

# Avoid blocking the event loop.
loop = asyncio.get_running_loop()
fyrestone marked this conversation as resolved.
Show resolved Hide resolved
await loop.run_in_executor(None, _gen_requirements_txt)

pip_install_cmd = [
python,
"-m",
Expand All @@ -207,15 +211,9 @@ def _install_pip_packages(
pip_requirements_file,
]
logger.info("Installing python requirements to %s", virtualenv_path)
exit_code, output = exec_cmd_stream_to_logger(
pip_install_cmd, logger, cwd=cwd, env={}
)
if exit_code != 0:
raise RuntimeError(
f"Failed to install python requirements to {virtualenv_path}:\n{output}"
)
await check_output_cmd(pip_install_cmd, logger=logger, cwd=cwd, env={})

def run(self):
async def run(self):
path = self._target_dir
logger = self._logger
pip_packages = self._runtime_env.pip_packages()
Expand All @@ -225,10 +223,10 @@ def run(self):
exec_cwd = os.path.join(path, "exec_cwd")
os.makedirs(exec_cwd, exist_ok=True)
try:
self._create_or_get_virtualenv(path, exec_cwd, logger)
await self._create_or_get_virtualenv(path, exec_cwd, logger)
python = _PathHelper.get_virtualenv_python(path)
with self._check_ray(python, exec_cwd, logger):
self._install_pip_packages(path, pip_packages, exec_cwd, logger)
async with self._check_ray(python, exec_cwd, logger):
await self._install_pip_packages(path, pip_packages, exec_cwd, logger)
# TODO(fyrestone): pip check.
except Exception:
logger.info("Delete incomplete virtualenv: %s", path)
Expand Down Expand Up @@ -301,9 +299,10 @@ async def create(

with FileLock(self._installs_and_deletions_file_lock):
pip_processor = PipProcessor(target_dir, runtime_env, logger)
pip_processor.run()
await pip_processor.run()

return get_directory_size_bytes(target_dir)
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_directory_size_bytes, target_dir)

def modify_context(
self,
Expand Down
96 changes: 94 additions & 2 deletions python/ray/_private/runtime_env/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Dict, List, Tuple, Any
import asyncio
import itertools
import json
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv
import logging
import subprocess
import types
from typing import Dict, List, Tuple, Any

from google.protobuf import json_format
from ray.core.generated.runtime_env_common_pb2 import RuntimeEnv as ProtoRuntimeEnv


def _build_proto_pip_runtime_env(runtime_env_dict: dict, runtime_env: ProtoRuntimeEnv):
Expand Down Expand Up @@ -280,3 +286,89 @@ def from_dict(
_build_proto_container_runtime_env(runtime_env_dict, proto_runtime_env)
_build_proto_plugin_runtime_env(runtime_env_dict, proto_runtime_env)
return cls(proto_runtime_env=proto_runtime_env)


class SubprocessCalledProcessError(subprocess.CalledProcessError):
"""The subprocess.CalledProcessError with stripped stdout."""

LAST_N_LINES = 10
architkulkarni marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _get_last_n_line(str_data: str, last_n_lines: int) -> str:
if last_n_lines < 0:
return str_data
lines = str_data.split("\n")
return "\n".join(lines[-last_n_lines:])

def __str__(self):
str_list = [super().__str__().strip(",.")]
out = {
"stdout": self.stdout,
"stderr": self.stderr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very similar to this part, maybe you can modify the implementation in conda_utils.py to avoid splitting the implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part

Yes. But we can't assume that the stderr is always redirected to the stdout. The implementation of SubprocessCalledProcessError should cover the interface of the parent subprocess.CalledProcessError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very similar to this part, maybe you can modify the implementation in conda_utils.py to avoid splitting the implementation

The exec_cmd and exec_cmd_stream_to_logger in the conda_utils.py are used by many runtime envs, not only by the conda. I thinks it's better to put the exec utils into the utils.py instead of conda_utils.py.

But, we are going to make all the runtime env async, so exec_cmd and exec_cmd_stream_to_logger will be removed after they are not referenced.

}
for name, s in out.items():
if s:
str_list.append(f"{name}={self._get_last_n_line(s, self.LAST_N_LINES)}")
return ", ".join(str_list)
fyrestone marked this conversation as resolved.
Show resolved Hide resolved


async def check_output_cmd(
cmd: List[str],
*,
logger: logging.Logger,
cmd_index_gen: types.GeneratorType = itertools.count(1),
**kwargs,
) -> str:
"""Run command with arguments and return its output.

If the return code was non-zero it raises a CalledProcessError. The
CalledProcessError object will have the return code in the returncode
attribute and any output in the output attribute.

Args:
cmd: The cmdline should be a sequence of program arguments or else
a single string or path-like object. The program to execute is
the first item in cmd.
logger: The logger instance.
cmd_index_gen: The cmd index generator, default is itertools.count(1).
kwargs: All arguments are passed to the create_subprocess_exec.

Returns:
The stdout of cmd.

Raises:
CalledProcessError: If the return code of cmd is not 0.
fyrestone marked this conversation as resolved.
Show resolved Hide resolved
"""

cmd_index = next(cmd_index_gen)
logger.info("Run cmd[%s] %s", cmd_index, repr(cmd))
proc = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, **kwargs
)

try:
# Use communicate instead of polling stdout:
# * Avoid deadlocks due to streams pausing reading or writing and blocking the
# child process. Please refer to:
# https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.stderr
# * Avoid mixing multiple outputs of concurrent cmds.
stdout, _ = await proc.communicate()
except BaseException as e:
raise RuntimeError(f"Run cmd[{cmd_index}] got exception.") from e
else:
stdout = stdout.decode("utf-8")
if stdout:
logger.info("Output of cmd[%s]: %s", cmd_index, stdout)
else:
logger.info("No output for cmd[%s]", cmd_index)
if proc.returncode != 0:
raise SubprocessCalledProcessError(proc.returncode, cmd, output=stdout)
return stdout
finally:
# Kill process.
try:
proc.kill()
except ProcessLookupError:
pass
# Wait process exit.
await proc.wait()
115 changes: 115 additions & 0 deletions python/ray/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,3 +1260,118 @@ def get_directory_size_bytes(path: Union[str, Path] = ".") -> int:
total_size_bytes += os.path.getsize(fp)

return total_size_bytes


try:
from contextlib import asynccontextmanager
except ImportError:
# Copy from https://github.com/python-trio/async_generator
# for compatible with Python 3.6
import sys
from functools import wraps
from inspect import isasyncgenfunction

class _aclosing:
def __init__(self, aiter):
self._aiter = aiter

async def __aenter__(self):
return self._aiter

async def __aexit__(self, *args):
await self._aiter.aclose()

# Very much derived from the one in contextlib, by copy/pasting and then
# asyncifying everything. (Also I dropped the obscure support for using
# context managers as function decorators. It could be re-added; I just
# couldn't be bothered.)
# So this is a derivative work licensed under the PSF License, which requires
# the following notice:
#
# Copyright © 2001-2017 Python Software Foundation; All Rights Reserved
class _AsyncGeneratorContextManager:
def __init__(self, func, args, kwds):
self._func_name = func.__name__
self._agen = func(*args, **kwds).__aiter__()

async def __aenter__(self):
if sys.version_info < (3, 5, 2):
self._agen = await self._agen
try:
return await self._agen.asend(None)
except StopAsyncIteration:
raise RuntimeError("async generator didn't yield") from None

async def __aexit__(self, type, value, traceback):
async with _aclosing(self._agen):
if type is None:
try:
await self._agen.asend(None)
except StopAsyncIteration:
return False
else:
raise RuntimeError("async generator didn't stop")
else:
# It used to be possible to have type != None, value == None:
# https://bugs.python.org/issue1705170
# but AFAICT this can't happen anymore.
assert value is not None
try:
await self._agen.athrow(type, value, traceback)
raise RuntimeError("async generator didn't stop after athrow()")
except StopAsyncIteration as exc:
# Suppress StopIteration *unless* it's the same exception
# that was passed to throw(). This prevents a
# StopIteration raised inside the "with" statement from
# being suppressed.
return exc is not value
except RuntimeError as exc:
# Don't re-raise the passed in exception. (issue27112)
if exc is value:
return False
# Likewise, avoid suppressing if a StopIteration exception
# was passed to throw() and later wrapped into a
# RuntimeError (see PEP 479).
if (
isinstance(value, (StopIteration, StopAsyncIteration))
and exc.__cause__ is value
):
return False
raise
except: # noqa: E722
# only re-raise if it's *not* the exception that was
# passed to throw(), because __exit__() must not raise an
# exception unless __exit__() itself failed. But throw()
# has to raise the exception to signal propagation, so
# this fixes the impedance mismatch between the throw()
# protocol and the __exit__() protocol.
#
if sys.exc_info()[1] is value:
return False
raise

def __enter__(self):
raise RuntimeError(
"use 'async with {func_name}(...)', not 'with {func_name}(...)'".format(
func_name=self._func_name
)
)

def __exit__(self): # pragma: no cover
assert False, """Never called, but should be defined"""

def asynccontextmanager(func):
"""Like @contextmanager, but async."""
if not isasyncgenfunction(func):
raise TypeError(
"must be an async generator (native or from async_generator; "
"if using @async_generator then @acontextmanager must be on top."
)

@wraps(func)
def helper(*args, **kwds):
return _AsyncGeneratorContextManager(func, args, kwds)

# A hint for sphinxcontrib-trio:
helper.__returns_acontextmanager__ = True
return helper
Loading