Skip to content

Commit

Permalink
Remove grpcio from minimal dependency (ray-project#38243)
Browse files Browse the repository at this point in the history
Removes grpcio from ray dependency while adds it to ray[default] dependency.

Removed use of grpc client: gcs_node_info_stub.DrainNode in autoscaler. Moved it into raylet.pyx.
Removed use of grpc client: gcs_node_resources_stub.GetAllResourceUsage in autoscaler. Moved it into raylet.pyx.
Removed use of grpc status code. Moved it into raylet.pyx.
Removed Ray Client tests in test_runtime_env_ray_minimal.py

---------

Signed-off-by: Ruiyang Wang <rywang014@gmail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
rynewang authored and arvind-chandra committed Aug 31, 2023
1 parent 14f136c commit c0d98b1
Show file tree
Hide file tree
Showing 24 changed files with 239 additions and 217 deletions.
88 changes: 47 additions & 41 deletions dashboard/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
from ray.dashboard.consts import _PARENT_DEATH_THREASHOLD
from ray._private.gcs_pubsub import GcsAioPublisher
from ray._raylet import GcsClient
from ray._private.gcs_utils import GcsAioClient
from ray._private.ray_logging import (
Expand All @@ -31,12 +30,6 @@
# Import psutil after ray so the packaged version is used.
import psutil

try:
from grpc import aio as aiogrpc
except ImportError:
from grpc.experimental import aio as aiogrpc


# Publishes at most this number of lines of Raylet logs, when the Raylet dies
# unexpectedly.
_RAYLET_LOG_MAX_PUBLISH_LINES = 20
Expand All @@ -52,19 +45,6 @@

logger = logging.getLogger(__name__)

# We would want to suppress deprecating warnings from aiogrpc library
# with the usage of asyncio.get_event_loop() in python version >=3.10
# This could be removed once https://github.com/grpc/grpc/issues/32526
# is released, and we used higher versions of grpcio that that.
if sys.version_info.major >= 3 and sys.version_info.minor >= 10:
import warnings

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
aiogrpc.init_grpc_aio()
else:
aiogrpc.init_grpc_aio()


class DashboardAgent:
def __init__(
Expand Down Expand Up @@ -117,13 +97,44 @@ def __init__(
assert self.ppid > 0
logger.info("Parent pid is %s", self.ppid)

# Setup raylet channel
options = ray_constants.GLOBAL_GRPC_OPTIONS
self.aiogrpc_raylet_channel = ray._private.utils.init_grpc_channel(
f"{self.ip}:{self.node_manager_port}", options, asynchronous=True
)
# grpc server is None in mininal.
self.server = None
# http_server is None in minimal.
self.http_server = None

# Used by the agent and sub-modules.
# TODO(architkulkarni): Remove gcs_client once the agent exclusively uses
# gcs_aio_client and not gcs_client.
self.gcs_client = GcsClient(address=self.gcs_address)
_initialize_internal_kv(self.gcs_client)
assert _internal_kv_initialized()
self.gcs_aio_client = GcsAioClient(address=self.gcs_address)

if not self.minimal:
self._init_non_minimal()

def _init_non_minimal(self):
from ray._private.gcs_pubsub import GcsAioPublisher
self.aio_publisher = GcsAioPublisher(address=self.gcs_address)

try:
from grpc import aio as aiogrpc
except ImportError:
from grpc.experimental import aio as aiogrpc

# We would want to suppress deprecating warnings from aiogrpc library
# with the usage of asyncio.get_event_loop() in python version >=3.10
# This could be removed once https://github.com/grpc/grpc/issues/32526
# is released, and we used higher versions of grpcio that that.
if sys.version_info.major >= 3 and sys.version_info.minor >= 10:
import warnings

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
aiogrpc.init_grpc_aio()
else:
aiogrpc.init_grpc_aio()

# Setup grpc server
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0),))
grpc_ip = "127.0.0.1" if self.ip == "127.0.0.1" else "0.0.0.0"
try:
Expand All @@ -143,19 +154,6 @@ def __init__(
else:
logger.info("Dashboard agent grpc address: %s:%s", grpc_ip, self.grpc_port)

# If the agent is started as non-minimal version, http server should
# be configured to communicate with the dashboard in a head node.
self.http_server = None

# Used by the agent and sub-modules.
# TODO(architkulkarni): Remove gcs_client once the agent exclusively uses
# gcs_aio_client and not gcs_client.
self.gcs_client = GcsClient(address=self.gcs_address)
_initialize_internal_kv(self.gcs_client)
assert _internal_kv_initialized()
self.gcs_aio_client = GcsAioClient(address=self.gcs_address)
self.publisher = GcsAioPublisher(address=self.gcs_address)

async def _configure_http_server(self, modules):
from ray.dashboard.http_server_agent import HttpServerAgent

Expand All @@ -180,9 +178,16 @@ def _load_modules(self):

@property
def http_session(self):
assert self.http_server, "Accessing unsupported API in a minimal ray."
assert self.http_server, \
"Accessing unsupported API (HttpServerAgent) in a minimal ray."
return self.http_server.http_session

@property
def publisher(self):
assert self.aio_publisher, \
"Accessing unsupported API (GcsAioPublisher) in a minimal ray."
return self.aio_publisher

async def run(self):
async def _check_parent():
"""Check if raylet is dead and fate-share if it is."""
Expand Down Expand Up @@ -311,9 +316,10 @@ async def _check_parent():
# TODO: Use async version if performance is an issue
# -1 should indicate that http server is not started.
http_port = -1 if not self.http_server else self.http_server.http_port
grpc_port = -1 if not self.server else self.grpc_port
await self.gcs_aio_client.internal_kv_put(
f"{dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX}{self.node_id}".encode(),
json.dumps([http_port, self.grpc_port]).encode(),
json.dumps([http_port, grpc_port]).encode(),
True,
namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
)
Expand Down
1 change: 1 addition & 0 deletions dashboard/optional_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from aiohttp.typedefs import PathLike # noqa: F401
from aiohttp.web import RouteDef # noqa: F401
import pydantic # noqa: F401
import grpc # noqa: F401
4 changes: 2 additions & 2 deletions dashboard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def run(self, server):
"""
Run the module in an asyncio loop. An agent module can provide
servicers to the server.
:param server: Asyncio GRPC server.
:param server: Asyncio GRPC server, or None if ray is minimal.
"""

@staticmethod
Expand Down Expand Up @@ -79,7 +79,7 @@ async def run(self, server):
"""
Run the module in an asyncio loop. A head module can provide
servicers to the server.
:param server: Asyncio GRPC server.
:param server: Asyncio GRPC server, or None if ray is minimal.
"""

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
Ray Client
==========

.. warning::
Ray Client requires pip package `ray[client]`. If you installed the minimal Ray (e.g. `pip install ray`), please reinstall by executing `pip install ray[client]`.

**What is the Ray Client?**

The Ray Client is an API that connects a Python script to a **remote** Ray cluster. Effectively, it allows you to leverage a remote Ray cluster just like you would with Ray running on your local machine.
Expand Down
68 changes: 0 additions & 68 deletions python/ray/_private/gcs_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import enum
import logging
import inspect
import os
import asyncio
from functools import wraps
from typing import Optional

import grpc

from ray._private import ray_constants

import ray._private.gcs_aio_client
Expand Down Expand Up @@ -98,60 +91,6 @@ def create_gcs_channel(address: str, aio=False):
return init_grpc_channel(address, options=_GRPC_OPTIONS, asynchronous=aio)


# This global variable is used for testing only
_called_freq = {}


def _auto_reconnect(f):
# This is for testing to count the frequence
# of gcs call
if inspect.iscoroutinefunction(f):

@wraps(f)
async def wrapper(self, *args, **kwargs):
if "TEST_RAY_COLLECT_KV_FREQUENCY" in os.environ:
global _called_freq
name = f.__name__
if name not in _called_freq:
_called_freq[name] = 0
_called_freq[name] += 1

remaining_retry = self._nums_reconnect_retry
while True:
try:
return await f(self, *args, **kwargs)
except grpc.RpcError as e:
if e.code() in (
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.UNKNOWN,
):
if remaining_retry <= 0:
logger.error(
"Failed to connect to GCS. Please check"
" `gcs_server.out` for more details."
)
raise
logger.debug(
"Failed to send request to gcs, reconnecting. " f"Error {e}"
)
try:
self._connect()
except Exception:
logger.error(f"Connecting to gcs failed. Error {e}")
await asyncio.sleep(1)
remaining_retry -= 1
continue
raise

return wrapper
else:

raise NotImplementedError(
"This code moved to Cython, see "
"https://github.com/ray-project/ray/pull/33769"
)


class GcsChannel:
def __init__(self, gcs_address: Optional[str] = None, aio: bool = False):
self._gcs_address = gcs_address
Expand All @@ -171,13 +110,6 @@ def channel(self):
return self._channel


class GcsCode(enum.IntEnum):
# corresponding to ray/src/ray/common/status.h
OK = 0
NotFound = 17
GrpcUnavailable = 26


# re-export
GcsAioClient = ray._private.gcs_aio_client.GcsAioClient

Expand Down
8 changes: 7 additions & 1 deletion python/ray/_private/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import ray._private.ray_constants as ray_constants
from ray._private.utils import (
validate_node_labels,
check_ray_client_dependencies_installed,
)


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -397,8 +399,12 @@ def _check_usage(self):
raise ValueError(
"max_worker_port must be higher than min_worker_port."
)

if self.ray_client_server_port is not None:
if not check_ray_client_dependencies_installed():
raise ValueError(
"Ray Client requires pip package `ray[client]`. "
"If you installed the minimal Ray (e.g. `pip install ray`), "
"please reinstall by executing `pip install ray[client]`.")
if (
self.ray_client_server_port < 1024
or self.ray_client_server_port > 65535
Expand Down
13 changes: 9 additions & 4 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@
import requests
from ray._raylet import Config

import grpc
import numpy as np
import psutil # We must import psutil after ray because we bundle it with ray.
from ray._private import (
ray_constants,
)
from ray._private.worker import RayContext
import yaml
from grpc._channel import _InactiveRpcError

import ray
import ray._private.gcs_utils as gcs_utils
Expand All @@ -45,9 +43,7 @@
from ray.core.generated import (
gcs_pb2,
node_manager_pb2,
node_manager_pb2_grpc,
gcs_service_pb2,
gcs_service_pb2_grpc,
)
from ray.util.queue import Empty, Queue, _QueueActor
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
Expand Down Expand Up @@ -1474,6 +1470,9 @@ async def get_total_killed_nodes(self):
return self.killed_nodes

def _kill_raylet(self, ip, port, graceful=False):
import grpc
from grpc._channel import _InactiveRpcError
from ray.core.generated import node_manager_pb2_grpc
raylet_address = f"{ip}:{port}"
channel = grpc.insecure_channel(raylet_address)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
Expand Down Expand Up @@ -1694,6 +1693,8 @@ def wandb_setup_api_key_hook():

# Get node stats from node manager.
def get_node_stats(raylet, num_retry=5, timeout=2):
import grpc
from ray.core.generated import node_manager_pb2_grpc
raylet_address = f'{raylet["NodeManagerAddress"]}:{raylet["NodeManagerPort"]}'
channel = ray._private.utils.init_grpc_channel(raylet_address)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
Expand All @@ -1711,6 +1712,7 @@ def get_node_stats(raylet, num_retry=5, timeout=2):

# Gets resource usage assuming gcs is local.
def get_resource_usage(gcs_address, timeout=10):
from ray.core.generated import gcs_service_pb2_grpc
if not gcs_address:
gcs_address = ray.worker._global_node.gcs_address

Expand Down Expand Up @@ -1739,6 +1741,9 @@ def get_load_metrics_report(webui_url):

# Send a RPC to the raylet to have it self-destruct its process.
def kill_raylet(raylet, graceful=False):
import grpc
from grpc._channel import _InactiveRpcError
from ray.core.generated import node_manager_pb2_grpc
raylet_address = f'{raylet["NodeManagerAddress"]}:{raylet["NodeManagerPort"]}'
channel = grpc.insecure_channel(raylet_address)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
Expand Down
3 changes: 1 addition & 2 deletions python/ray/_private/tls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import os
import socket

import grpc


def generate_self_signed_tls_certs():
"""Create self-signed key/cert pair for testing.
Expand Down Expand Up @@ -68,6 +66,7 @@ def generate_self_signed_tls_certs():


def add_port_to_grpc_server(server, address):
import grpc
if os.environ.get("RAY_USE_TLS", "0").lower() in ("1", "true"):
server_cert_chain, private_key, ca_cert = load_certs_from_env()
credentials = grpc.ssl_server_credentials(
Expand Down
Loading

0 comments on commit c0d98b1

Please sign in to comment.